{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting Started with Minerva for Human Activity Recognition\n", "\n", "Human Activity Recognition (HAR) is a challenging task that involves identifying actions performed by individuals based on sensor data—typically time-series signals from accelerometers and gyroscopes.\n", "\n", "In this notebook, we will use Minerva to train and evaluate an [1D ResNet-SE](https://ieeexplore.ieee.org/document/9771436) model from scratch for classifying human activities using [DAGHAR Dataset](https://www.nature.com/articles/s41597-024-03951-4).\n", "\n", "Thus, this notebook is a step-by-step guide to get you started with Minerva for HAR.\n", "It comprehends the following steps:\n", "\n", "1. Data Preparation\n", "2. Model Creation\n", "3. Model Training\n", "4. Model Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import lightning as L\n", "import torch\n", "from torchmetrics import Accuracy\n", "\n", "from minerva.data.datasets.series_dataset import MultiModalSeriesCSVDataset\n", "from minerva.data.data_modules.base import MinervaDataModule\n", "from minerva.models.nets.time_series.resnet import ResNetSE1D_5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Data Preparation\n", "\n", "We begin by preparing the data required for training and evaluation.\n", "\n", "For this tutorial, we will use the [standardized view of the DAGHAR Dataset](https://zenodo.org/records/13987073), as introduced in the following paper:\n", "\n", "```latex\n", "Napoli, O., Duarte, D., Alves, P., Soto, D.H.P., de Oliveira, H.E., Rocha, A., Boccato, L. and Borin, E., 2024. \n", "A benchmark for domain adaptation and generalization in smartphone-based human activity recognition. \n", "Scientific Data, 11(1), p.1192.\n", "```\n", "\n", "This dataset includes time-series data from two tri-axial sensors—an accelerometer and a gyroscope—collected via smartphones. It is organized into six different datasets:\n", "- KuHar \n", "- MotionSense \n", "- RealWorld-waist \n", "- RealWorld-thigh \n", "- UCI \n", "- WISDM \n", "\n", "In this notebook, we will work with the **standardized view of the MotionSense** dataset.\n", "\n", "You can download and extract the dataset using the commands below (`wget` and `unzip`):" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2025-03-22 17:41:07-- https://zenodo.org/records/13987073/files/standardized_view.zip?download=1\n", "Resolving zenodo.org (zenodo.org)... 188.185.45.92, 188.185.43.25, 188.185.48.194, ...\n", "Connecting to zenodo.org (zenodo.org)|188.185.45.92|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 191454525 (183M) [application/octet-stream]\n", "Saving to: ‘daghar_standardized_view.zip’\n", "\n", "daghar_standardized 100%[===================>] 182.58M 5.08MB/s in 9m 27s \n", "\n", "2025-03-22 17:50:35 (330 KB/s) - ‘daghar_standardized_view.zip’ saved [191454525/191454525]\n", "\n", "Archive: daghar_standardized_view.zip\n", " creating: daghar_standardized_view/standardized_view/\n", " creating: daghar_standardized_view/standardized_view/KuHar/\n", " inflating: daghar_standardized_view/standardized_view/KuHar/validation.csv \n", " inflating: daghar_standardized_view/standardized_view/KuHar/train.csv \n", " inflating: daghar_standardized_view/standardized_view/KuHar/test.csv \n", " creating: daghar_standardized_view/standardized_view/MotionSense/\n", " inflating: daghar_standardized_view/standardized_view/MotionSense/validation.csv \n", " inflating: daghar_standardized_view/standardized_view/MotionSense/train.csv \n", " inflating: daghar_standardized_view/standardized_view/MotionSense/test.csv \n", " creating: daghar_standardized_view/standardized_view/RealWorld_thigh/\n", " inflating: daghar_standardized_view/standardized_view/RealWorld_thigh/validation.csv \n", " inflating: daghar_standardized_view/standardized_view/RealWorld_thigh/train.csv \n", " inflating: daghar_standardized_view/standardized_view/RealWorld_thigh/test.csv \n", " creating: daghar_standardized_view/standardized_view/RealWorld_waist/\n", " inflating: daghar_standardized_view/standardized_view/RealWorld_waist/validation.csv \n", " inflating: daghar_standardized_view/standardized_view/RealWorld_waist/train.csv \n", " inflating: daghar_standardized_view/standardized_view/RealWorld_waist/test.csv \n", " creating: daghar_standardized_view/standardized_view/UCI/\n", " inflating: daghar_standardized_view/standardized_view/UCI/validation.csv \n", " inflating: daghar_standardized_view/standardized_view/UCI/train.csv \n", " inflating: daghar_standardized_view/standardized_view/UCI/test.csv \n", " creating: daghar_standardized_view/standardized_view/WISDM/\n", " inflating: daghar_standardized_view/standardized_view/WISDM/validation.csv \n", " inflating: daghar_standardized_view/standardized_view/WISDM/train.csv \n", " inflating: daghar_standardized_view/standardized_view/WISDM/test.csv \n" ] } ], "source": [ "!wget https://zenodo.org/records/13987073/files/standardized_view.zip?download=1 -O daghar_standardized_view.zip\n", "!unzip -o daghar_standardized_view.zip -d daghar_standardized_view" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Once extracted, the folder structure will look like this:\n", "\n", "```tree\n", "daghar_standardized_view/standardized_view\n", "├── KuHar\n", "│ ├── train.csv\n", "│ ├── test.csv\n", "│ └── val.csv\n", "├── MotionSense\n", "│ ├── train.csv\n", "│ ├── test.csv\n", "│ └── val.csv\n", "├── RealWorld_waist\n", "│ ├── train.csv\n", "│ ├── test.csv\n", "│ └── val.csv\n", "├── RealWorld_thigh\n", "│ ├── train.csv\n", "│ ├── test.csv\n", "│ └── val.csv\n", "├── UCI\n", "│ ├── train.csv\n", "│ ├── test.csv\n", "│ └── val.csv\n", "└── WISDM\n", " ├── train.csv\n", " ├── test.csv\n", " └── val.csv\n", "```\n", "\n", "Each dataset is split into `train.csv`, `val.csv`, and `test.csv` files. Each file contains time-series data with the following structure:\n", "\n", "| Column Range | Description |\n", "|-------------------------------|---------------------------------------------|\n", "| `accel-x-0` to `accel-x-59` | 60 time steps of accelerometer x-axis |\n", "| `accel-y-0` to `accel-y-59` | 60 time steps of accelerometer y-axis |\n", "| `accel-z-0` to `accel-z-59` | 60 time steps of accelerometer z-axis |\n", "| `gyro-x-0` to `gyro-x-59` | 60 time steps of gyroscope x-axis |\n", "| `gyro-y-0` to `gyro-y-59` | 60 time steps of gyroscope y-axis |\n", "| `gyro-z-0` to `gyro-z-59` | 60 time steps of gyroscope z-axis |\n", "| `standard activity code` | Encoded activity label |\n", "\n", "Each row represents one sample, composed of 6 channels (3 from each sensor) and 60 time steps per channel, representing 3 seconds of data at a sampling rate of 20 Hz.\n", "\n", "All datasets in DAGHAR share the same structure and label set. The activity codes are mapped as follows:\n", "\n", "| Standard Activity Code | Activity |\n", "|------------------------|--------------|\n", "| 0 | Sit |\n", "| 1 | Stand |\n", "| 2 | Walk |\n", "| 3 | Stair-up |\n", "| 4 | Stair-down |\n", "| 5 | Run |\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1. Creating Datasets\n", "\n", "Our root level directory is `daghar_standardized_view/standardized_view`. We will create 3 dataset objects for training, validation, and testing using the `MultiModalSeriesCSVDataset` class from Minerva.\n", "\n", "This class allows us to read any CSV file containing multimodal time-series in the format described above.\n", "The parameters required to create the dataset are:\n", "- `data_path`: Path to the CSV file\n", "- `feature_prefixes`: List of prefixes for the features to be considered. For instance, `['accel-x', 'accel-y', 'accel-z']` will read all columns starting with `accel-x`, than all columns starting with `accel-y`, and, finally, all columns starting with `accel-z`. Then, a single sample will be composed of 3 channels, where first channel will contain all columns starting with `accel-x`, second channel will contain all columns starting with `accel-y`, and third channel will contain all columns starting with `accel-z`. Note that the order of the prefixes in the list will determine the order of the channels in the sample. Also, other prefixes will be ignored.\n", "- `features_as_channels`: If True, the features will be treated as channels. If False, the features will be flattened into a single channel.\n", "- `label`: Column name of the label." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "root_data_dir = Path(\"daghar_standardized_view/standardized_view/MotionSense\")\n", "\n", "# Create the train dataset\n", "train_dataset = MultiModalSeriesCSVDataset(\n", " data_path=root_data_dir / \"train.csv\",\n", " feature_prefixes=[\n", " \"accel-x-\",\n", " \"accel-y\",\n", " \"accel-z\",\n", " \"gyro-x\",\n", " \"gyro-y\",\n", " \"gyro-z\",\n", " ],\n", " features_as_channels=True,\n", " label=\"standard activity code\",\n", ")\n", "\n", "# Create the validation dataset\n", "val_dataset = MultiModalSeriesCSVDataset(\n", " data_path=root_data_dir / \"validation.csv\",\n", " feature_prefixes=[\n", " \"accel-x-\",\n", " \"accel-y\",\n", " \"accel-z\",\n", " \"gyro-x\",\n", " \"gyro-y\",\n", " \"gyro-z\",\n", " ],\n", " features_as_channels=True,\n", " label=\"standard activity code\",\n", ")\n", "\n", "# Create the test dataset\n", "test_dataset = MultiModalSeriesCSVDataset(\n", " data_path=root_data_dir / \"test.csv\",\n", " feature_prefixes=[\n", " \"accel-x-\",\n", " \"accel-y\",\n", " \"accel-z\",\n", " \"gyro-x\",\n", " \"gyro-y\",\n", " \"gyro-z\",\n", " ],\n", " features_as_channels=True,\n", " label=\"standard activity code\",\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train dataset: MultiModalSeriesCSVDataset at daghar_standardized_view/standardized_view/MotionSense/train.csv (3558 samples)\n", "Validation dataset: MultiModalSeriesCSVDataset at daghar_standardized_view/standardized_view/MotionSense/validation.csv (420 samples)\n", "Test dataset: MultiModalSeriesCSVDataset at daghar_standardized_view/standardized_view/MotionSense/test.csv (1062 samples)\n" ] } ], "source": [ "print(f\"Train dataset: {train_dataset}\")\n", "print(f\"Validation dataset: {val_dataset}\")\n", "print(f\"Test dataset: {test_dataset}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at the first sample of the training dataset. \n", "As each sample has 6 channels and 60 time steps, the shape of the sample will be `(6, 60)`.\n", "The label will be a single integer representing the activity code." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The shape of the input is (6, 60) and the label is 4\n" ] } ], "source": [ "X, y = train_dataset[0]\n", "print(f\"The shape of the input is {X.shape} and the label is {y}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2. Creating the `MinervaDataModule`\n", "\n", "Minerva models are implemented using Pytorch Lightning. \n", "Thus, to train a model we should create a `LightningDataModule` object that will handle the data loading and preprocessing.\n", "Minerva provides a `MinervaDataModule` class that extends Pytorch Lightning's `LightningDataModule` class and standardizes the data loading process.\n", "\n", "We may create a `MinervaDataModule` object by passing the training, validation, and testing datasets, as well as the batch size and the number of workers for data loading" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", " 🆔 MotionSense Dataset\n", "==================================================\n", "├── Predict Split: test\n", "└── Dataloader class: \n", "📂 Datasets:\n", " ├── Train Dataset:\n", " │ MultiModalSeriesCSVDataset at daghar_standardized_view/standardized_view/MotionSense/train.csv (3558 samples)\n", " ├── Val Dataset:\n", " │ MultiModalSeriesCSVDataset at daghar_standardized_view/standardized_view/MotionSense/validation.csv (420 samples)\n", " └── Test Dataset:\n", " MultiModalSeriesCSVDataset at daghar_standardized_view/standardized_view/MotionSense/test.csv (1062 samples)\n", "\n", "🛠 **Dataloader Configurations:**\n", " ├── Train Dataloader Kwargs:\n", " ├── batch_size: 64\n", " ├── num_workers: 4\n", " ├── shuffle: true\n", " ├── drop_last: false\n", " ├── Val Dataloader Kwargs:\n", " ├── batch_size: 64\n", " ├── num_workers: 4\n", " ├── shuffle: false\n", " ├── drop_last: false\n", " └── Test Dataloader Kwargs:\n", " ├── batch_size: 64\n", " ├── num_workers: 4\n", " ├── shuffle: false\n", " ├── drop_last: false\n", "==================================================\n" ] } ], "source": [ "data_module = MinervaDataModule(\n", " train_dataset=train_dataset, \n", " val_dataset=val_dataset, \n", " test_dataset=test_dataset, \n", " batch_size=64,\n", " name=\"MotionSense Dataset\",\n", " num_workers=4\n", ")\n", "\n", "print(data_module)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Creating the Model\n", "\n", "Minerva provides several models to be used in Human Activity Recognition. \n", "In this notebook we will use the [`ResNetSE1D_5` model](https://ieeexplore.ieee.org/document/9771436) which is a 1D ResNet model with Squeeze-and-Excitation blocks and 5 residual blocks.\n", "\n", "To create the model, we just need to call the `ResNetSE1D_5` class from Minerva and pass the following parameters:\n", "- `input_shape`: Shape of each input samples, in the format `(channels, time_steps)`. In this case, it will be `(6, 60)`.\n", "- `num_classes`: Number of classes in the dataset. In this case, it will be 6." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[W322 18:28:16.875825496 NNPACK.cpp:62] Could not initialize NNPACK! Reason: Unsupported hardware.\n" ] }, { "data": { "text/plain": [ "ResNetSE1D_5(\n", " (backbone): _ResNet1D(\n", " (conv_block): ConvolutionalBlock(\n", " (block): Sequential(\n", " (0): Conv1d(6, 64, kernel_size=(5,), stride=(1,))\n", " (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " )\n", " (residual_blocks): Sequential(\n", " (0): ResNetSEBlock(\n", " (block): Sequential(\n", " (0): Conv1d(64, 32, kernel_size=(5,), stride=(1,), padding=same)\n", " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=same)\n", " (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SqueezeAndExcitation1D(\n", " (block): Sequential(\n", " (0): Linear(in_features=64, out_features=32, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=32, out_features=64, bias=True)\n", " (3): Sigmoid()\n", " )\n", " )\n", " )\n", " )\n", " (1): ResNetSEBlock(\n", " (block): Sequential(\n", " (0): Conv1d(64, 32, kernel_size=(5,), stride=(1,), padding=same)\n", " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=same)\n", " (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SqueezeAndExcitation1D(\n", " (block): Sequential(\n", " (0): Linear(in_features=64, out_features=32, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=32, out_features=64, bias=True)\n", " (3): Sigmoid()\n", " )\n", " )\n", " )\n", " )\n", " (2): ResNetSEBlock(\n", " (block): Sequential(\n", " (0): Conv1d(64, 32, kernel_size=(5,), stride=(1,), padding=same)\n", " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=same)\n", " (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SqueezeAndExcitation1D(\n", " (block): Sequential(\n", " (0): Linear(in_features=64, out_features=32, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=32, out_features=64, bias=True)\n", " (3): Sigmoid()\n", " )\n", " )\n", " )\n", " )\n", " (3): ResNetSEBlock(\n", " (block): Sequential(\n", " (0): Conv1d(64, 32, kernel_size=(5,), stride=(1,), padding=same)\n", " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=same)\n", " (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SqueezeAndExcitation1D(\n", " (block): Sequential(\n", " (0): Linear(in_features=64, out_features=32, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=32, out_features=64, bias=True)\n", " (3): Sigmoid()\n", " )\n", " )\n", " )\n", " )\n", " (4): ResNetSEBlock(\n", " (block): Sequential(\n", " (0): Conv1d(64, 32, kernel_size=(5,), stride=(1,), padding=same)\n", " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=same)\n", " (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): SqueezeAndExcitation1D(\n", " (block): Sequential(\n", " (0): Linear(in_features=64, out_features=32, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=32, out_features=64, bias=True)\n", " (3): Sigmoid()\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (global_avg_pool): AdaptiveAvgPool1d(output_size=1)\n", " )\n", " (fc): Linear(in_features=64, out_features=6, bias=True)\n", " (loss_fn): CrossEntropyLoss()\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = ResNetSE1D_5(\n", " input_shape=(6, 60),\n", " num_classes=6,\n", ")\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Defining the trainer\n", "\n", "As we are using Pytorch Lightning, we need to define a `Trainer` object to train the model.\n", "We can define the trainer by passing the following parameters:\n", "- `max_epochs`: Maximum number of epochs to train the model.\n", "- `acceleartor`: Device to use for training. It can be `cpu` or `gpu`.\n", "- `devices`: the list or the number of accelerator to use for training.\n", "\n", "For this example we will disable logging and checkpointing, by setting `logger=False` and `checkpoint_callback=False`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = L.Trainer(\n", " max_epochs=100,\n", " devices=1,\n", " accelerator=\"gpu\",\n", " logger=False,\n", " enable_checkpointing=False\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.1. Training the model\n", "\n", "To train the model we need to have three objects: the model, the data module, and the trainer.\n", "We can train the model by calling the `fit` method from the trainer and passing the model and the data module.\n", "\n", "The `fit` method will train the model for the number of epochs defined in the trainer object. Also, training dataloader will be used for training, and validation dataloader will be used for validation." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", " | Name | Type | Params | Mode \n", "------------------------------------------------------\n", "0 | backbone | _ResNet1D | 126 K | train\n", "1 | fc | Linear | 390 | train\n", "2 | loss_fn | CrossEntropyLoss | 0 | train\n", "------------------------------------------------------\n", "127 K Trainable params\n", "0 Non-trainable params\n", "127 K Total params\n", "0.509 Total estimated model params size (MB)\n", "76 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \r" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py:512: You called `self.log('val_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 2%|▏ | 1/56 [00:00<00:13, 4.19it/s]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py:512: You called `self.log('train_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 99: 100%|██████████| 56/56 [00:03<00:00, 16.92it/s, val_loss=1.960, train_loss=0.0358]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=100` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 99: 100%|██████████| 56/56 [00:03<00:00, 16.90it/s, val_loss=1.960, train_loss=0.0358]\n" ] } ], "source": [ "trainer.fit(model, data_module)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Evaluating Model\n", "\n", "Once model is trained, we can evaluate the performance of the model on the test dataset.\n", "The performance is evaluated using the accuracy metric.\n", "To evaluate the model, we perform the following steps:\n", "1. Perform inference on the test dataset using the trained model. This is done using the `trainer.predict` method. The method returns the predicted logits for each sample in the test dataset.\n", "2. Calculate the predicted labels by taking the argmax of the logits.\n", "3. Obtain the labels from the test dataset.\n", "4. Create the accuracy metric object and pass the predicted labels and the true labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 100%|██████████| 17/17 [00:00<00:00, 94.88it/s]\n" ] } ], "source": [ "# 1. Obtain predictions for the test set\n", "predictions = trainer.predict(model, data_module)\n", "# As predictions is a list of batches, we concatenate them along the first dimension\n", "predictions = torch.cat(predictions, dim=0) # type: ignore" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The shape of the predicted labels is torch.Size([1062]) and dtype torch.int64\n" ] } ], "source": [ "# 2. We an use the torch.argmax function to obtain the class with the highest probability\n", "predicted_classes = torch.argmax(predictions, dim=1) # type: ignore\n", "# Let's print the predicted classes\n", "print(f\"The shape of the predicted labels is {predicted_classes.shape} and dtype {predicted_classes.dtype}\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The shape of the true labels is torch.Size([1062]) and dtype torch.int64\n" ] } ], "source": [ "# 3. Let's obtain the true labels\n", "_, y = data_module.test_dataset[:] # type: ignore\n", "y = torch.from_numpy(y)\n", "# Let's print the true labels\n", "print(f\"The shape of the true labels is {y.shape} and dtype {y.dtype}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The accuracy of the model is 81.45%\n" ] } ], "source": [ "# 4. Let's create the accuracy metric object and compute the accuracy\n", "acc_metric = Accuracy(task=\"multiclass\", num_classes=6)\n", "score = acc_metric(predicted_classes, y)\n", "print(f\"The accuracy of the model is {score.item()*100:.2f}%\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }