Source code for deep_bottleneck.mi_estimator.base
from typing import *
import pandas as pd
import numpy as np
from deep_bottleneck import utils
from deep_bottleneck.mi_estimator import kde
[docs]class MutualInformationEstimator:
nats2bits = 1.0 / np.log(2)
"""Nats to bits conversion factor."""
def __init__(self, discretization_range, architecture, n_classes):
self.architecture = architecture
self.n_classes = n_classes
[docs] def compute_mi(self, data, file_dump) -> pd.DataFrame:
print(f'*** Start running {self.__class__.__name__}. ***')
labels = data.labels
one_hot_labels = data.one_hot_labels
# Proportion of instances that have a certain label.
label_weights = np.mean(one_hot_labels, axis=0)
label_masks = {}
for target_class in range(self.n_classes):
label_masks[target_class] = labels == target_class
n_layers = len(self.architecture) + 1 # + 1 for output layer
epoch_numbers = [int(value) for value in file_dump.keys()]
epoch_numbers = sorted(epoch_numbers)
measures = self._init_dataframe(epoch_numbers=epoch_numbers, n_layers=n_layers)
for epoch in epoch_numbers:
print(f'Estimating mutual information for epoch {epoch}.')
summary = file_dump[str(epoch)]
for layer_index in range(n_layers):
layer_activations = summary['activations'][str(layer_index)]
mi_with_input, mi_with_label = self._compute_mi_per_epoch_and_layer(layer_activations, label_weights,
label_masks)
measures.loc[(epoch, layer_index), 'MI_XM'] = mi_with_input
measures.loc[(epoch, layer_index), 'MI_YM'] = mi_with_label
return measures
def _init_dataframe(self, epoch_numbers, n_layers):
info_measures = ['MI_XM', 'MI_YM']
index_base_keys = [epoch_numbers, list(range(n_layers))]
index = pd.MultiIndex.from_product(index_base_keys, names=['epoch', 'layer'])
measures = pd.DataFrame(index=index, columns=info_measures)
return measures
def _compute_mi_per_epoch_and_layer(self, activations, label_weights, label_masks) -> Tuple[float, float]:
activations = np.asarray(activations)
H_of_M = self._estimate_entropy(activations)
H_of_M_given_X = self._estimate_conditional_entropy(activations)
H_of_M_given_Y = self._compute_H_of_M_given_Y(activations, label_weights, label_masks)
mi_with_input = self.nats2bits * (H_of_M - H_of_M_given_X)
mi_with_label = self.nats2bits * (H_of_M - H_of_M_given_Y)
return mi_with_input, mi_with_label
def _compute_H_of_M_given_Y(self, activations, label_weights, label_masks):
H_of_M_given_Y = 0
for label, mask in label_masks.items():
H_of_M_for_specific_y = self._estimate_entropy(activations[mask])
H_of_M_given_Y += label_weights[label] * H_of_M_for_specific_y
return H_of_M_given_Y
def _estimate_entropy(self, data: np.array) -> float:
"""
Args:
data: The data to estimate entropy for.
Returns:
The estimated entropy.
"""
raise NotImplementedError
def _estimate_conditional_entropy(self, data: np.array) -> float:
raise NotImplementedError