Source code for minerva.analysis.clustering_analysis

from minerva.analysis.model_analysis import _ModelAnalysis
import lightning as L
from typing import Optional, Tuple
import torch
from minerva.data.data_module_tools import get_full_data_split
from minerva.utils.typing import PathLike
from sklearn.metrics import silhouette_score, davies_bouldin_score
import numpy as np


[docs] class ClusteringAnalysis(_ModelAnalysis): """ Perform a clustering analysis on the embeddings generated by some model, using the Silhouette score and Davies-Bouldin score, functions implemented in sklearn. The results are returned in a dictionary. """ def __init__(self, data_split: str = "predict"): """ Initialize the analysis with the specified data split. Parameters ---------- data_split : str, optional The data split to use for the analysis, by default "predict". This specifies which part of the dataset to analyze. Can be one of: ["train", "validation", "test", "predict"]. """ super().__init__() self.data_split = data_split
[docs] def compute(self, model: L.LightningModule, data: L.LightningDataModule): """ Compute the clustering analysis metrics. Parameters ---------- model : L.LightningModule The trained model from which to extract embeddings. data : L.LightningDataModule The data module containing the dataset to analyze. Returns ------- dict A dictionary containing the Silhouette score and Davies-Bouldin score. """ # Establish the data to be used data, labels = get_full_data_split(data, self.data_split) data = np.array(data) labels = np.array(labels) data = torch.tensor(data, dtype=torch.float32) model.eval() embeddings = model.backbone(data) silhouette = silhouette_score(embeddings.detach().numpy(), labels) davies_bouldin = davies_bouldin_score(embeddings.detach().numpy(), labels) # Saving the results result = { "silhouette-score": float(silhouette), "davies-bouldin-score": float(davies_bouldin), } return result