Source code for deep_bottleneck.callbacks.metrics_logger

from tensorflow.python.keras.callbacks import Callback


[docs]class MetricsLogger(Callback): """Callback to log loss and accuracy to sacred database.""" def __init__(self, file_dump, do_save_func): super().__init__() self._file_dump = file_dump self._do_save_func = do_save_func
[docs] def on_epoch_end(self, epoch, logs=None): if self._do_save_func(epoch): self._file_dump.require_group(f'{epoch}/accuracy') self._file_dump.require_group(f'{epoch}/loss') self._file_dump[f'{epoch}/accuracy']['training'] = float(logs['acc']) self._file_dump[f'{epoch}/loss']['training'] = float(logs['acc']) try: self._file_dump[f'{epoch}/accuracy']['test'] = float(logs['val_acc']) self._file_dump[f'{epoch}/loss']['test'] = float(logs['val_loss']) except KeyError: print('Validation not enabled. Validation metrics cannot be logged')
[docs]class SacredMetricsLogger(Callback): def __init__(self, run): super().__init__() self._run = run
[docs] def on_epoch_end(self, epoch, logs=None): self._run.log_scalar("training.loss", float(logs['loss']), step=epoch) self._run.log_scalar("training.accuracy", float(logs['acc']), step=epoch) try: self._run.log_scalar("test.loss", float(logs['val_loss']), step=epoch) self._run.log_scalar("test.accuracy", float(logs['val_acc']), step=epoch) except KeyError: print('Validation not enabled. Validation metrics cannot be logged')