{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting Started with Minerva for Seismic Facies Classification\n", "\n", "Seismic Facies Classification is a challenging problem in the field of geophysics. The goal is to predict the lithology of the subsurface based on seismic data. In this notebook, we will use the Minerva to train and evaluate a DeepLabV3 model from scratch for seismic facies classification.\n", "\n", "Thus, this notebook is a step-by-step guide to train a DeepLabV3 model for seismic facies classification using Minerva. It comprises the following steps:\n", "\n", "1. Data Preparation\n", "2. Model Creation\n", "3. Model Training\n", "4. Model Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/_distutils_hack/__init__.py:53: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from pathlib import Path\n", "import numpy as np\n", "from minerva.data.readers.patched_array_reader import NumpyArrayReader\n", "from minerva.data.data_modules.base import MinervaDataModule\n", "from minerva.transforms.transform import Repeat, Squeeze\n", "from minerva.data.datasets.base import SimpleDataset\n", "from minerva.models.nets.image.deeplabv3 import DeepLabV3\n", "import lightning as L\n", "import torch\n", "from torchmetrics import JaccardIndex\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Data Preparation\n", "\n", "We begin by preparing the data for training and evaluation.\n", "\n", "For this tutorial, we will use the [F3 dataset](https://zenodo.org/records/3755060/files/data.zip?download=1) from the seismic facies classification benchmark, introduced in the following work:\n", "\n", "```latex\n", "Alaudah, Y., Michałowicz, P., Alfarraj, M. and AlRegib, G., 2019. A machine-learning benchmark for facies classification. Interpretation, 7(3), pp.SE175-SE187.\n", "```\n", "\n", "This dataset is a 3D seismic volume from the F3 block in the Netherlands North Sea. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2025-03-31 17:44:57-- https://zenodo.org/records/3755060/files/data.zip?download=1\n", "Resolving zenodo.org (zenodo.org)... 188.185.43.25, 188.185.45.92, 188.185.48.194, ...\n", "Connecting to zenodo.org (zenodo.org)|188.185.43.25|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1051449986 (1003M) [application/octet-stream]\n", "Saving to: ‘f3.zip’\n", "\n", "f3.zip 100%[===================>] 1003M 11.9MB/s in 2m 6s \n", "\n", "2025-03-31 17:47:04 (7.97 MB/s) - ‘f3.zip’ saved [1051449986/1051449986]\n", "\n", "Archive: f3.zip\n", " creating: f3/data/\n", " creating: f3/data/train/\n", " inflating: f3/data/train/train_seismic.npy \n", " inflating: f3/data/train/train_labels.npy \n", " creating: f3/data/test_once/\n", " inflating: f3/data/test_once/test1_seismic.npy \n", " inflating: f3/data/test_once/test2_labels.npy \n", " inflating: f3/data/test_once/test1_labels.npy \n", " inflating: f3/data/test_once/test2_seismic.npy \n", " inflating: f3/data/.dropbox \n" ] } ], "source": [ "!wget https://zenodo.org/records/3755060/files/data.zip?download=1 -O f3.zip\n", "!unzip -o f3.zip -d f3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once extracted, the data is organized as follows:\n", "\n", "```tree\n", "f3/data/\n", "├── test_once\n", "│ ├── test1_labels.npy\n", "│ ├── test1_seismic.npy\n", "│ ├── test2_labels.npy\n", "│ └── test2_seismic.npy\n", "└── train\n", " ├── train_labels.npy\n", " └── train_seismic.npy\n", "```\n", "\n", "The `train` folder contains the training data, while the `test_once` folder holds the test data. Each sample is stored in two separate NumPy files: one for the seismic volume and one for the corresponding labels.\n", "\n", "Each `.npy` file contains a 3D volume with the following dimensions:\n", "\n", "- `train_seismic`: `(401, 701, 255)` — where `(401, 701)` are the spatial dimensions, and `255` is the number of seismic traces. \n", "- `test1_seismic`: `(200, 701, 255)` — same trace dimension, but fewer spatial slices. \n", "- `test2_seismic`: `(601, 200, 255)` — different spatial dimensions from training and Test 1.\n", "\n", "**Note**: The label volumes have the same shapes as their corresponding seismic volumes.\n", "\n", "To process the data, we iterate over the first dimension of the seismic volumes to extract 2D slices and their associated labels. For example, for the training data, we extract `401` slices of shape `(701, 255)` each, along with their matching label slices of the same shape.\n", "\n", "It's important to note that Test 1 samples have the same spatial dimensions as the training set, making it suitable for direct evaluation. Test 2, however, has different spatial dimensions and may require separate handling. Thus, we will focus on Test 1 for evaluation in this notebook.\n", "\n", "Each label slice is a 2D array containing 6 distinct classes, representing different lithologies." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1.1 Creating Data Readers and Dataset\n", "\n", "In Minerva, we can create readers, which are responsible for loading a single unit of data in a ordered way. It is like an ordered collection of data. A Dataset is a collection of readers and transforms, associated to each reader.\n", "\n", "Let's first create 2 readers for training data. The first reader will be responsible to iterate over the training seismic data, while the second reader will iterate over the corresponding labels.\n", "\n", "We will use `NumpyArrayReader` that allows use to read data from NumPy files.\n", "There are two required arguments for `NumpyArrayReader`:\n", "- `data`: the path to the NumPy file or the NumPy array.\n", "- `data_shape`: the shape of each sample. We will use `(1, 701, 255)` for seismic data and `(1, 701, 255)` for labels. Thus, the shape of each sample will be `(1, 701, 255)` and we will have 401 samples for training data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "root_data_dir = Path(\"f3/data/\")\n", "\n", "train_data_reader = NumpyArrayReader(\n", " data=root_data_dir / \"train\" / \"train_seismic.npy\",\n", " data_shape=(1, 701, 255),\n", ")\n", "\n", "train_labels_reader = NumpyArrayReader(\n", " data=root_data_dir / \"train\" / \"train_labels.npy\",\n", " data_shape=(1, 701, 255),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once the readers are set up, we can create a dataset using them. For this purpose, we'll use the `SimpleDataset` class, which takes two main inputs:\n", "\n", "- A list of **readers** (data sources) \n", "- A list of **transforms** (optional preprocessing steps)\n", "\n", "When an item at index `i` is requested from the dataset (e.g., `dataset[i]`), the following steps occur:\n", "\n", "1. The dataset retrieves the item at index `i` from `reader[0]` and applies the corresponding transform `transform[0]`.\n", "2. It then retrieves the item at index `i` from `reader[1]` and applies `transform[1]`.\n", "3. Finally, it returns a 2-element tuple containing the transformed outputs:\n", " - The first element is the transformed seismic data: `transform[0](reader[0][i])`\n", " - The second element is the transformed label: `transform[1](reader[1][i])`\n", "\n", "This design enables flexible and consistent pairing of multiple data sources with their corresponding preprocessing logic-ideal for training, validation, or testing workflows.\n", "\n", "\n", "We will use the following transforms:\n", "- For seismic data: as data is uni-dimensional, and our model expects RGB data, we will convert the seismic data to a 3-channel image by repeating the same data across 3 channels. Thus, we will use `Repeat` transform, and our data will be of shape `(3, 701, 255)`.\n", "- For labels: we will use no transforms. The data will be of shape `(1, 701, 255)`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", " 📂 SimpleDataset Information \n", "==================================================\n", "📌 Dataset Type: SimpleDataset\n", " └── Reader 0: NumpyArrayReader(samples=401, shape=(1, 701, 255), dtype=float64)\n", " │ └── Transform: Repeat(axis=0, n_repetitions=3)\n", " └── Reader 1: NumpyArrayReader(samples=401, shape=(1, 701, 255), dtype=uint8)\n", " │ └── Transform: None\n", " │\n", " └── Total Readers: 2\n", "==================================================\n" ] } ], "source": [ "train_dataset = SimpleDataset(\n", " readers=[train_data_reader, train_labels_reader],\n", " transforms=[Repeat(axis=0, n_repetitions=3), None],\n", ")\n", "\n", "print(train_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same will be done for the test data. However, the only difference will be at test label transform. We do not want the channels dimension for the labels. Thus, the shape of the label data will be `(701, 255)`. Thus, we will use `Squeeze` transform to remove the channels dimension. For the seismic data, we will use the same transform as the training data." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", " 📂 SimpleDataset Information \n", "==================================================\n", "📌 Dataset Type: SimpleDataset\n", " └── Reader 0: NumpyArrayReader(samples=200, shape=(1, 701, 255), dtype=float64)\n", " │ └── Transform: Repeat(axis=0, n_repetitions=3)\n", " └── Reader 1: NumpyArrayReader(samples=200, shape=(1, 701, 255), dtype=uint8)\n", " │ └── Transform: Squeeze(axis=0)\n", " │\n", " └── Total Readers: 2\n", "==================================================\n" ] } ], "source": [ "test_data_reader = NumpyArrayReader(\n", " data=root_data_dir / \"test_once\" / \"test1_seismic.npy\",\n", " data_shape=(1, 701, 255),\n", ")\n", "\n", "test_labels_reader = NumpyArrayReader(\n", " data=root_data_dir / \"test_once\" / \"test1_labels.npy\",\n", " data_shape=(1, 701, 255),\n", ")\n", "\n", "test_dataset = SimpleDataset(\n", " readers=[test_data_reader, test_labels_reader],\n", " transforms=[Repeat(axis=0, n_repetitions=3), Squeeze(0)],\n", ")\n", "\n", "print(test_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2. Creating the `MinervaDataModule`\n", "\n", "Minerva models are implemented using Pytorch Lightning. \n", "Thus, to train a model we should create a `LightningDataModule` object that will handle the data loading and preprocessing.\n", "Minerva provides a `MinervaDataModule` class that extends Pytorch Lightning's `LightningDataModule` class and standardizes the data loading process.\n", "\n", "We may create a `MinervaDataModule` object by passing the training, validation, and testing datasets, as well as the batch size and the number of workers for data loading.\n", "\n", "We use `drop_last` parameter to drop the last batch if it is smaller than the batch size. This is quite useful as DeepLabV3 model expects the batch is higher than 1." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", " 🆔 F3 Dataset\n", "==================================================\n", "└── Predict Split: test\n", "📂 Datasets:\n", " ├── Train Dataset:\n", " │ ==================================================\n", " │ 📂 SimpleDataset Information \n", " │ ==================================================\n", " │ 📌 Dataset Type: SimpleDataset\n", " │ └── Reader 0: NumpyArrayReader(samples=401, shape=(1, 701, 255), dtype=float64)\n", " │ │ └── Transform: Repeat(axis=0, n_repetitions=3)\n", " │ └── Reader 1: NumpyArrayReader(samples=401, shape=(1, 701, 255), dtype=uint8)\n", " │ │ └── Transform: None\n", " │ │\n", " │ └── Total Readers: 2\n", " │ ==================================================\n", " ├── Val Dataset:\n", " │ None\n", " └── Test Dataset:\n", " ==================================================\n", " 📂 SimpleDataset Information \n", " ==================================================\n", " 📌 Dataset Type: SimpleDataset\n", " └── Reader 0: NumpyArrayReader(samples=200, shape=(1, 701, 255), dtype=float64)\n", " │ └── Transform: Repeat(axis=0, n_repetitions=3)\n", " └── Reader 1: NumpyArrayReader(samples=200, shape=(1, 701, 255), dtype=uint8)\n", " │ └── Transform: Squeeze(axis=0)\n", " │\n", " └── Total Readers: 2\n", " ==================================================\n", "\n", "🛠 **Dataloader Configurations:**\n", " ├── Dataloader class: \n", " ├── Train Dataloader Kwargs:\n", " ├── batch_size: 16\n", " ├── num_workers: 4\n", " ├── shuffle: true\n", " ├── drop_last: true\n", " ├── Val Dataloader Kwargs:\n", " ├── batch_size: 16\n", " ├── num_workers: 4\n", " ├── shuffle: false\n", " ├── drop_last: false\n", " └── Test Dataloader Kwargs:\n", " ├── batch_size: 16\n", " ├── num_workers: 4\n", " ├── shuffle: false\n", " ├── drop_last: false\n", "==================================================\n" ] } ], "source": [ "data_module = MinervaDataModule(\n", " train_dataset=train_dataset,\n", " test_dataset=test_dataset,\n", " batch_size=16,\n", " num_workers=4,\n", " additional_train_dataloader_kwargs={\"drop_last\": True},\n", " name=\"F3 Dataset\"\n", ")\n", "\n", "print(data_module)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Creating the Model\n", "\n", "\n", "For this tutorial, we will use a DeepLabV3 model for seismic facies classification. The DeepLabV3 model is a popular architecture for semantic segmentation tasks, such as image segmentation. It is based on a deep convolutional neural network with a ResNet backbone and an Atrous Spatial Pyramid Pooling (ASPP) module. Minerva provides a `DeepLabV3` model that can be used for seismic facies classification. We just need to pass the number of classes to the model, which is 6 in this case." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeepLabV3(\n", " (backbone): DeepLabV3Backbone(\n", " (RN50model): ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (4): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (5): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=2048, out_features=1000, bias=True)\n", " )\n", " )\n", " (fc): DeepLabV3PredictionHead(\n", " (0): ASPP(\n", " (convs): ModuleList(\n", " (0): Sequential(\n", " (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (1): ASPPConv(\n", " (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (2): ASPPConv(\n", " (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(24, 24), dilation=(24, 24), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (3): ASPPConv(\n", " (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(36, 36), dilation=(36, 36), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (4): ASPPPooling(\n", " (0): AdaptiveAvgPool2d(output_size=1)\n", " (1): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): ReLU()\n", " )\n", " )\n", " (project): Sequential(\n", " (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " (3): Dropout(p=0.5, inplace=False)\n", " )\n", " )\n", " (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): ReLU()\n", " (4): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (loss_fn): CrossEntropyLoss()\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = DeepLabV3(num_classes=6)\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Defining the trainer\n", "\n", "As we are using Pytorch Lightning, we need to define a `Trainer` object to train the model.\n", "We can define the trainer by passing the following parameters:\n", "- `max_epochs`: Maximum number of epochs to train the model.\n", "- `acceleartor`: Device to use for training. It can be `cpu` or `gpu`.\n", "- `devices`: the list or the number of accelerator to use for training.\n", "\n", "For this example we will disable logging and checkpointing, by setting `logger=False` and `checkpoint_callback=False`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = L.Trainer(\n", " max_epochs=100,\n", " accelerator=\"gpu\",\n", " devices=1,\n", " logger=False,\n", " enable_checkpointing=False,\n", ")\n", "\n", "trainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.1. Training the model\n", "\n", "To train the model we need to have three objects: the model, the data module, and the trainer.\n", "We can train the model by calling the `fit` method from the trainer and passing the model and the data module.\n", "\n", "The `fit` method will train the model for the number of epochs defined in the trainer object. Also, training dataloader will be used for training, and validation dataloader will be used for validation." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/vscode/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params | Mode \n", "-------------------------------------------------------------\n", "0 | backbone | DeepLabV3Backbone | 25.6 M | train\n", "1 | fc | DeepLabV3PredictionHead | 16.1 M | train\n", "2 | loss_fn | CrossEntropyLoss | 0 | train\n", "-------------------------------------------------------------\n", "41.7 M Trainable params\n", "0 Non-trainable params\n", "41.7 M Total params\n", "166.736 Total estimated model params size (MB)\n", "186 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 0%| | 0/25 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "index = 100\n", "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", "axes[0].imshow(predicted_classes[index].cpu().numpy().T)\n", "axes[0].set_title(\"Predicted\")\n", "axes[1].imshow(y[index].cpu().numpy().T)\n", "axes[1].set_title(\"True\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 2 }