Source code for minerva.models.nets.time_series.imu_transformer
from typing import Tuple
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from minerva.models.nets.base import SimpleSupervisedModel
"""
IMUTransformerEncoder model
"""
[docs]
class _IMUTransformerEncoder(nn.Module):
    def __init__(
        self,
        input_shape: tuple = (6, 60),
        transformer_dim: int = 64,
        encode_position: bool = True,
        nhead: int = 8,
        dim_feedforward: int = 128,
        transformer_dropout: float = 0.1,
        transformer_activation: str = "gelu",
        num_encoder_layers: int = 6,
        permute: bool = False,
    ):
        """
        input_shape: (tuple) shape of the input data
        transformer_dim: (int) dimension of the transformer
        encode_position: (bool) whether to encode position or not
        nhead: (int) number of attention heads
        dim_feedforward: (int) dimension of the feedforward network
        transformer_dropout: (float) dropout rate for the transformer
        transformer_activation: (str) activation function for the transformer
        num_encoder_layers: (int) number of transformer encoder layers
        num_classes: (int) number of output classes
        permute: bool, optional. If `True` the input data will be permuted before passing through the model, by default False.
        """
        super().__init__()
        self.input_shape = input_shape
        self.transformer_dim = transformer_dim
        self.permute = permute
        self.input_proj = nn.Sequential(
            nn.Conv1d(input_shape[0], self.transformer_dim, 1),
            nn.GELU(),
            nn.Conv1d(self.transformer_dim, self.transformer_dim, 1),
            nn.GELU(),
            nn.Conv1d(self.transformer_dim, self.transformer_dim, 1),
            nn.GELU(),
            nn.Conv1d(self.transformer_dim, self.transformer_dim, 1),
            nn.GELU(),
        )
        self.encode_position = encode_position
        encoder_layer = TransformerEncoderLayer(
            d_model=self.transformer_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=transformer_dropout,
            activation=transformer_activation,
        )
        self.transformer_encoder = TransformerEncoder(
            encoder_layer,
            num_layers=num_encoder_layers,
            norm=nn.LayerNorm(self.transformer_dim),
        )
        self.cls_token = nn.Parameter(
            torch.zeros((1, self.transformer_dim)), requires_grad=True
        )
        if self.encode_position:
            self.position_embed = nn.Parameter(
                torch.randn(input_shape[1] + 1, 1, self.transformer_dim)
            )
        # init
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
[docs]
    def forward(self, x):
        """Forward
        Parameters
        ----------
        x : _type_
            A tensor of shape (B, C, S) with B = batch size, C = channels, S = sequence length
        """
        if self.permute:
            x = x.permute(0, 2, 1)
        # Embed in a high dimensional space and reshape to Transformer's expected shape
        x = self.input_proj(x)
        # print(f"src.shape: {src.shape}")
        x = x.permute(2, 0, 1)
        # Prepend class token
        cls_token = self.cls_token.unsqueeze(1).repeat(1, x.shape[1], 1)
        x = torch.cat([cls_token, x])
        # Add the position embedding
        if self.encode_position:
            x += self.position_embed
        # Transformer Encoder pass
        target = self.transformer_encoder(x)[0]
        return target
[docs]
class IMUTransformerEncoder(SimpleSupervisedModel):
    def __init__(
        self,
        input_shape: tuple = (6, 60),
        transformer_dim: int = 64,
        encode_position: bool = True,
        nhead: int = 8,
        dim_feedforward: int = 128,
        transformer_dropout: float = 0.1,
        transformer_activation: str = "gelu",
        num_encoder_layers: int = 6,
        num_classes: int = 6,
        learning_rate: float = 1e-3,
        # Arguments passed to the SimpleSupervisedModel constructor
        *args,
        **kwargs,
    ):
        self.input_shape = input_shape
        backbone = self._create_backbone(
            input_shape=input_shape,
            transformer_dim=transformer_dim,
            encode_position=encode_position,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            transformer_dropout=transformer_dropout,
            transformer_activation=transformer_activation,
            num_encoder_layers=num_encoder_layers,
        )
        fc = self._create_fc(transformer_dim, num_classes)
        super().__init__(
            backbone=backbone,
            fc=fc,
            learning_rate=learning_rate,
            loss_fn=torch.nn.CrossEntropyLoss(),
            *args,
            **kwargs,
        )
[docs]
    def _create_backbone(
        self,
        input_shape,
        transformer_dim,
        encode_position,
        nhead,
        dim_feedforward,
        transformer_dropout,
        transformer_activation,
        num_encoder_layers,
    ):
        backbone = _IMUTransformerEncoder(
            input_shape=input_shape,
            transformer_dim=transformer_dim,
            encode_position=encode_position,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            transformer_dropout=transformer_dropout,
            transformer_activation=transformer_activation,
            num_encoder_layers=num_encoder_layers,
        )
        return backbone
[docs]
    def _create_fc(self, transform_dim, num_classes):
        imu_head = nn.Sequential(
            nn.LayerNorm(transform_dim),
            nn.Linear(transform_dim, transform_dim // 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(transform_dim // 4, num_classes),
        )
        return imu_head
[docs]
class IMUCNN(SimpleSupervisedModel):
    def __init__(
        self,
        input_shape: tuple = (6, 60),
        hidden_dim: int = 64,
        num_classes: int = 6,
        dropout_factor: float = 0.1,
        learning_rate: float = 1e-3,
        # Arguments passed to the SimpleSupervisedModel constructor
        *args,
        **kwargs,
    ):
        self.input_shape = input_shape
        self.hidden_dim = hidden_dim
        self.dropout_factor = dropout_factor
        backbone = self._create_backbone(
            input_shape=input_shape,
            hidden_dim=hidden_dim,
            dropout_factor=dropout_factor,
        )
        self.fc_input_channels = self._calculate_fc_input_features(
            backbone, input_shape
        )
        fc = self._create_fc(self.fc_input_channels, hidden_dim, num_classes)
        super().__init__(
            backbone=backbone,
            fc=fc,
            learning_rate=learning_rate,
            loss_fn=torch.nn.CrossEntropyLoss(),
            flatten=True,
            *args,
            **kwargs,
        )
[docs]
    def _create_backbone(self, input_shape, hidden_dim, dropout_factor):
        return torch.nn.Sequential(
            torch.nn.Conv1d(input_shape[0], hidden_dim, kernel_size=1),
            torch.nn.ReLU(),
            torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_factor),
            torch.nn.MaxPool1d(kernel_size=2),
        )
[docs]
    def _calculate_fc_input_features(
        self, backbone: torch.nn.Module, input_shape: Tuple[int, int]
    ) -> int:
        random_input = torch.randn(1, *input_shape)
        with torch.no_grad():
            out = backbone(random_input)
        return out.view(out.size(0), -1).size(1)
[docs]
    def _create_fc(self, input_features, hidden_dim, num_classes):
        return torch.nn.Sequential(
            torch.nn.Linear(input_features, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, num_classes),
        )