Source code for minerva.analysis.model_analysis

from pathlib import Path
from typing import Dict, Optional, Union

import lightning as L
from sklearn.manifold import TSNE
import plotly.express as px
import pandas as pd

import torch

from minerva.data.data_module_tools import get_full_data_split
from minerva.utils.typing import PathLike


[docs] class _ModelAnalysis: """Main interface for model analysis. A model analysis is a post-training analysis that can be run on a trained model to generate insights about the model's performance. It has a `path` attribute that specifies the directory where the analysis results will be saved. The `compute` method should be implemented by subclasses to perform the actual analysis. All insights generated by the analysis should be saved in the `path` directory. Note that, differently from `Metric`, `_ModelAnalysis` does not return any value. Instead, the results of the analysis should be saved in the `path` directory. All subclasses of `_ModelAnalysis` should implement the `compute` method. Inside a pipeline the path will be automatically set to the `pipeline.log_dir` attribute. """ def __init__(self, path: Optional[PathLike] = None): self._path = path @property def path(self): return self._path @path.setter def path(self, path: PathLike): self._path = Path(path)
[docs] def compute(self, model: L.LightningModule, data: L.LightningDataModule): raise NotImplementedError
[docs] class TSNEAnalysis(_ModelAnalysis): """Perform t-SNE analysis on the embeddings generated by a model. A t-SNE plot is generated using the embeddings and saved in the `path` directory. The plot is saved as a PNG image file. """ def __init__( self, label_names: Optional[Dict[Union[int, str], str]] = None, height: int = 800, width: int = 800, text_size: int = 12, title: Optional[str] = None, x_axis_title: str = "x", y_axis_title: str = "y", legend_title: str = "Label", output_filename: PathLike = "tsne.png", seed: int = 42, n_components: int = 2, ): """Plot a t-SNE plot of the embeddings generated by a model. Parameters ---------- label_names : Optional[Dict[Union[int, str], str]], optional Labels to use for the plot, instead of the original labels in the data (`y`). The keys are the original labels and the values are the new labels to use in the plot. If None, the original labels are used as they are. By default None height : int, optional Height of the figure, by default 800 width : int, optional Width of the figure, by default 800 text_size : int, optional Size of font used in plot, by default 12 title : str, optional Title of graph, by default None x_axis_title : str, optional Name of x-axis, by default "x" y_axis_title : str, optional Name of y-axis, by default "y" legend_title : str, optional Name for legend title, by default "Label" output_filename : PathLike, optional Name of the output file to save the plot as a PNG image file. The file will be saved in the `path` directory with this name. By default "tsne.png" seed : int, optional Random seed for t-SNE, by default 42 n_components : int, optional Number of components to use in t-SNE, by default 2 """ super().__init__() self.label_names = label_names self.height = height self.width = width self.text_size = text_size self.title = title self.output_filename = Path(output_filename) self.x_axis_title = x_axis_title self.y_axis_title = y_axis_title self.legend_title = legend_title self.seed = seed self.n_components = n_components
[docs] def compute(self, model: L.LightningModule, data: L.LightningDataModule): if not self.path: raise ValueError( "Path is not set. Please set the path before running the analysis." ) model.eval() X, y = get_full_data_split(data, "predict") X = torch.tensor(X, device="cpu") embeddings = model.backbone.forward(X) # type: ignore embeddings = embeddings.flatten(start_dim=1).detach().cpu().numpy() # print(f"***X shape: {X.shape}; embeddings shape: {embeddings.shape}") # Perform t-SNE on embeddings tsne_embeddings = TSNE( n_components=self.n_components, random_state=self.seed ).fit_transform(embeddings) # Create a DataFrame with embeddings and labels df = pd.DataFrame(data=tsne_embeddings, columns=["x", "y"]) df["label"] = y # If label names are provided, map the original labels to the new labels if self.label_names is not None: df["label"] = df["label"].map(self.label_names) # Sort the DataFrame by label and convert label to string (for discrete colors) df = df.sort_values(by="label") df["label"] = df["label"].astype(str) # Plot t-SNE embeddings with discrete colors based on the label fig = px.scatter( df, x="x", y="y", color=df["label"], labels={"color": "Label"}, # Legend label title=self.title, ) # Customize layout (optional) fig.update_layout( height=self.height, width=self.width, legend_title_text=self.legend_title, xaxis_title=self.x_axis_title, yaxis_title=self.y_axis_title, title=self.title, font=dict(size=self.text_size), ) # Save the figure path = (self.path / self.output_filename).resolve() fig.write_image(path) print(f"t-SNE plot saved to {path}") return str(path)