#!/usr/bin/env python3
import uuid
import pytorch_lightning as pl
from dask_pytorch_ddp.results import DaskResultsHandler
from torch.utils.data import DataLoader
from dasf.ml.dl.clusters import DaskClusterEnvironment
from dasf.transforms.base import Fit
from dasf.utils.funcs import (
get_dask_gpu_count,
get_dask_running_client,
get_gpu_count,
get_worker_info,
sync_future_loop,
)
[docs]
class TorchDataLoader(pl.LightningDataModule):
def __init__(self, train, val=None, test=None, batch_size=64):
super().__init__()
self._train = train
self._val = val
self._test = test
self._batch_size = batch_size
[docs]
def prepare_data(self):
if self._train is not None and hasattr(self._train, "download"):
self._train.download()
if self._val is not None and hasattr(self._val, "download"):
self._val.download()
if self._test is not None and hasattr(self._test, "download"):
self._test.download()
[docs]
def setup(self, stage=None):
if self._train is not None and hasattr(self._train, "load"):
self._train.load()
if self._val is not None and hasattr(self._val, "load"):
self._val.load()
if self._test is not None and hasattr(self._test, "load"):
self._test.load()
[docs]
def train_dataloader(self):
return DataLoader(self._train, batch_size=self._batch_size)
[docs]
def val_dataloader(self):
return DataLoader(self._val, batch_size=self._batch_size)
[docs]
def test_dataloader(self):
return DataLoader(self._test, batch_size=self._batch_size)
[docs]
def run_dask_clustered(func, client=None, **kwargs):
if client is None:
client = get_dask_running_client()
all_workers = get_worker_info(client)
for worker in all_workers:
# Including worker metadata into kwargs
kwargs['meta'] = worker
futures = client.submit(func, **kwargs, workers=[worker["worker"]])
sync_future_loop(futures)
[docs]
def fit(model, X, y, max_iter, accel, strategy, devices, ngpus, batch_size=32,
plugins=None, meta=None):
if meta is None:
plugin = DaskClusterEnvironment(metadata=meta)
nodes = plugin.world_size()
if plugins is None:
plugins = list()
plugins.append(plugin)
else:
nodes = 1
# Use it for heterogeneous workers.
if ngpus < 0:
ngpus = -1
dataloader = TorchDataLoader(train=X, val=y, batch_size=batch_size)
trainer = pl.Trainer(
max_epochs=max_iter,
accelerator=accel,
strategy=strategy,
gpus=ngpus,
plugins=plugins,
devices=devices,
num_nodes=nodes,
)
trainer.fit(model, datamodule=dataloader)
[docs]
class NeuralNetClassifier(Fit):
def __init__(self, model, max_iter=100, batch_size=32):
self._model = model
self._accel = None
self._strategy = None
self._max_iter = max_iter
self._devices = 0
self._ngpus = 0
self._batch_size = batch_size
self.__trainer = False
self.__handler = DaskResultsHandler(uuid.uuid4().hex)
[docs]
def _lazy_fit_generic(self, X, y, accel, ngpus):
self._accel = accel
self._strategy = "ddp"
self._ngpus = self._ndevices = ngpus
plugins = [DaskClusterEnvironment()]
run_dask_clustered(
fit,
model=self._model,
X=X,
y=y,
max_iter=self._max_iter,
accel=self._accel,
strategy=self._strategy,
devices=self._ndevices,
ngpus=self._ngpus,
batch_size=self._batch_size,
plugins=plugins,
)
[docs]
def _lazy_fit_gpu(self, X, y=None):
self._lazy_fit_generic(X=X, y=y, accel="gpu", ngpus=get_dask_gpu_count())
[docs]
def _lazy_fit_cpu(self, X, y=None):
self._lazy_fit_generic(X=X, y=y, accel="cpu", ngpus=get_dask_gpu_count())
def __fit_generic(self, X, y, accel, ngpus):
self._accel = accel
self._strategy = "dp"
self._ngpus = self._ndevices = ngpus
dataloader = TorchDataLoader(train=X, val=y, batch_size=self._batch_size)
self.__trainer = pl.Trainer(
max_epochs=self._max_iter, accelerator=accel, devices=ngpus
)
self.__trainer.fit(self._model, datamodule=dataloader)
[docs]
def _fit_gpu(self, X, y=None):
self.__fit_generic(X, y, "gpu", get_gpu_count())
[docs]
def _fit_cpu(self, X, y=None):
self.__fit_generic(X, y, "cpu", 0)