from torch.nn.modules.loss import _Loss
import torch
import numpy as np
[docs]
class TopologicalLoss(_Loss):
def __init__(self, p: int=2):
"""
Initialize the TopologicalLoss class.
Parameters
----------
p : int, optional
Order of norm used for distance computation, by default 2
"""
super(TopologicalLoss, self).__init__()
self.p = p
self.topological_signature_distance = TopologicalSignatureDistance()
self.latent_norm = torch.nn.Parameter(data=torch.ones(1), requires_grad=True)
[docs]
def forward(self, x, x_encoded):
x_distances = self._compute_distance_matrix(x, p=self.p)
if len(x.size()) == 4:
# If the input is an image (has 4 dimensions), normalize using theoretical maximum
_, ch, b, w = x.size()
# Compute the maximum distance we could get in the data space (this
# is only valid for images wich are normalized between -1 and 1)
max_distance = (2**2 * ch * b * w) ** 0.5
else:
# Else just take the max distance we got in the data
max_distance = x_distances.max()
x_distances = x_distances / max_distance
# Latent distances
x_encoded_distances = self._compute_distance_matrix(x_encoded, p=self.p)
x_encoded_distances = x_encoded_distances / self.latent_norm
# Compute the topological signature distance
topological_error, _ = self.topological_signature_distance(x_distances, x_encoded_distances)
# Normalize the topological error according to batch size
topological_error = topological_error / x.size(0)
return topological_error
[docs]
@staticmethod
def _compute_distance_matrix(x, p=2):
x_flat = x.view(x.size(0), -1)
distances = torch.norm(x_flat[:, None] - x_flat, dim=2, p=p)
return distances
# Borrowed from https://github.com/BorgwardtLab/topological-autoencoders/blob/master/src/models/approx_based.py
[docs]
class TopologicalSignatureDistance(torch.nn.Module):
"""Topological signature."""
def __init__(self, sort_selected=False, use_cycles=False,
match_edges=None):
"""Topological signature computation.
Args:
p: Order of norm used for distance computation
use_cycles: Flag to indicate whether cycles should be used
or not.
"""
super().__init__()
self.use_cycles = use_cycles
self.match_edges = match_edges
# if use_cycles:
# use_aleph = True
# else:
# if not sort_selected and match_edges is None:
# use_aleph = True
# else:
# use_aleph = False
# if use_aleph:
# print('Using aleph to compute signatures')
##self.signature_calculator = AlephPersistenHomologyCalculation(
## compute_cycles=use_cycles, sort_selected=sort_selected)
# else:
print('Using python to compute signatures')
self.signature_calculator = PersistentHomologyCalculation()
[docs]
def _get_pairings(self, distances):
pairs_0, pairs_1 = self.signature_calculator(
distances.detach().cpu().numpy())
return pairs_0, pairs_1
[docs]
def _select_distances_from_pairs(self, distance_matrix, pairs):
# Split 0th order and 1st order features (edges and cycles)
pairs_0, pairs_1 = pairs
selected_distances = distance_matrix[(pairs_0[:, 0], pairs_0[:, 1])]
if self.use_cycles:
edges_1 = distance_matrix[(pairs_1[:, 0], pairs_1[:, 1])]
edges_2 = distance_matrix[(pairs_1[:, 2], pairs_1[:, 3])]
edge_differences = edges_2 - edges_1
selected_distances = torch.cat(
(selected_distances, edge_differences))
return selected_distances
[docs]
@staticmethod
def sig_error(signature1, signature2):
"""Compute distance between two topological signatures."""
return ((signature1 - signature2)**2).sum(dim=-1)
[docs]
@staticmethod
def _count_matching_pairs(pairs1, pairs2):
def to_set(array):
return set(tuple(elements) for elements in array)
return float(len(to_set(pairs1).intersection(to_set(pairs2))))
[docs]
@staticmethod
def _get_nonzero_cycles(pairs):
all_indices_equal = np.sum(pairs[:, [0]] == pairs[:, 1:], axis=-1) == 3
return np.sum(np.logical_not(all_indices_equal))
# pylint: disable=W0221
[docs]
def forward(self, distances1, distances2):
"""Return topological distance of two pairwise distance matrices.
Args:
distances1: Distance matrix in space 1
distances2: Distance matrix in space 2
Returns:
distance, dict(additional outputs)
"""
pairs1 = self._get_pairings(distances1)
pairs2 = self._get_pairings(distances2)
distance_components = {
'metrics.matched_pairs_0D': self._count_matching_pairs(
pairs1[0], pairs2[0])
}
# Also count matched cycles if present
if self.use_cycles:
distance_components['metrics.matched_pairs_1D'] = \
self._count_matching_pairs(pairs1[1], pairs2[1])
nonzero_cycles_1 = self._get_nonzero_cycles(pairs1[1])
nonzero_cycles_2 = self._get_nonzero_cycles(pairs2[1])
distance_components['metrics.non_zero_cycles_1'] = nonzero_cycles_1
distance_components['metrics.non_zero_cycles_2'] = nonzero_cycles_2
if self.match_edges is None:
sig1 = self._select_distances_from_pairs(distances1, pairs1)
sig2 = self._select_distances_from_pairs(distances2, pairs2)
distance = self.sig_error(sig1, sig2)
elif self.match_edges == 'symmetric':
sig1 = self._select_distances_from_pairs(distances1, pairs1)
sig2 = self._select_distances_from_pairs(distances2, pairs2)
# Selected pairs of 1 on distances of 2 and vice versa
sig1_2 = self._select_distances_from_pairs(distances2, pairs1)
sig2_1 = self._select_distances_from_pairs(distances1, pairs2)
distance1_2 = self.sig_error(sig1, sig1_2)
distance2_1 = self.sig_error(sig2, sig2_1)
distance_components['metrics.distance1-2'] = distance1_2
distance_components['metrics.distance2-1'] = distance2_1
distance = distance1_2 + distance2_1
elif self.match_edges == 'random':
# Create random selection in oder to verify if what we are seeing
# is the topological constraint or an implicit latent space prior
# for compactness
n_instances = len(pairs1[0])
pairs1 = torch.cat([
torch.randperm(n_instances)[:, None],
torch.randperm(n_instances)[:, None]
], dim=1)
pairs2 = torch.cat([
torch.randperm(n_instances)[:, None],
torch.randperm(n_instances)[:, None]
], dim=1)
sig1_1 = self._select_distances_from_pairs(
distances1, (pairs1, None))
sig1_2 = self._select_distances_from_pairs(
distances2, (pairs1, None))
sig2_2 = self._select_distances_from_pairs(
distances2, (pairs2, None))
sig2_1 = self._select_distances_from_pairs(
distances1, (pairs2, None))
distance1_2 = self.sig_error(sig1_1, sig1_2)
distance2_1 = self.sig_error(sig2_1, sig2_2)
distance_components['metrics.distance1-2'] = distance1_2
distance_components['metrics.distance2-1'] = distance2_1
distance = distance1_2 + distance2_1
return distance, distance_components
# Borrowed from https://github.com/BorgwardtLab/topological-autoencoders/blob/master/src/topology.py
'''
Methods for calculating lower-dimensional persistent homology.
'''
[docs]
class UnionFind:
'''
An implementation of a Union--Find class. The class performs path
compression by default. It uses integers for storing one disjoint
set, assuming that vertices are zero-indexed.
'''
def __init__(self, n_vertices):
'''
Initializes an empty Union--Find data structure for a given
number of vertices.
'''
self._parent = np.arange(n_vertices, dtype=int)
[docs]
def find(self, u):
'''
Finds and returns the parent of u with respect to the hierarchy.
'''
if self._parent[u] == u:
return u
else:
# Perform path collapse operation
self._parent[u] = self.find(self._parent[u])
return self._parent[u]
[docs]
def merge(self, u, v):
'''
Merges vertex u into the component of vertex v. Note the
asymmetry of this operation.
'''
if u != v:
self._parent[self.find(u)] = self.find(v)
[docs]
def roots(self):
'''
Generator expression for returning roots, i.e. components that
are their own parents.
'''
for vertex, parent in enumerate(self._parent):
if vertex == parent:
yield vertex
[docs]
class PersistentHomologyCalculation:
[docs]
def __call__(self, matrix):
n_vertices = matrix.shape[0]
uf = UnionFind(n_vertices)
triu_indices = np.triu_indices_from(matrix)
edge_weights = matrix[triu_indices]
edge_indices = np.argsort(edge_weights, kind='stable')
# 1st dimension: 'source' vertex index of edge
# 2nd dimension: 'target' vertex index of edge
persistence_pairs = []
for edge_index, edge_weight in \
zip(edge_indices, edge_weights[edge_indices]):
u = triu_indices[0][edge_index]
v = triu_indices[1][edge_index]
younger_component = uf.find(u)
older_component = uf.find(v)
# Not an edge of the MST, so skip it
if younger_component == older_component:
continue
elif younger_component > older_component:
uf.merge(v, u)
else:
uf.merge(u, v)
if u < v:
persistence_pairs.append((u, v))
else:
persistence_pairs.append((v, u))
# Return empty cycles component
return np.array(persistence_pairs), np.array([])
[docs]
class AlephPersistenHomologyCalculation():
def __init__(self, compute_cycles, sort_selected):
"""Calculate persistent homology using aleph.
Args:
compute_cycles: Whether to compute cycles
sort_selected: Whether to sort the selected pairs using the
distance matrix (such that they are in the order of the
filteration)
"""
self.compute_cycles = compute_cycles
self.sort_selected = sort_selected
[docs]
def __call__(self, distance_matrix):
"""Do PH calculation.
Args:
distance_matrix: numpy array of distances
Returns: tuple(edge_featues, cycle_features)
"""
import aleph
if self.compute_cycles:
pairs_0, pairs_1 = aleph.vietoris_rips_from_matrix_2d(
distance_matrix)
pairs_0 = np.array(pairs_0)
pairs_1 = np.array(pairs_1)
else:
pairs_0 = aleph.vietoris_rips_from_matrix_1d(
distance_matrix)
pairs_0 = np.array(pairs_0)
# Return empty cycles component
pairs_1 = np.array([])
if self.sort_selected:
selected_distances = \
distance_matrix[(pairs_0[:, 0], pairs_0[:, 1])]
indices_0 = np.argsort(selected_distances)
pairs_0 = pairs_0[indices_0]
if self.compute_cycles:
cycle_creation_times = \
distance_matrix[(pairs_1[:, 0], pairs_1[:, 1])]
cycle_destruction_times = \
distance_matrix[(pairs_1[:, 2], pairs_1[:, 3])]
cycle_persistences = \
cycle_destruction_times - cycle_creation_times
# First sort by destruction time and then by persistence of the
# create cycles in order to recover original filtration order.
indices_1 = np.lexsort(
(cycle_destruction_times, cycle_persistences))
pairs_1 = pairs_1[indices_1]
return pairs_0, pairs_1