#!/usr/bin/env python3
import inspect
from typing import List
import graphviz
import networkx as nx
from dasf.utils.logging import init_logging
[docs]
class PipelinePlugin:
[docs]
def on_pipeline_start(self, fn_keys):
pass
[docs]
def on_pipeline_end(self):
pass
[docs]
def on_task_start(self, func, params, name):
pass
[docs]
def on_task_end(self, func, params, name, ret):
pass
[docs]
def on_task_error(self, func, params, name, exception):
pass
[docs]
class Pipeline:
def __init__(self,
name,
executor=None,
verbose=False,
callbacks: List[PipelinePlugin] = None):
from dasf.pipeline.executors.wrapper import LocalExecutor
self._name = name
self._executor = executor if executor is not None else LocalExecutor()
self._verbose = verbose
self._dag = nx.DiGraph()
self._dag_table = dict()
self._dag_g = graphviz.Digraph(name, format="png")
self._logger = init_logging()
self._callbacks = callbacks or []
[docs]
def register_plugin(self, plugin):
if isinstance(plugin, PipelinePlugin):
self._callbacks.append(plugin)
else:
self._executor.register_plugin(plugin)
[docs]
def info(self):
print(self._executor.info)
[docs]
def execute_callbacks(self, func_name: str, *args, **kwargs):
for callback in self._callbacks:
getattr(callback, func_name)(*args, **kwargs)
def __add_into_dag(self, obj, func_name, parameters=None, itself=None):
key = hash(obj)
if key not in self._dag_table:
self._dag.add_node(key)
self._dag_table[key] = dict()
self._dag_table[key]["fn"] = obj
self._dag_table[key]["name"] = func_name
self._dag_table[key]["parameters"] = None
self._dag_table[key]["ret"] = None
if parameters and isinstance(parameters, dict):
if self._dag_table[key]["parameters"] is None:
self._dag_table[key]["parameters"] = parameters
else:
self._dag_table[key]["parameters"].update(parameters)
# If we are adding a object which require parameters,
# we need to make sure they are mapped into DAG.
for k, v in parameters.items():
dep_obj, dep_func_name, _ = self.__inspect_element(v)
self.add(dep_obj)
if not self._dag.has_node(str(key)):
self._dag_g.node(str(key), func_name)
if not self._dag.has_node(str(hash(dep_obj))):
self._dag_g.node(str(hash(dep_obj)), dep_func_name)
self._dag.add_edge(hash(dep_obj), key)
self._dag_g.edge(str(hash(dep_obj)), str(key), label=k)
def __inspect_element(self, obj):
from dasf.datasets.base import Dataset
from dasf.ml.inference.loader.base import BaseLoader
from dasf.transforms.base import Fit, Transform
def generate_name(class_name, func_name):
return ("%s.%s" % (class_name, func_name))
if inspect.isfunction(obj) and callable(obj):
return (obj,
obj.__qualname__,
None)
elif inspect.ismethod(obj):
return (obj,
generate_name(obj.__self__.__class__.__name__,
obj.__name__),
obj.__self__)
elif issubclass(obj.__class__, Dataset) and hasattr(obj, "load"):
# (Disabled) Register dataset for reusability
# obj = self.__register_dataset(obj)
return (obj.load,
generate_name(obj.__class__.__name__,
"load"),
obj)
elif issubclass(obj.__class__, Fit) and hasattr(obj, "fit"):
return (obj.fit,
generate_name(obj.__class__.__name__,
"fit"),
obj)
elif issubclass(obj.__class__, BaseLoader) and hasattr(obj, "load"):
return (obj.load,
generate_name(obj.__class__.__name__,
"load"),
obj)
elif issubclass(obj.__class__, Transform) and hasattr(obj, "transform"):
return (obj.transform,
generate_name(obj.__class__.__name__,
"transform"),
obj)
else:
raise ValueError(
f"This object {obj.__class__.__name__} is not a function, "
"method or a transformer object."
)
[docs]
def add(self, obj, **kwargs):
obj, func_name, objref = self.__inspect_element(obj)
self.__add_into_dag(obj, func_name, kwargs, objref)
return self
[docs]
def visualize(self, filename=None):
from dasf.utils.funcs import is_notebook
if is_notebook():
return self._dag_g
return self._dag_g.view(filename)
def __register_dataset(self, dataset):
key = str(hash(dataset.load))
kwargs = {key: dataset}
if not self._executor.has_dataset(key):
return self._executor.register_dataset(**kwargs)
return self._executor.get_dataset(key)
def __execute(self, func, params, name):
ret = None
new_params = dict()
if params:
for k, v in params.items():
dep_obj, *_ = self.__inspect_element(v)
req_key = hash(dep_obj)
new_params[k] = self._dag_table[req_key]["ret"]
if len(new_params) > 0:
ret = self._executor.execute(fn=func, **new_params)
else:
ret = self._executor.execute(fn=func)
return ret
[docs]
def get_result_from(self, obj):
_, obj_name, *_ = self.__inspect_element(obj)
for key in self._dag_table:
if self._dag_table[key]["name"] == obj_name:
if self._dag_table[key]["ret"] is None:
raise Exception("Pipeline was not executed yet.")
return self._dag_table[key]["ret"]
raise Exception(f"Function {obj_name} was not added into pipeline.")
[docs]
def run(self):
if not nx.is_directed_acyclic_graph(self._dag):
raise Exception("Pipeline has not a DAG format.")
if not hasattr(self._executor, "execute"):
raise Exception(
f"Executor {self._executor.__class__.__name__} has not a execute() "
"method."
)
if not self._executor.is_connected:
raise Exception("Executor is not connected.")
fn_keys = list(nx.topological_sort(self._dag))
self._logger.info(f"Beginning pipeline run for '{self._name}'")
self.execute_callbacks("on_pipeline_start", fn_keys)
self._executor.pre_run(self)
ret = None
failed = False
for fn_key in fn_keys:
func = self._dag_table[fn_key]["fn"]
params = self._dag_table[fn_key]["parameters"]
name = self._dag_table[fn_key]["name"]
if not failed:
self._logger.info(f"Task '{name}': Starting task run...")
else:
self._logger.error(f"Task '{name}': Starting task run...")
try:
if not failed:
# Execute DAG node only if there is no error during the
# execution. Otherwise, skip it.
self.execute_callbacks("on_task_start", func=func,
params=params, name=name)
result = self.__execute(func, params, name)
self._dag_table[fn_key]["ret"] = result
self.execute_callbacks("on_task_end", func=func,
params=params, name=name,
ret=result)
except Exception as e:
self.execute_callbacks("on_task_error", func=func,
params=params, name=name, exception=e)
failed = True
failed_at = name
err = str(e)
self._logger.exception(f"Task '{name}': Failed with:\n{err}")
if not failed:
self._logger.info(f"Task '{name}': Finished task run")
else:
self._logger.error(f"Task '{name}': Finished task run")
if failed:
self._logger.info(f"Pipeline failed at '{failed_at}'")
else:
self._logger.info("Pipeline run successfully")
self._executor.post_run(self)
self.execute_callbacks("on_pipeline_end")
return ret