Source code for minerva.analysis.clustering_analysis
from minerva.analysis.model_analysis import _ModelAnalysis
import lightning as L
from typing import Optional, Tuple
import torch
from minerva.data.data_module_tools import get_full_data_split
from minerva.utils.typing import PathLike
from sklearn.metrics import silhouette_score, davies_bouldin_score
import numpy as np
[docs]
class ClusteringAnalysis(_ModelAnalysis):
"""
Perform a clustering analysis on the embeddings generated by some model,
using the Silhouette score and Davies-Bouldin score, functions implemented
in sklearn. The results are returned in a dictionary.
"""
def __init__(self, data_split: str = "predict"):
"""
Initialize the analysis with the specified data split.
Parameters
----------
data_split : str, optional
The data split to use for the analysis, by default "predict".
This specifies which part of the dataset to analyze. Can be one of:
["train", "validation", "test", "predict"].
"""
super().__init__()
self.data_split = data_split
[docs]
def compute(self, model: L.LightningModule, data: L.LightningDataModule):
"""
Compute the clustering analysis metrics.
Parameters
----------
model : L.LightningModule
The trained model from which to extract embeddings.
data : L.LightningDataModule
The data module containing the dataset to analyze.
Returns
-------
dict
A dictionary containing the Silhouette score and Davies-Bouldin score.
"""
# Establish the data to be used
data, labels = get_full_data_split(data, self.data_split)
data = np.array(data)
labels = np.array(labels)
data = torch.tensor(data, dtype=torch.float32)
model.eval()
embeddings = model.backbone(data)
silhouette = silhouette_score(embeddings.detach().numpy(), labels)
davies_bouldin = davies_bouldin_score(embeddings.detach().numpy(), labels)
# Saving the results
result = {
"silhouette-score": float(silhouette),
"davies-bouldin-score": float(davies_bouldin),
}
return result