Source code for das.tracking
"""Utilities for logging training runs.
We currenlty have integrations for `tensorboard <https://www.tensorflow.org/tensorboard>`_, and `wandb.ai <https://wandb.ai>`_.
While tensorboard is integrated with tensorflow. To use wandb you'll have to install the wandb API:
"pip install wandb" or "conda install wandb -c conda-forge".
"""
import logging
import os
from typing import Optional, Dict
try:
import wandb
from wandb.keras import WandbCallback
HAS_WANDB = True
except ImportError as e:
logging.debug("Could not import neptune libraries.")
HAS_WANDB = False
HAS_WANDB = True
[docs]class Wandb:
"""Utility class for logging to wandb.ai during training."""
def __init__(
self,
project: Optional[str] = None,
api_token: Optional[str] = None,
entity: Optional[str] = None,
params: Optional[Dict] = None,
infer_from_env: bool = False,
):
"""
Args:
project (Optional[str], optional): Project to log to. Defaults to None.
api_token (Optional[str], optional): api token. Defaults to None.
entity (Optional[str], optional): Entity (user/team name). Defaults to None.
params (Optional[Dict], optional): Dict to log to `config`. Defaults to None.
infer_from_env (bool, optional): read project and api_token from environment variables
WANDB_PROJECT and WANDB_API_TOKEN.
Defaults to False.
"""
if not HAS_WANDB:
self.run = None
logging.error("Could not import wandb in das.tracking.")
return
try:
if project is None:
project = os.environ["WANDB_PROJECT"]
if api_token is None:
api_token = os.environ["WANDB_API_TOKEN"]
wandb.login(key=api_token)
self.project = project
self.entity = entity
self.run = wandb.init(project=self.project, entity=self.entity, settings=wandb.Settings(start_method="fork"))
if params is not None:
wandb.config.update(params)
except:
self.run = None
logging.exception("Wandb stuff went wrong.")
def reinit(self, params=None):
self.run = wandb.init(reinit=True, project=self.project, entity=self.entity)
if params is not None:
wandb.config.update(params)
def finish(self):
self.run.finish()
[docs] def callback(self, save_model=False): # -> Optional[WandbCallback]:
"""Get callback for auto-logging from tensorfow/keras."""
# CHECK: Is callback re-usable across reinits?
if self.run is not None:
return WandbCallback(save_model=save_model)
else:
pass
[docs] def log_test_results(self, report: Dict):
"""Log final classification result from test data.
Args:
report (Dict): dictionary containing the classification report.
"""
if self.run is not None:
wandb.summary.update(report)