dasf.ml.dl
Init module for Deep Learning algorithms.
Subpackages
Submodules
Classes
Initialize the LightningFit class. |
|
Class representing a Fit operation of the pipeline. |
Package Contents
- class dasf.ml.dl.LightningTrainer(model, use_gpu=False, batch_size=1, max_epochs=1, limit_train_batches=None, limit_val_batches=None, devices='auto', num_nodes=1, shuffle=True, strategy='ddp', unsqueeze_dim=None)[source]
Initialize the LightningFit class.
Parameters
- modelLightningModule
The LightningModule instance representing the model to be trained.
- use_gpubool, optional
Flag indicating whether to use GPU for training, by default False.
- batch_sizeint, optional
The batch size for training, by default 1.
- max_epochsint, optional
The maximum number of epochs for training, by default 1.
- limit_train_batchesint, optional
The number of batches to consider for training, by default None.
- limit_val_batchesint, optional
The number of batches to consider for validation, by default None.
- devicesint, optional
The number of devices to use for training, by default “auto”.
- num_nodesint, optional
The number of nodes to use for distributed training, by default 1.
- shufflebool, optional
Flag indicating whether to shuffle the data during training, by default True.
- strategystr, optional
The strategy to use for distributed training, by default “ddp”.
- unsqueeze_dimint, optional
The dimension to unsqueeze the input data, by default None.
- model
- accelerator
- batch_size
- max_epochs
- limit_train_batches
- limit_val_batches
- devices
- num_nodes
- shuffle
- strategy
- unsqueeze_dim
- fit(train_data, val_data=None)[source]
Perform the training of the model using torch Lightning.
Parameters
- train_dataAny
A dasf map-style like dataset containing the training data.
- val_dataAny, optional
A dasf map-style like dataset containing the validation data.
- Parameters:
train_data (Any)
val_data (Any)
- Parameters:
use_gpu (bool)
batch_size (int)
max_epochs (int)
limit_train_batches (int)
limit_val_batches (int)
devices (int)
num_nodes (int)
shuffle (bool)
strategy (str)
unsqueeze_dim (int)
- class dasf.ml.dl.NeuralNetClassifier(model, max_iter=100, batch_size=32)[source]
Bases:
dasf.transforms.base.Fit
Class representing a Fit operation of the pipeline.
- _model
- _accel = None
- _strategy = None
- _max_iter
- _devices = 0
- _ngpus = 0
- _batch_size
- __trainer = False
- __handler
- __fit_generic(X, y, accel, ngpus)