Source code for deep_bottleneck.plotter.informationplane_movie
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
from deep_bottleneck.plotter.base import BasePlotter
[docs]class InformationPlaneMoviePlotter(BasePlotter):
"""Plot the infoplane movie for several runs of the same network."""
plotname = 'infoplane_movie'
file_ext = 'mp4'
num_layers = None
total_number_of_epochs = None
epoch_indexes = None
layers_colors = None
def __init__(self, run, dataset):
self.dataset = dataset
self.run = run
[docs] def generate(self, measures_summary, suffix):
self.filename = self.make_filename(suffix)
self.plot(measures_summary)
suffix = '_' + suffix if suffix else suffix
artifact_name = f'{self.plotname}{suffix}'
self.run.add_artifact(self.filename, name=artifact_name)
[docs] def setup_infoplane_subplot(self, ax_infoplane):
if self.dataset == 'datasets.mnist' or self.dataset == 'datasets.fashion_mnist':
ax_infoplane.set(xlim=[0, 14], ylim=[0, 3.5])
else:
ax_infoplane.set(xlim=[0, 12], ylim=[0, 1])
ax_infoplane.set(xlabel='I(X;M)', ylabel='I(Y;M)')
scatter = ax_infoplane.scatter([], [], s=20, edgecolor='none')
text = ax_infoplane.text(0, 1.05, "", fontsize=12)
return scatter, text
[docs] def fill_infoplane_subplot(self, ax_infoplane, mi_epoch):
xmvals = mi_epoch['MI_XM']
ymvals = mi_epoch['MI_YM']
points = np.array([xmvals, ymvals]).transpose()
colors = self.layers_colors[mi_epoch.index]
ax_infoplane.set_offsets(points)
ax_infoplane.set_array(colors)
return ax_infoplane
[docs] def setup_accuracy_subplot(self, ax_accuracy):
[acc_line] = ax_accuracy.plot([], [], 'b', label="training accuracy")
[val_acc_line] = ax_accuracy.plot([], [], 'g', label="test accuracy")
ax_accuracy.set_ylim(0, 1)
ax_accuracy.set_xlim(0, self.total_number_of_epochs)
if self.total_number_of_epochs > 20:
xticks_positions = range(0, self.total_number_of_epochs, int(self.total_number_of_epochs / 20))
ax_accuracy.set_xticks(xticks_positions)
ax_accuracy.set_xticklabels(self.epoch_indexes[xticks_positions], rotation=90)
handles, labels = ax_accuracy.get_legend_handles_labels()
ax_accuracy.legend(handles, labels, loc=4)
ax_accuracy.set_xlabel('Epoch')
ax_accuracy.set_ylabel('Accuracy')
return acc_line, val_acc_line
[docs] def fill_accuracy_subplot(self, acc_line, val_acc_line, activations_summary, epoch_number, acc, val_acc):
epoch_accuracy = np.asarray(activations_summary[f'{epoch_number}/accuracy/']['training'])
epoch_val_accuracy = np.asarray(activations_summary[f'{epoch_number}/accuracy/']['test'])
acc.append(epoch_accuracy)
val_acc.append(epoch_val_accuracy)
xs = range(len(acc))
acc_line.set_data(xs, acc)
val_acc_line.set_data(xs, val_acc)
return acc, val_acc, acc_line, val_acc_line
[docs] def get_specifications(self, measures):
self.num_layers = measures.index.get_level_values(1).nunique()
self.layers_colors = np.linspace(0, 1, self.num_layers)
self.epoch_indexes = measures.index.get_level_values('epoch').unique()
self.total_number_of_epochs = len(self.epoch_indexes)
[docs] def plot(self, measures_summary):
os.makedirs('plots/', exist_ok=True)
measures = measures_summary['measures_all_runs']
activations_summary = measures_summary['activations_summary']
self.get_specifications(measures)
plt.set_cmap("hsv")
fig, (ax_infoplane, ax_accuracy) = plt.subplots(2, 1, figsize=(6, 9),
gridspec_kw={'height_ratios': [2, 1]})
acc = []
val_acc = []
scatter, text = self.setup_infoplane_subplot(ax_infoplane)
acc_line, val_acc_line = self.setup_accuracy_subplot(ax_accuracy)
writer = FFMpegWriter(fps=7)
with writer.saving(fig, self.filename, 600):
for epoch_number, mi_epoch in measures.groupby(level=0):
# Drop outer index level corresponding to the epoch.
mi_epoch.index = mi_epoch.index.droplevel()
scatter = self.fill_infoplane_subplot(scatter, mi_epoch)
text.set_text(f'Epoch: {epoch_number}')
acc, val_acc, acc_line, val_acc_line = self.fill_accuracy_subplot(acc_line, val_acc_line,
activations_summary, epoch_number,
acc, val_acc)
writer.grab_frame()