Source code for deep_bottleneck.callbacks.activity_logger

from tensorflow import keras
from tensorflow.python.keras import backend as K

from deep_bottleneck import utils


[docs]class ActivityLogger(keras.callbacks.Callback): def __init__(self, examples, file_dump, do_save_func=None, *args, **kwargs): super().__init__(*args, **kwargs) self.examples = examples self.file_dump = file_dump # do_save_func(epoch) should return True if we should save on that epoch. self.do_save_func = do_save_func # Functions return activity of each layer. self.layerfuncs = []
[docs] def on_train_begin(self, logs=None): for layer in self.model.layers: if utils.is_dense_like(layer): self.layerfuncs.append(K.function(self.model.inputs, [layer.output]))
[docs] def on_epoch_end(self, epoch, logs=None): if self.do_save_func is not None and self.do_save_func(epoch): self._log_activity_for_epoch(epoch)
def _log_activity_for_epoch(self, epoch): self.file_dump.require_group(str(epoch)) self.file_dump[str(epoch)].create_group('activations') for i, current_layer_func in enumerate(self.layerfuncs): self.file_dump[str(epoch)]['activations'].create_dataset( name=str(i), data=current_layer_func([self.examples])[0] )