Source code for minerva.models.nets.time_series.ts2vec_classifier

from typing import Literal, Tuple
from minerva.models.nets.tnc import TSEncoder, DilatedConvEncoder
from minerva.models.nets.base import SimpleSupervisedModel
from minerva.models.nets.mlp import MLP
import torch


[docs] class TS2VecClassifier(SimpleSupervisedModel): def __init__( self, input_shape: Tuple[int, int], ts_input_dims: int, ts_output_dims: int, ts_hidden_dims: int = 64, ts_depth: int = 10, hidden_dims: int = 128, num_classes: int = 6, ): encoder = TSEncoder( input_dims=ts_input_dims, hidden_dims=ts_hidden_dims, output_dims=ts_output_dims, depth=ts_depth, permute=True, encoder_cls=DilatedConvEncoder, ) self.fc_input_features = self._calculate_fc_input_features( encoder, input_shape ) super().__init__( backbone=encoder, fc=MLP([self.fc_input_features, hidden_dims, num_classes]), loss_fn=torch.nn.CrossEntropyLoss(), flatten=True, )
[docs] def _calculate_fc_input_features( self, backbone: torch.nn.Module, input_shape: Tuple[int, int] ) -> int: """Run a single forward pass with a random input to get the number of features after the convolutional layers. Parameters ---------- backbone : torch.nn.Module The backbone of the network input_shape : Tuple[int, int, int] The input shape of the network. Returns ------- int The number of features after the convolutional layers. """ random_input = torch.randn(1, *input_shape) with torch.no_grad(): out = backbone(random_input) return out.reshape(out.size(0), -1).size(1)