from pathlib import Path
from typing import Any, Dict, Optional, Union, List
import lightning as L
from sklearn.manifold import TSNE
import plotly.express as px
import pandas as pd
from minerva.models.nets.base import SimpleSupervisedModel
import torch
from minerva.data.data_module_tools import get_full_data_split
from minerva.utils.typing import PathLike
import plotly.graph_objects as go
# Global variable to control plotly.js inclusion
_plot_tsne_written_dirs = set()
[docs]
class _ModelAnalysis:
"""Main interface for model analysis. A model analysis is a post-training
analysis that can be run on a trained model to generate insights about the
model's performance. It has a `path` attribute that specifies the directory
where the analysis results will be saved. The `compute` method should be
implemented by subclasses to perform the actual analysis. All insights
generated by the analysis should be saved in the `path` directory.
Note that, differently from `Metric`, `_ModelAnalysis` does not return any
value. Instead, the results of the analysis should be saved in the `path`
directory. All subclasses of `_ModelAnalysis` should implement the `compute`
method. Inside a pipeline the path will be automatically set to the
`pipeline.log_dir` attribute.
"""
def __init__(self, path: Optional[PathLike] = None):
self._path = path
@property
def path(self):
return self._path
@path.setter
def path(self, path: PathLike):
self._path = Path(path)
[docs]
def compute(self, model: L.LightningModule, data: L.LightningDataModule):
raise NotImplementedError
[docs]
class TSNEAnalysis(_ModelAnalysis):
"""Perform t-SNE analysis on the embeddings generated by a model.
A t-SNE plot is generated using the embeddings and saved in the `path`
directory. The plot is saved as both PNG and HTML files.
"""
def __init__(
self,
label_names: Optional[Dict[Union[int, str], str]] = None,
height: int = 1000,
width: int = 1000,
text_size: int = 30,
title: Optional[str] = None,
x_axis_title: str = "x",
y_axis_title: str = "y",
legend_title: str = "Label",
output_filename: PathLike = "tsne.png",
seed: int = 42,
n_components: int = 2,
marker_symbols: Optional[Dict[str, str]] = None,
colors: Optional[Dict[str, str]] = None,
write_html: bool = True,
):
"""Plot a t-SNE plot of the embeddings generated by a model.
Parameters
----------
label_names : Optional[Dict[Union[int, str], str]], optional
Labels to use for the plot, instead of the original labels in the
data (`y`). The keys are the original labels and the values are the
new labels to use in the plot. If None, the original labels are used
as they are. By default None
height : int, optional
Height of the figure, by default 1000
width : int, optional
Width of the figure, by default 1000
text_size : int, optional
Size of font used in plot, by default 30
title : str, optional
Title of graph, by default None
x_axis_title : str, optional
Name of x-axis, by default "x"
y_axis_title : str, optional
Name of y-axis, by default "y"
legend_title : str, optional
Name for legend title, by default "Label"
output_filename : PathLike, optional
Name of the output file to save the plot as a PNG image file. The
file will be saved in the `path` directory with this name. By
default "tsne.png"
seed : int, optional
Random seed for t-SNE, by default 42
n_components : int, optional
Number of components to use in t-SNE, by default 2
marker_symbols : Optional[Dict[str, str]], optional
Dictionary mapping labels to marker symbols. If None, will use default symbols.
colors : Optional[Dict[str, str]], optional
Dictionary mapping labels to color values (hex codes or names). If None, will use plotly's default color sequence.
write_html: bool, optional
If True, saves the plot as an HTML file with interactive controls. Default is True.
Examples
--------
Create and run a t-SNE analysis for a fine-tuned model on the MotionSense dataset:
>>> analysis = TSNEAnalysis(
height=1000,
width=1000,
legend_title="Activity",
title=" ",
output_filename="tsne_analysis.png",
label_names={
0: "sit",
1: "stand",
2: "walk",
3: "stair up",
4: "stair down",
5: "run",
6: "stair up and down",
},
marker_symbols={
"sit": "x-open",
"stand": "cross-open",
"stair up": "triangle-up-open",
"stair down": "triangle-down-open",
"walk": "circle-open",
"run": "star-open"
},
text_size=30,
x_axis_title='1st Component',
y_axis_title='2nd Component',
)
"""
super().__init__()
self.label_names = label_names
self.height = height
self.width = width
self.text_size = text_size
self.title = title
self.output_filename = Path(output_filename)
self.x_axis_title = x_axis_title
self.y_axis_title = y_axis_title
self.legend_title = legend_title
self.seed = seed
self.n_components = n_components
self.marker_symbols = marker_symbols
self.colors = colors
self.write_html = write_html
assert (
self.n_components == 2
), "For now, n_components must be set to 2 for t-SNE analysis"
[docs]
def set_path(self, path: PathLike):
"""Set the output path for saving the plots."""
self.path = Path(path)
[docs]
def compute(self, model: L.LightningModule, data: L.LightningDataModule):
"""
Run the t-SNE analysis on the provided model and dataset.
This method extracts embeddings from the given model using the provided
LightningDataModule, applies t-SNE to reduce the embeddings to the
specified number of components, and generates a scatter plot with
points colored and marked according to their labels.
The resulting plot is saved as:
- A static PNG file (`output_filename`) in the directory `self.path`.
- An interactive HTML file (if `write_html=True`).
Parameters
----------
model : lightning.LightningModule
A trained PyTorch Lightning model with a `.backbone` attribute that
produces embeddings from the input data.
data : lightning.LightningDataModule
A Lightning DataModule providing the dataset to project. Data is
extracted using the `"predict"` split.
"""
if not self.path:
raise ValueError(
"Path is not set. Please set the path before running the analysis."
)
model.eval()
X, y = get_full_data_split(data, "predict")
X = torch.tensor(X, device="cpu")
embeddings = (
model.backbone.forward(X).flatten(start_dim=1).detach().cpu().numpy()
)
# t-SNE
tsne_embeddings = TSNE(
n_components=self.n_components, random_state=self.seed
).fit_transform(embeddings)
df = pd.DataFrame(tsne_embeddings, columns=["x", "y"])
df["label"] = y
if self.label_names:
df["label"] = df["label"].map(self.label_names)
df["label"] = df["label"].astype(str)
# Color/marker handling
used_colors = self.colors or px.colors.qualitative.Plotly
label_list = sorted(df["label"].unique())
color_map = {
label: used_colors[i % len(used_colors)]
for i, label in enumerate(label_list)
}
marker_map = self.marker_symbols or {label: "circle" for label in label_list}
# Create traces per label
fig = go.Figure()
for label in label_list:
sub_df = df[df["label"] == label]
fig.add_trace(
go.Scatter(
x=sub_df["x"],
y=sub_df["y"],
mode="markers",
name=label,
marker=dict(
size=10,
line=dict(width=2, color="black"), # bold outline
color=color_map[label],
symbol=marker_map[label],
),
)
)
# Layout
fig.update_layout(
height=self.height,
width=self.width,
title=dict(
text=self.title or "t-SNE",
y=0.95,
x=0.5,
xanchor="center",
yanchor="top",
font=dict(size=self.text_size * 1.5),
),
font=dict(size=self.text_size, family="Times New Roman"),
margin=dict(l=10, r=10, t=80, b=10),
legend=dict(
orientation="h",
y=1.1,
x=0.5,
xanchor="center",
yanchor="bottom",
font=dict(size=self.text_size * 1.2),
itemsizing="constant",
),
xaxis_title=self.x_axis_title,
yaxis_title=self.y_axis_title,
legend_title_text=self.legend_title,
)
png_path = (self.path / self.output_filename).resolve()
fig.write_image(png_path)
html_path = png_path.with_suffix(".html")
if self.write_html:
fig.write_html(html_path)
print(f"t-SNE PNG saved to {png_path}")
if self.write_html:
print(f"t-SNE HTML saved to {html_path}")
return {
"png_path": str(png_path),
"html_path": str(html_path) if self.write_html else None,
}