Source code for dasf.profile.plugins

import os
import socket
import time
from typing import Any

import nvtx
import pynvml
from dask.distributed.compatibility import PeriodicCallback
from dask.distributed.system_monitor import SystemMonitor
from distributed.diagnostics.plugin import WorkerPlugin
from pynvml import *

from dasf.profile.profiler import EventProfiler


[docs] class WorkerTaskPlugin(WorkerPlugin): def __init__( self, name: str = "TracePlugin", ): self.name = name
[docs] def setup(self, worker): self.worker = worker self.hostname = socket.gethostname() self.worker_id = f"worker-{self.hostname}-{self.worker.name}" self.database = EventProfiler( database_file=f"{self.name}-{self.hostname}.msgpack", )
[docs] def transition(self, key, start, finish, *args, **kwargs): now = time.monotonic() if finish == "memory": # Get the last compute event startstops = next( ( x for x in reversed(self.worker.state.tasks[key].startstops) if x["action"] == "compute" ), None, ) if startstops is not None: # Add information about the task execution shape = tuple() dtype = "unknown" if hasattr(self.worker.data[key], "shape"): if isinstance(getattr(self.worker.data[key], "shape"), tuple): shape = getattr(self.worker.data[key], "shape") if hasattr(self.worker.data[key], "dtype"): dtype = str(getattr(self.worker.data[key], "dtype")) task = self.worker.state.tasks[key] nbytes = task.nbytes or 0 self.database.record_complete_event( name="Compute", timestamp=now, # TODO check startstop returning None duration=startstops["stop"] - startstops["start"], process_id=self.hostname, thread_id=self.worker_id, args={ "key": key, "name": "-".join(key.split(",")[0][2:-1].split("-")[:-1]), "state": finish, "size": nbytes, "shape": shape, "dtype": dtype, "type": str(type(self.worker.data[key])), "dependencies": [dep.key for dep in task.dependencies], "dependents": [dep.key for dep in task.dependents], }, ) if finish == "memory" or finish == "erred": # Additionally add the total in-memory tasks self.database.record_instant_event( name="Managed Memory", timestamp=now, process_id=self.hostname, thread_id=self.worker_id, args={ "key": key, "state": finish, "size": self.worker.state.nbytes, "tasks": len(self.worker.data), } )
[docs] class ResourceMonitor: def __init__(self, time = 100, autostart: bool = True, name: str = "ResourceMonitor", **monitor_kwargs): self.time = time self.name = name self.hostname = socket.gethostname() self.database = EventProfiler( database_file=f"{self.name}-{self.hostname}.msgpack", ) self.monitor = SystemMonitor(**monitor_kwargs) self.callback = PeriodicCallback(self.update, callback_time=self.time) if autostart: self.start()
[docs] def __del__(self): self.stop()
[docs] def update(self): res = self.monitor.update() self.database.record_instant_event( name="Resource Usage", timestamp=time.monotonic(), process_id=self.hostname, thread_id=None, args=res ) return res
[docs] def start(self): self.callback.start()
[docs] def stop(self): self.database.commit() self.callback.stop()
[docs] class GPUAnnotationPlugin(WorkerPlugin): def __init__( self, name: str = "GPUAnnotationPlugin", ): self.name = name self.gpu_num = None self.marks = {}
[docs] def setup(self, worker): self.worker = worker self.gpu_num = int(os.environ['CUDA_VISIBLE_DEVICES'].split(",")[0]) print(f"Setting up GPU annotation plugin for worker {self.worker.name}. GPU: {self.gpu_num}")
[docs] def transition(self, key, start, finish, *args, **kwargs): if finish == "executing": handle = pynvml.nvmlDeviceGetHandleByIndex(self.gpu_num) mark = nvtx.start_range(message=key, domain="compute") self.marks[key] = mark if start == "executing": handle = pynvml.nvmlDeviceGetHandleByIndex(self.gpu_num) nvtx.end_range(self.marks[key]) del self.marks[key]