Source code for dasf.ml.inference.loader.base

from dask.distributed import Worker

from dasf.utils.decorators import task_handler
from dasf.utils.funcs import get_dask_running_client


[docs] class BaseLoader: """ BaseLoader for DL models. When running in a Dask Cluster instantiates a model per worker that will be reused on every subsequent prediction task. """ def __init__(self): self.model_instances = {} def inference(self, model, data): raise NotImplementedError("Inference must be implemented")
[docs] def load_model(self): """ Load Model method is specific for each framework/model. """ raise NotImplementedError("Load Model must be implemented")
[docs] def load_model_distributed(self, **kwargs): """ Distributed model instantiation """ try: Worker.model = self.load_model(**kwargs) return "UP" except: return "DOWN"
[docs] def _lazy_load(self, **kwargs): client = get_dask_running_client() self.model_instances = {} if client: worker_addresses = list(client.scheduler_info()["workers"].keys()) self.model_instances = client.run( self.load_model_distributed, **kwargs, workers=worker_addresses )
[docs] def _load(self, **kwargs): self.model_instances = {"local": self.load_model(**kwargs)}
[docs] def _lazy_load_cpu(self, **kwargs): if not (hasattr(self, "device") and self.device): self.device = "cpu" self._lazy_load(**kwargs)
[docs] def _lazy_load_gpu(self, **kwargs): if not (hasattr(self, "device") and self.device): self.device = "gpu" self._lazy_load(**kwargs)
[docs] def _load_cpu(self, **kwargs): if not (hasattr(self, "device") and self.device): self.device = "cpu" self._load(**kwargs)
[docs] def _load_gpu(self, **kwargs): if not (hasattr(self, "device") and self.device): self.device = "gpu" self._load(**kwargs)
[docs] @task_handler def load(self, **kwargs): ...
[docs] def predict(self, data): """ Predict method called on prediction tasks. """ if not self.model_instances: raise RuntimeError( "Models have not been loaded. load method must be executed beforehand." ) if "local" in self.model_instances: model = self.model_instances["local"] else: model = Worker.model data = self.preprocessing(data) output = self.inference(model, data) return self.postprocessing(output)
[docs] def preprocessing(self, data): """ Preprocessing stage which is called before inference """ return data
[docs] def inference(self, model, data): """ Inference method, receives model and input data """ raise NotImplementedError("Inference must be implemented")
[docs] def postprocessing(self, data): """ Postprocessing stage which is called after inference """ return data