from torch.nn.modules.loss import _Loss
import torch
import numpy as np
[docs]
class TopologicalLoss(_Loss):
    def __init__(self, p: int = 2):
        """
        Initialize the TopologicalLoss class.
        Parameters
        ----------
        p : int, optional
            Order of norm used for distance computation, by default 2
        """
        super(TopologicalLoss, self).__init__()
        self.p = p
        self.topological_signature_distance = TopologicalSignatureDistance()
        self.latent_norm = torch.nn.Parameter(data=torch.ones(1), requires_grad=True)
[docs]
    def forward(self, x, x_encoded):
        x_distances = self._compute_distance_matrix(x, p=self.p)
        if len(x.size()) == 4:
            # If the input is an image (has 4 dimensions), normalize using theoretical maximum
            _, ch, b, w = x.size()
            # Compute the maximum distance we could get in the data space (this
            # is only valid for images wich are normalized between -1 and 1)
            max_distance = (2**2 * ch * b * w) ** 0.5
        else:
            # Else just take the max distance we got in the data
            max_distance = x_distances.max()
        x_distances = x_distances / max_distance
        # Latent distances
        x_encoded_distances = self._compute_distance_matrix(x_encoded, p=self.p)
        x_encoded_distances = x_encoded_distances / self.latent_norm
        # Compute the topological signature distance
        topological_error, _ = self.topological_signature_distance(
            x_distances, x_encoded_distances
        )
        # Normalize the topological error according to batch size
        topological_error = topological_error / x.size(0)
        return topological_error 
[docs]
    @staticmethod
    def _compute_distance_matrix(x, p=2):
        x_flat = x.view(x.size(0), -1)
        distances = torch.norm(x_flat[:, None] - x_flat, dim=2, p=p)
        return distances 
 
# Borrowed from https://github.com/BorgwardtLab/topological-autoencoders/blob/master/src/models/approx_based.py
[docs]
class TopologicalSignatureDistance(torch.nn.Module):
    """Topological signature."""
    def __init__(self, sort_selected=False, use_cycles=False, match_edges=None):
        """Topological signature computation.
        Args:
            p: Order of norm used for distance computation
            use_cycles: Flag to indicate whether cycles should be used
                or not.
        """
        super().__init__()
        self.use_cycles = use_cycles
        self.match_edges = match_edges
        # if use_cycles:
        #     use_aleph = True
        # else:
        #     if not sort_selected and match_edges is None:
        #         use_aleph = True
        #     else:
        #         use_aleph = False
        # if use_aleph:
        #     print('Using aleph to compute signatures')
        ##self.signature_calculator = AlephPersistenHomologyCalculation(
        ##    compute_cycles=use_cycles, sort_selected=sort_selected)
        # else:
        print("Using python to compute signatures")
        self.signature_calculator = PersistentHomologyCalculation()
[docs]
    def _get_pairings(self, distances):
        pairs_0, pairs_1 = self.signature_calculator(distances.detach().cpu().numpy())
        return pairs_0, pairs_1 
[docs]
    def _select_distances_from_pairs(self, distance_matrix, pairs):
        # Split 0th order and 1st order features (edges and cycles)
        pairs_0, pairs_1 = pairs
        selected_distances = distance_matrix[(pairs_0[:, 0], pairs_0[:, 1])]
        if self.use_cycles:
            edges_1 = distance_matrix[(pairs_1[:, 0], pairs_1[:, 1])]
            edges_2 = distance_matrix[(pairs_1[:, 2], pairs_1[:, 3])]
            edge_differences = edges_2 - edges_1
            selected_distances = torch.cat((selected_distances, edge_differences))
        return selected_distances 
[docs]
    @staticmethod
    def sig_error(signature1, signature2):
        """Compute distance between two topological signatures."""
        return ((signature1 - signature2) ** 2).sum(dim=-1) 
[docs]
    @staticmethod
    def _count_matching_pairs(pairs1, pairs2):
        def to_set(array):
            return set(tuple(elements) for elements in array)
        return float(len(to_set(pairs1).intersection(to_set(pairs2)))) 
[docs]
    @staticmethod
    def _get_nonzero_cycles(pairs):
        all_indices_equal = np.sum(pairs[:, [0]] == pairs[:, 1:], axis=-1) == 3
        return np.sum(np.logical_not(all_indices_equal)) 
    # pylint: disable=W0221
[docs]
    def forward(self, distances1, distances2):
        """Return topological distance of two pairwise distance matrices.
        Args:
            distances1: Distance matrix in space 1
            distances2: Distance matrix in space 2
        Returns:
            distance, dict(additional outputs)
        """
        pairs1 = self._get_pairings(distances1)
        pairs2 = self._get_pairings(distances2)
        distance_components = {
            "metrics.matched_pairs_0D": self._count_matching_pairs(pairs1[0], pairs2[0])
        }
        # Also count matched cycles if present
        if self.use_cycles:
            distance_components["metrics.matched_pairs_1D"] = (
                self._count_matching_pairs(pairs1[1], pairs2[1])
            )
            nonzero_cycles_1 = self._get_nonzero_cycles(pairs1[1])
            nonzero_cycles_2 = self._get_nonzero_cycles(pairs2[1])
            distance_components["metrics.non_zero_cycles_1"] = nonzero_cycles_1
            distance_components["metrics.non_zero_cycles_2"] = nonzero_cycles_2
        if self.match_edges is None:
            sig1 = self._select_distances_from_pairs(distances1, pairs1)
            sig2 = self._select_distances_from_pairs(distances2, pairs2)
            distance = self.sig_error(sig1, sig2)
        elif self.match_edges == "symmetric":
            sig1 = self._select_distances_from_pairs(distances1, pairs1)
            sig2 = self._select_distances_from_pairs(distances2, pairs2)
            # Selected pairs of 1 on distances of 2 and vice versa
            sig1_2 = self._select_distances_from_pairs(distances2, pairs1)
            sig2_1 = self._select_distances_from_pairs(distances1, pairs2)
            distance1_2 = self.sig_error(sig1, sig1_2)
            distance2_1 = self.sig_error(sig2, sig2_1)
            distance_components["metrics.distance1-2"] = distance1_2
            distance_components["metrics.distance2-1"] = distance2_1
            distance = distance1_2 + distance2_1
        elif self.match_edges == "random":
            # Create random selection in oder to verify if what we are seeing
            # is the topological constraint or an implicit latent space prior
            # for compactness
            n_instances = len(pairs1[0])
            pairs1 = torch.cat(
                [
                    torch.randperm(n_instances)[:, None],
                    torch.randperm(n_instances)[:, None],
                ],
                dim=1,
            )
            pairs2 = torch.cat(
                [
                    torch.randperm(n_instances)[:, None],
                    torch.randperm(n_instances)[:, None],
                ],
                dim=1,
            )
            sig1_1 = self._select_distances_from_pairs(distances1, (pairs1, None))
            sig1_2 = self._select_distances_from_pairs(distances2, (pairs1, None))
            sig2_2 = self._select_distances_from_pairs(distances2, (pairs2, None))
            sig2_1 = self._select_distances_from_pairs(distances1, (pairs2, None))
            distance1_2 = self.sig_error(sig1_1, sig1_2)
            distance2_1 = self.sig_error(sig2_1, sig2_2)
            distance_components["metrics.distance1-2"] = distance1_2
            distance_components["metrics.distance2-1"] = distance2_1
            distance = distance1_2 + distance2_1
        return distance, distance_components 
 
# Borrowed from https://github.com/BorgwardtLab/topological-autoencoders/blob/master/src/topology.py
"""
Methods for calculating lower-dimensional persistent homology.
"""
[docs]
class UnionFind:
    """
    An implementation of a Union--Find class. The class performs path
    compression by default. It uses integers for storing one disjoint
    set, assuming that vertices are zero-indexed.
    """
    def __init__(self, n_vertices):
        """
        Initializes an empty Union--Find data structure for a given
        number of vertices.
        """
        self._parent = np.arange(n_vertices, dtype=int)
[docs]
    def find(self, u):
        """
        Finds and returns the parent of u with respect to the hierarchy.
        """
        if self._parent[u] == u:
            return u
        else:
            # Perform path collapse operation
            self._parent[u] = self.find(self._parent[u])
            return self._parent[u] 
[docs]
    def merge(self, u, v):
        """
        Merges vertex u into the component of vertex v. Note the
        asymmetry of this operation.
        """
        if u != v:
            self._parent[self.find(u)] = self.find(v) 
[docs]
    def roots(self):
        """
        Generator expression for returning roots, i.e. components that
        are their own parents.
        """
        for vertex, parent in enumerate(self._parent):
            if vertex == parent:
                yield vertex 
 
[docs]
class PersistentHomologyCalculation:
[docs]
    def __call__(self, matrix):
        n_vertices = matrix.shape[0]
        uf = UnionFind(n_vertices)
        triu_indices = np.triu_indices_from(matrix)
        edge_weights = matrix[triu_indices]
        edge_indices = np.argsort(edge_weights, kind="stable")
        # 1st dimension: 'source' vertex index of edge
        # 2nd dimension: 'target' vertex index of edge
        persistence_pairs = []
        for edge_index, edge_weight in zip(edge_indices, edge_weights[edge_indices]):
            u = triu_indices[0][edge_index]
            v = triu_indices[1][edge_index]
            younger_component = uf.find(u)
            older_component = uf.find(v)
            # Not an edge of the MST, so skip it
            if younger_component == older_component:
                continue
            elif younger_component > older_component:
                uf.merge(v, u)
            else:
                uf.merge(u, v)
            if u < v:
                persistence_pairs.append((u, v))
            else:
                persistence_pairs.append((v, u))
        # Return empty cycles component
        return np.array(persistence_pairs), np.array([]) 
 
[docs]
class AlephPersistenHomologyCalculation:
    def __init__(self, compute_cycles, sort_selected):
        """Calculate persistent homology using aleph.
        Args:
            compute_cycles: Whether to compute cycles
            sort_selected: Whether to sort the selected pairs using the
                distance matrix (such that they are in the order of the
                filteration)
        """
        self.compute_cycles = compute_cycles
        self.sort_selected = sort_selected
[docs]
    def __call__(self, distance_matrix):
        """Do PH calculation.
        Args:
            distance_matrix: numpy array of distances
        Returns: tuple(edge_featues, cycle_features)
        """
        import aleph
        if self.compute_cycles:
            pairs_0, pairs_1 = aleph.vietoris_rips_from_matrix_2d(distance_matrix)
            pairs_0 = np.array(pairs_0)
            pairs_1 = np.array(pairs_1)
        else:
            pairs_0 = aleph.vietoris_rips_from_matrix_1d(distance_matrix)
            pairs_0 = np.array(pairs_0)
            # Return empty cycles component
            pairs_1 = np.array([])
        if self.sort_selected:
            selected_distances = distance_matrix[(pairs_0[:, 0], pairs_0[:, 1])]
            indices_0 = np.argsort(selected_distances)
            pairs_0 = pairs_0[indices_0]
            if self.compute_cycles:
                cycle_creation_times = distance_matrix[(pairs_1[:, 0], pairs_1[:, 1])]
                cycle_destruction_times = distance_matrix[
                    (pairs_1[:, 2], pairs_1[:, 3])
                ]
                cycle_persistences = cycle_destruction_times - cycle_creation_times
                # First sort by destruction time and then by persistence of the
                # create cycles in order to recover original filtration order.
                indices_1 = np.lexsort((cycle_destruction_times, cycle_persistences))
                pairs_1 = pairs_1[indices_1]
        return pairs_0, pairs_1