Source code for minerva.models.nets.time_series.imu_transformer

from typing import Tuple

import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from minerva.models.nets.base import SimpleSupervisedModel

"""
IMUTransformerEncoder model
"""


[docs] class _IMUTransformerEncoder(nn.Module): def __init__( self, input_shape: tuple = (6, 60), transformer_dim: int = 64, encode_position: bool = True, nhead: int = 8, dim_feedforward: int = 128, transformer_dropout: float = 0.1, transformer_activation: str = "gelu", num_encoder_layers: int = 6, permute: bool = False, ): """ input_shape: (tuple) shape of the input data transformer_dim: (int) dimension of the transformer encode_position: (bool) whether to encode position or not nhead: (int) number of attention heads dim_feedforward: (int) dimension of the feedforward network transformer_dropout: (float) dropout rate for the transformer transformer_activation: (str) activation function for the transformer num_encoder_layers: (int) number of transformer encoder layers num_classes: (int) number of output classes permute: bool, optional. If `True` the input data will be permuted before passing through the model, by default False. """ super().__init__() self.input_shape = input_shape self.transformer_dim = transformer_dim self.permute = permute self.input_proj = nn.Sequential( nn.Conv1d(input_shape[0], self.transformer_dim, 1), nn.GELU(), nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(), nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(), nn.Conv1d(self.transformer_dim, self.transformer_dim, 1), nn.GELU(), ) self.encode_position = encode_position encoder_layer = TransformerEncoderLayer( d_model=self.transformer_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=transformer_dropout, activation=transformer_activation, ) self.transformer_encoder = TransformerEncoder( encoder_layer, num_layers=num_encoder_layers, norm=nn.LayerNorm(self.transformer_dim), ) self.cls_token = nn.Parameter( torch.zeros((1, self.transformer_dim)), requires_grad=True ) if self.encode_position: self.position_embed = nn.Parameter( torch.randn(input_shape[1] + 1, 1, self.transformer_dim) ) # init for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
[docs] def forward(self, x): """Forward Parameters ---------- x : _type_ A tensor of shape (B, C, S) with B = batch size, C = channels, S = sequence length """ if self.permute: x = x.permute(0,2,1) # Embed in a high dimensional space and reshape to Transformer's expected shape x = self.input_proj(x) # print(f"src.shape: {src.shape}") x = x.permute(2, 0, 1) # Prepend class token cls_token = self.cls_token.unsqueeze(1).repeat(1, x.shape[1], 1) x = torch.cat([cls_token, x]) # Add the position embedding if self.encode_position: x += self.position_embed # Transformer Encoder pass target = self.transformer_encoder(x)[0] return target
[docs] class IMUTransformerEncoder(SimpleSupervisedModel): def __init__( self, input_shape: tuple = (6, 60), transformer_dim: int = 64, encode_position: bool = True, nhead: int = 8, dim_feedforward: int = 128, transformer_dropout: float = 0.1, transformer_activation: str = "gelu", num_encoder_layers: int = 6, num_classes: int = 6, learning_rate: float = 1e-3, # Arguments passed to the SimpleSupervisedModel constructor *args, **kwargs, ): self.input_shape = input_shape backbone = self._create_backbone( input_shape=input_shape, transformer_dim=transformer_dim, encode_position=encode_position, nhead=nhead, dim_feedforward=dim_feedforward, transformer_dropout=transformer_dropout, transformer_activation=transformer_activation, num_encoder_layers=num_encoder_layers, ) fc = self._create_fc(transformer_dim, num_classes) super().__init__( backbone=backbone, fc=fc, learning_rate=learning_rate, loss_fn=torch.nn.CrossEntropyLoss(), *args, **kwargs, )
[docs] def _create_backbone( self, input_shape, transformer_dim, encode_position, nhead, dim_feedforward, transformer_dropout, transformer_activation, num_encoder_layers, ): backbone = _IMUTransformerEncoder( input_shape=input_shape, transformer_dim=transformer_dim, encode_position=encode_position, nhead=nhead, dim_feedforward=dim_feedforward, transformer_dropout=transformer_dropout, transformer_activation=transformer_activation, num_encoder_layers=num_encoder_layers, ) return backbone
[docs] def _create_fc(self, transform_dim, num_classes): imu_head = nn.Sequential( nn.LayerNorm(transform_dim), nn.Linear(transform_dim, transform_dim // 4), nn.GELU(), nn.Dropout(0.1), nn.Linear(transform_dim // 4, num_classes), ) return imu_head
[docs] class IMUCNN(SimpleSupervisedModel): def __init__( self, input_shape: tuple = (6, 60), hidden_dim: int = 64, num_classes: int = 6, dropout_factor: float = 0.1, learning_rate: float = 1e-3, # Arguments passed to the SimpleSupervisedModel constructor *args, **kwargs, ): self.input_shape = input_shape self.hidden_dim = hidden_dim self.dropout_factor = dropout_factor backbone = self._create_backbone( input_shape=input_shape, hidden_dim=hidden_dim, dropout_factor=dropout_factor, ) self.fc_input_channels = self._calculate_fc_input_features( backbone, input_shape ) fc = self._create_fc(self.fc_input_channels, hidden_dim, num_classes) super().__init__( backbone=backbone, fc=fc, learning_rate=learning_rate, loss_fn=torch.nn.CrossEntropyLoss(), flatten=True, *args, **kwargs, )
[docs] def _create_backbone(self, input_shape, hidden_dim, dropout_factor): return torch.nn.Sequential( torch.nn.Conv1d(input_shape[0], hidden_dim, kernel_size=1), torch.nn.ReLU(), torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1), torch.nn.ReLU(), torch.nn.Dropout(dropout_factor), torch.nn.MaxPool1d(kernel_size=2), )
[docs] def _calculate_fc_input_features( self, backbone: torch.nn.Module, input_shape: Tuple[int, int] ) -> int: random_input = torch.randn(1, *input_shape) with torch.no_grad(): out = backbone(random_input) return out.view(out.size(0), -1).size(1)
[docs] def _create_fc(self, input_features, hidden_dim, num_classes): return torch.nn.Sequential( torch.nn.Linear(input_features, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, num_classes), )