Source code for minerva.models.nets.time_series.ts2vec_classifier
from typing import Literal, Tuple
from minerva.models.nets.tnc import TSEncoder, DilatedConvEncoder
from minerva.models.nets.base import SimpleSupervisedModel
from minerva.models.nets.mlp import MLP
import torch
[docs]
class TS2VecClassifier(SimpleSupervisedModel):
    def __init__(
        self,
        input_shape: Tuple[int, int],
        ts_input_dims: int,
        ts_output_dims: int,
        ts_hidden_dims: int = 64,
        ts_depth: int = 10,
        hidden_dims: int = 128,
        num_classes: int = 6,
    ):
        encoder = TSEncoder(
            input_dims=ts_input_dims,
            hidden_dims=ts_hidden_dims,
            output_dims=ts_output_dims,
            depth=ts_depth,
            permute=True,
            encoder_cls=DilatedConvEncoder,
        )
        self.fc_input_features = self._calculate_fc_input_features(encoder, input_shape)
        super().__init__(
            backbone=encoder,
            fc=MLP([self.fc_input_features, hidden_dims, num_classes]),
            loss_fn=torch.nn.CrossEntropyLoss(),
            flatten=True,
        )
[docs]
    def _calculate_fc_input_features(
        self, backbone: torch.nn.Module, input_shape: Tuple[int, int]
    ) -> int:
        """Run a single forward pass with a random input to get the number of
        features after the convolutional layers.
        Parameters
        ----------
        backbone : torch.nn.Module
            The backbone of the network
        input_shape : Tuple[int, int, int]
            The input shape of the network.
        Returns
        -------
        int
            The number of features after the convolutional layers.
        """
        random_input = torch.randn(1, *input_shape)
        with torch.no_grad():
            out = backbone(random_input)
        return out.reshape(out.size(0), -1).size(1)