Minerva Experimental API

This notebook is still in development and is not yet ready.

[1]:
from minerva.pipelines.experiment import (
    ModelInstantiator,
    ModelInformation,
    ModelConfig,
    Experiment,
)
from minerva.data.data_modules.base import MinervaDataModule

from pathlib import Path
from typing import Dict, Tuple, Optional

import numpy as np

from minerva.data.datasets.supervised_dataset import SimpleDataset
from minerva.data.readers import TiffReader, PNGReader
from minerva.transforms.transform import (
    _Transform,
    TransformPipeline,
    Transpose,
    PadCrop,
    CastTo,
    Unsqueeze,
    Squeeze,
    Identity,
)
from minerva.utils.typing import PathLike
import pandas as pd

import torchvision.transforms as T
/usr/local/lib/python3.10/dist-packages/_distutils_hack/__init__.py:53: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml
  warnings.warn(
[2]:
class NumberChannels(_Transform):
    def __init__(self, num_channels: int):
        self.num_channels = num_channels
        assert self.num_channels in [1, 3], "Number of channels must be 1 or 3"

    def __call__(self, data: np.ndarray) -> np.ndarray:
        if data.ndim != 3:
            raise ValueError(
                f"Data must have 3 dimensions, but got {data.ndim}"
            )

        if data.shape[0] != 3:
            raise ValueError(
                f"Data must have 3 channels, but got {data.shape[0]}"
            )

        if self.num_channels == 1:
            return data[0, :, :].reshape(1, data.shape[1], data.shape[2])
        else:
            return data

    def __str__(self):
        return f"NumberChannels(num_channels={self.num_channels})"

    def __repr__(self):
        return str(self)


class MinMaxNormalize(_Transform):
    def __call__(self, data: np.ndarray) -> np.ndarray:
        return (data - data.min()) / (data.max() - data.min())

    def __str__(self):
        return "MinMaxNormalize()"


def get_paihaka_data_module(
    root_data_dir: Path,
    root_annotation_dir: Path,
    img_size: Tuple[int, int, int] = (3, 1006, 590),
    label_size: Optional[Tuple[int, int, int]] = None,
    batch_size: int = 1,
    seed: int = 42,
    padding_mode: str = "reflect",
    padding_input_constant: int = 0,
    padding_label_constant: int = 0,
    normalize: bool = False,
    pad_test: bool = False,
) -> MinervaDataModule:
    name = f"seam_ai_padding_{padding_mode}"
    label_size = label_size or img_size

    return MinervaDataModule(
        name=name,
        predict_split="test",
        batch_size=batch_size,
        drop_last=True,
        shuffle_train=True,
        additional_test_dataloader_kwargs={
            "drop_last": False,
        },
        # Train dataset is a SimpleDataset with two readers
        # (TiffReader and PNGReader) with root path at "/train"
        train_dataset=SimpleDataset(
            readers=[
                TiffReader(
                    path=root_data_dir / "train",
                    sort_method=["text", "numeric"],
                    delimiter="_",
                    key_index=[0, 1],
                ),
                PNGReader(
                    path=root_annotation_dir / "train",
                    sort_method=["text", "numeric"],
                    delimiter="_",
                    key_index=[0, 1],
                ),
            ],
            transforms=[
                TransformPipeline(
                    [
                        Transpose([2, 0, 1]),
                        MinMaxNormalize() if normalize else Identity(),
                        PadCrop(
                            target_h_size=img_size[1],
                            target_w_size=img_size[2],
                            padding_mode=padding_mode,
                            constant_values=padding_input_constant,
                            seed=seed,
                        ),
                        NumberChannels(num_channels=img_size[0]),
                        CastTo("float32"),
                    ]
                ),
                TransformPipeline(
                    [
                        Unsqueeze(0),
                        PadCrop(
                            target_h_size=label_size[1],
                            target_w_size=label_size[2],
                            padding_mode=padding_mode,
                            constant_values=padding_label_constant,
                            seed=seed,
                        ),
                        CastTo("int32"),
                    ]
                ),
            ],
        ),
        # Validation dataset is a SimpleDataset with two readers
        # (TiffReader and PNGReader) with root path at "/val". Same transforms as
        # train dataset
        val_dataset=SimpleDataset(
            readers=[
                TiffReader(
                    path=root_data_dir / "val",
                    sort_method=["text", "numeric"],
                    delimiter="_",
                    key_index=[0, 1],
                ),
                PNGReader(
                    path=root_annotation_dir / "val",
                    sort_method=["text", "numeric"],
                    delimiter="_",
                    key_index=[0, 1],
                ),
            ],
            transforms=[
                TransformPipeline(
                    [
                        Transpose([2, 0, 1]),
                        MinMaxNormalize() if normalize else Identity(),
                        PadCrop(
                            target_h_size=img_size[1],
                            target_w_size=img_size[2],
                            padding_mode=padding_mode,
                            constant_values=padding_input_constant,
                            seed=seed,
                        ),
                        NumberChannels(num_channels=img_size[0]),
                        CastTo("float32"),
                    ]
                ),
                TransformPipeline(
                    [
                        Unsqueeze(0),
                        PadCrop(
                            target_h_size=label_size[1],
                            target_w_size=label_size[2],
                            padding_mode=padding_mode,
                            constant_values=padding_label_constant,
                            seed=seed,
                        ),
                        CastTo("int32"),
                    ]
                ),
            ],
        ),
        # Test dataset is a SimpleDataset with two readers
        # (TiffReader and PNGReader) with root path at "/test". The transforms are
        # the same as the train dataset, except for the label, which is squeezed
        # instead of unsqueezed
        test_dataset=SimpleDataset(
            readers=[
                TiffReader(
                    path=root_data_dir / "test",
                    sort_method=["text", "numeric"],
                    delimiter="_",
                    key_index=[0, 1],
                ),
                PNGReader(
                    path=root_annotation_dir / "test",
                    sort_method=["text", "numeric"],
                    delimiter="_",
                    key_index=[0, 1],
                ),
            ],
            transforms=[
                TransformPipeline(
                    [  # Transforms for image reader (TIFF)
                        Transpose([2, 0, 1]),
                        MinMaxNormalize() if normalize else Identity(),
                        PadCrop(
                            target_h_size=img_size[1],
                            target_w_size=img_size[2],
                            padding_mode=padding_mode,
                            constant_values=padding_input_constant,
                            seed=seed,
                        ) if pad_test else Identity(),
                        NumberChannels(num_channels=img_size[0]),
                        CastTo("float32"),
                    ]
                ),
                TransformPipeline(
                    [  # Transforms for label reader (PNG)
                        Unsqueeze(0),
                        CastTo("int32"),
                        Squeeze(0),
                    ]
                ),
            ],
        ),
    )

[3]:
data_module = get_paihaka_data_module(
    root_data_dir=Path("/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/"),
    root_annotation_dir=Path("/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/"),
    img_size=(3, 1006, 590),
    batch_size=8,
    seed=42,
    padding_mode="reflect",
    pad_test=False,
    normalize=False,
)

print(data_module)
==================================================
             🆔 seam_ai_padding_reflect
==================================================
├── Predict Split: test
└── Dataloader class: <class 'torch.utils.data.dataloader.DataLoader'>
📂 Datasets:
   ├── Train Dataset:
   │      ==================================================
   │                 📂 SimpleDataset Information
   │      ==================================================
   │      📌 Dataset Type: SimpleDataset
   │         └── Reader 0: TiffReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/train' (1121 files)
   │         │     └── Transform: TransformPipeline(transforms=[Transpose(axes=[2, 0, 1]), Identity(), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), NumberChannels(num_channels=3), CastTo(dtype=float32)])
   │         └── Reader 1: PNGReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/train' (1121 files)
   │         │     └── Transform: TransformPipeline(transforms=[Unsqueeze(axis=0), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), CastTo(dtype=int32)])
   │         │
   │         └── Total Readers: 2
   │      ==================================================
   ├── Val Dataset:
   │      ==================================================
   │                 📂 SimpleDataset Information
   │      ==================================================
   │      📌 Dataset Type: SimpleDataset
   │         └── Reader 0: TiffReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/val' (51 files)
   │         │     └── Transform: TransformPipeline(transforms=[Transpose(axes=[2, 0, 1]), Identity(), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), NumberChannels(num_channels=3), CastTo(dtype=float32)])
   │         └── Reader 1: PNGReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/val' (51 files)
   │         │     └── Transform: TransformPipeline(transforms=[Unsqueeze(axis=0), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), CastTo(dtype=int32)])
   │         │
   │         └── Total Readers: 2
   │      ==================================================
   └── Test Dataset:
          ==================================================
                     📂 SimpleDataset Information
          ==================================================
          📌 Dataset Type: SimpleDataset
             └── Reader 0: TiffReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/test' (200 files)
             │     └── Transform: TransformPipeline(transforms=[Transpose(axes=[2, 0, 1]), Identity(), Identity(), NumberChannels(num_channels=3), CastTo(dtype=float32)])
             └── Reader 1: PNGReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/test' (200 files)
             │     └── Transform: TransformPipeline(transforms=[Unsqueeze(axis=0), CastTo(dtype=int32), Squeeze(axis=0)])
             │
             └── Total Readers: 2
          ==================================================

🛠 **Dataloader Configurations:**
   ├── Train Dataloader Kwargs:
         ├── batch_size: 8
         ├── num_workers: 0
         ├── shuffle: true
         ├── drop_last: true
   ├── Val Dataloader Kwargs:
         ├── batch_size: 8
         ├── num_workers: 0
         ├── shuffle: false
         ├── drop_last: true
   └── Test Dataloader Kwargs:
         ├── batch_size: 8
         ├── num_workers: 0
         ├── shuffle: false
         ├── drop_last: false
==================================================
[4]:
from minerva.models.nets.image.deeplabv3 import DeepLabV3
import lightning as L
import torch
from minerva.models.loaders import FromPretrained
from minerva.losses.weighted_dice_loss import WeightedDiceLoss


class DeepLabV3Instantiator(ModelInstantiator):
    def __init__(
        self,
        num_classes: int = 6,
        epochs: int = 100,
        loss: str = "ce",
        optimizer="adamw",
        learning_rate=1e-4,
    ):
        self.num_classes = num_classes
        self.epochs = epochs
        self.loss = loss
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        if self.loss == "ce":
            self.loss_fn = torch.nn.CrossEntropyLoss()
        elif self.loss == "wdice":
            self.loss_fn = WeightedDiceLoss(num_classes=num_classes)
        else:
            raise ValueError(f"Loss {self.loss} not supported")

    def create_model_randomly_initialized(self) -> L.LightningModule:
        return DeepLabV3(
            num_classes=self.num_classes,
            loss_fn=self.loss_fn,
            epochs=self.epochs,
            optimizer=self.optimizer,
            learning_rate=self.learning_rate,
        )

    def create_model_and_load_backbone(self, backbone_checkpoint_path):
        model = self.create_model_randomly_initialized()
        model = FromPretrained(
            model,
            ckpt_path=backbone_checkpoint_path,
            strict=False,
            ckpt_key=None,
            keys_to_rename={"": "backbone.RN50model."},
            error_on_missing_keys=False,
        )
        return model

    def load_model_from_checkpoint(
        self, checkpoint_path: PathLike
    ) -> L.LightningModule:
        model = self.create_model_randomly_initialized()
        return FromPretrained(model, ckpt_path=checkpoint_path, strict=False)


deeplabv3_config = ModelConfig(
    instantiator=DeepLabV3Instantiator(
        num_classes=6,
        optimizer="adam+lr_scheduler",
        loss="wdice",
        epochs=100,
        learning_rate=1e-5,
    ),
    information=ModelInformation(
        name="DeepLabV3",
        input_shape=(3, 1006, 590),
        output_shape=(6, 1006, 590),
        num_classes=6,
        return_logits=True,
    ),
)

print(deeplabv3_config)
/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
<minerva.pipelines.experiment.ModelConfig object at 0x6ffe8affe770>
[5]:
from torchmetrics import JaccardIndex

experiment = Experiment(
    experiment_name="test",
    model_config=deeplabv3_config,
    data_module=data_module,

    root_log_dir="test_logs",
    max_epochs=3,
    seed=42,
    checkpoint_metrics=[
        {"mode": "min", "monitor": "train_loss", "filename": "min_train_loss"},
        {"mode": "min", "monitor": "val_loss", "filename": "min_val_loss"},
    ],

    limit_train_batches=10,
    limit_val_batches=10,

    evaluation_metrics={
        "miou-macro": JaccardIndex(task="multiclass", num_classes=6, average="macro"),
        "miou-micro": JaccardIndex(task="multiclass", num_classes=6, average="micro"),
        "miou-weighted": JaccardIndex(task="multiclass", num_classes=6, average="weighted"),
    },
    save_predictions=False
)

print(experiment)
================================================================================
                              🚀 Experiment: test 🚀
================================================================================

🛠 **Execution Details:**
   ├── Execution ID: 0
   ├── Log Dir: test_logs/DeepLabV3/seam_ai_padding_reflect/test/0
   ├── Seed: 42
   ├── Accelerator: gpu
   ├── Devices: 1
   └── Max Epochs: 3

🧠 **Model Information:**
   ├── Model Name: DeepLabV3
   ├── Input Shape: (3, 1006, 590)
   ├── Output Shape: (6, 1006, 590)
   └── Num Classes: 6

📂 **Dataset Information:**
      ==================================================
                   🆔 seam_ai_padding_reflect
      ==================================================
      ├── Predict Split: test
      └── Dataloader class: <class 'torch.utils.data.dataloader.DataLoader'>
      📂 Datasets:
         ├── Train Dataset:
         │      ==================================================
         │                 📂 SimpleDataset Information
         │      ==================================================
         │      📌 Dataset Type: SimpleDataset
         │         └── Reader 0: TiffReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/train' (1121 files)
         │         │     └── Transform: TransformPipeline(transforms=[Transpose(axes=[2, 0, 1]), Identity(), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), NumberChannels(num_channels=3), CastTo(dtype=float32)])
         │         └── Reader 1: PNGReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/train' (1121 files)
         │         │     └── Transform: TransformPipeline(transforms=[Unsqueeze(axis=0), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), CastTo(dtype=int32)])
         │         │
         │         └── Total Readers: 2
         │      ==================================================
         ├── Val Dataset:
         │      ==================================================
         │                 📂 SimpleDataset Information
         │      ==================================================
         │      📌 Dataset Type: SimpleDataset
         │         └── Reader 0: TiffReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/val' (51 files)
         │         │     └── Transform: TransformPipeline(transforms=[Transpose(axes=[2, 0, 1]), Identity(), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), NumberChannels(num_channels=3), CastTo(dtype=float32)])
         │         └── Reader 1: PNGReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/val' (51 files)
         │         │     └── Transform: TransformPipeline(transforms=[Unsqueeze(axis=0), PadCrop(target_h_size=1006, target_w_size=590, padding_mode=reflect, constant_values=0, seed=42), CastTo(dtype=int32)])
         │         │
         │         └── Total Readers: 2
         │      ==================================================
         └── Test Dataset:
                ==================================================
                           📂 SimpleDataset Information
                ==================================================
                📌 Dataset Type: SimpleDataset
                   └── Reader 0: TiffReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images/test' (200 files)
                   │     └── Transform: TransformPipeline(transforms=[Transpose(axes=[2, 0, 1]), Identity(), Identity(), NumberChannels(num_channels=3), CastTo(dtype=float32)])
                   └── Reader 1: PNGReader at '/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations/test' (200 files)
                   │     └── Transform: TransformPipeline(transforms=[Unsqueeze(axis=0), CastTo(dtype=int32), Squeeze(axis=0)])
                   │
                   └── Total Readers: 2
                ==================================================

      🛠 **Dataloader Configurations:**
         ├── Train Dataloader Kwargs:
               ├── batch_size: 8
               ├── num_workers: 0
               ├── shuffle: true
               ├── drop_last: true
         ├── Val Dataloader Kwargs:
               ├── batch_size: 8
               ├── num_workers: 0
               ├── shuffle: false
               ├── drop_last: true
         └── Test Dataloader Kwargs:
               ├── batch_size: 8
               ├── num_workers: 0
               ├── shuffle: false
               ├── drop_last: false
      ==================================================

[6]:
experiment.cleanup()
Experiment at 'test_logs/DeepLabV3/seam_ai_padding_reflect/test/0' cleaned up.
[7]:
experiment.status
[7]:
{'experiment_name': 'test',
 'log_dir': PosixPath('test_logs/DeepLabV3/seam_ai_padding_reflect/test/0'),
 'checkpoints': {},
 'training_metrics': None,
 'prediction_paths': {},
 'results_paths': {},
 'state': 'not executed'}
[ ]:
result = experiment.run(task="fit-evaluate")
** Seed set to: 42 **
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

================================================================================
                               Experiment: test
================================================================================
🧠 Model
   ├── Name: DeepLabV3
   ├── Finetune: No
   ├── Resumed From: Beginning
   ├── Expected Input Shape: (3, 1006, 590)
   ├── Expected Output Shape: (6, 1006, 590)
   ├── Total Params: 41,684,014
   └── Trainable Params: 41,684,014 (100.00%)

📊 Dataset
   ├── Train Samples: 1121
   |   ├── Input Shape: (3, 1006, 590)
   |   └── Label Shape: (1, 1006, 590)
   └── Val Samples: 51
       ├── Input Shape: (3, 1006, 590)
       └── Label Shape: (1, 1006, 590)

💾 Logging & Checkpoints
   ├── Log Dir: test_logs/DeepLabV3/seam_ai_padding_reflect/test/0
   ├── Metrics Path: test_logs/DeepLabV3/seam_ai_padding_reflect/test/0/metrics.csv
   └── Checkpoints Dir: test_logs/DeepLabV3/seam_ai_padding_reflect/test/0/checkpoints
       └── Files: min_train_loss.ckpt, min_val_loss.ckpt, last.ckpt

⚙️ Trainer Config
   ├── Max Epochs: 3
   ├── Train Batches: 10
   ├── Accelerator: gpu
   ├── Strategy: auto
   ├── Devices: 1
   ├── Num Nodes: 1
   └── Seed: 42

  | Name     | Type                    | Params | Mode
-------------------------------------------------------------
0 | backbone | DeepLabV3Backbone       | 25.6 M | train
1 | fc       | DeepLabV3PredictionHead | 16.1 M | train
2 | loss_fn  | WeightedDiceLoss        | 0      | train
-------------------------------------------------------------
41.7 M    Trainable params
0         Non-trainable params
41.7 M    Total params
166.736   Total estimated model params size (MB)
186       Modules in train mode
0         Modules in eval mode
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0: 100%|██████████| 10/10 [00:20<00:00,  0.50it/s, v_num=0, val_loss=0.863, train_loss=0.842]
The Kernel crashed while executing code in the current cell or a previous cell.

Please review the code in the cell(s) to identify a possible cause of the failure.

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info.

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
[ ]:
experiment.status