Source code for dasf.utils.decorators

""" Implementations of important library decorators. """
#!/usr/bin/env python3

from functools import wraps

from dasf.utils.funcs import (
    get_dask_running_client,
    is_dask_gpu_supported,
    is_dask_supported,
    is_gpu_supported,
)
from dasf.utils.types import is_dask_array, is_gpu_array


[docs] def is_forced_local(cls): """ Returns if object is forced to run in a CPU. """ # pylint: disable=protected-access if hasattr(cls, "_run_local") and cls._run_local is not None: # pylint: disable=protected-access return cls._run_local return None
[docs] def is_forced_gpu(cls): """ Returns if object is forced to run in a GPU. """ # pylint: disable=protected-access if hasattr(cls, "_run_gpu") and cls._run_gpu is not None: # pylint: disable=protected-access return cls._run_gpu return None
[docs] def fetch_from_dask(*args, **kwargs) -> tuple: """ Fetches to CPU all parameters in a Dask data type. """ new_kwargs = {} new_args = [] for key, value in kwargs.items(): if is_dask_array(value): new_kwargs[key] = value.compute() else: new_kwargs[key] = value for value in args: if is_dask_array(value): new_args.append(value.compute()) else: new_args.append(value) return new_args, new_kwargs
[docs] def fetch_from_gpu(*args, **kwargs) -> tuple: """ Fetches to CPU all parameters in a GPU data type. """ new_kwargs = {} new_args = [] for key, value in kwargs.items(): if is_gpu_array(value): new_kwargs[key] = value.get() else: new_kwargs[key] = value for value in args: if is_gpu_array(value): new_args.append(value.get()) else: new_args.append(value) return new_args, new_kwargs
[docs] def fetch_args_from_dask(func): """ Fetches to CPU all function parameters in a Dask data type. """ def wrapper(*args, **kwargs): """ Wrapper to fetch parameters from Dask data. """ new_args, new_kwargs = fetch_from_dask(*args, **kwargs) return func(*new_args, **new_kwargs) return wrapper
[docs] def fetch_args_from_gpu(func): """ Fetches to CPU all function parameters in a GPU data type. """ def wrapper(*args, **kwargs): """ Wrapper to fetch parameters from GPU. """ new_args, new_kwargs = fetch_from_gpu(*args, **kwargs) return func(*new_args, **new_kwargs) return wrapper
[docs] def task_handler(func): """ Returns all mapped functions corresponding to the executor in place. """ @wraps(func) def wrapper(*args, **kwargs): """ Wrapper of the function to map the proper object function. """ cls = args[0] new_args = args[1:] func_name = func.__name__ func_type = "" arch = "cpu" client = get_dask_running_client() if client is not None: # Runs task according to current client configuration, i.e, Pipeline Executor func_type = "_lazy" arch = "gpu" if getattr(client, "backend", None) == "cupy" else "cpu" else: if not is_forced_local(cls) and (is_dask_gpu_supported() or is_dask_supported()): func_type = "_lazy" if is_dask_gpu_supported() or is_gpu_supported(): arch = "gpu" if is_forced_local(cls): func_type = "" new_args, kwargs = fetch_from_dask(*new_args, **kwargs) if is_forced_gpu(cls): arch = "gpu" if arch == "cpu": new_args, kwargs = fetch_from_gpu(*new_args, **kwargs) wrapper_func_attr = f"{func_type}_{func_name}_{arch}" if (not hasattr(cls, wrapper_func_attr) and hasattr(cls, func_name)): return func(*new_args, **kwargs) if (not hasattr(cls, wrapper_func_attr) and not hasattr(cls, func_name)): raise NotImplementedError( f"There is no implementation of {wrapper_func_attr} nor " f"{func_name}" ) return getattr(cls, wrapper_func_attr)(*new_args, **kwargs) return wrapper