Skip to content

Commit

Permalink
Merge pull request #24 from discovery-unicamp/feature/patch-inference
Browse files Browse the repository at this point in the history
Feature/Patch Inference

Signed-off-by: Julio Faracco <[email protected]>
  • Loading branch information
jcfaracco authored Feb 29, 2024
2 parents 4e3915d + e0a09e5 commit a167057
Show file tree
Hide file tree
Showing 9 changed files with 789 additions and 8 deletions.
Empty file added dasf/ml/inference/__init__.py
Empty file.
Empty file.
102 changes: 102 additions & 0 deletions dasf/ml/inference/loader/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from dask.distributed import Worker

from dasf.utils.funcs import get_dask_running_client
from dasf.utils.decorators import task_handler


class BaseLoader:
"""
BaseLoader for DL models. When running in a Dask Cluster instantiates a model per worker that will be reused on every subsequent prediction task.
"""

def __init__(self):
self.model_instances = {}

def inference(self, model, data):
raise NotImplementedError("Inference must be implemented")

def load_model(self):
"""
Load Model method is specific for each framework/model.
"""
raise NotImplementedError("Load Model must be implemented")

def load_model_distributed(self, **kwargs):
"""
Distributed model instantiation
"""
try:
Worker.model = self.load_model(**kwargs)
return "UP"
except:
return "DOWN"

def _lazy_load(self, **kwargs):
client = get_dask_running_client()
self.model_instances = {}
if client:
worker_addresses = list(client.scheduler_info()["workers"].keys())
self.model_instances = client.run(
self.load_model_distributed, **kwargs, workers=worker_addresses
)

def _load(self, **kwargs):
self.model_instances = {"local": self.load_model(**kwargs)}

def _lazy_load_cpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "cpu"
self._lazy_load(**kwargs)

def _lazy_load_gpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "gpu"
self._lazy_load(**kwargs)

def _load_cpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "cpu"
self._load(**kwargs)

def _load_gpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "gpu"
self._load(**kwargs)

@task_handler
def load(self, **kwargs):
...

def predict(self, data):
"""
Predict method called on prediction tasks.
"""
if not self.model_instances:
raise RuntimeError(
"Models have not been loaded. load method must be executed beforehand."
)
if "local" in self.model_instances:
model = self.model_instances["local"]
else:
model = Worker.model
data = self.preprocessing(data)
output = self.inference(model, data)
return self.postprocessing(output)

def preprocessing(self, data):
"""
Preprocessing stage which is called before inference
"""
return data

def inference(self, model, data):
"""
Inference method, receives model and input data
"""
raise NotImplementedError("Inference must be implemented")

def postprocessing(self, data):
"""
Postprocessing stage which is called after inference
"""
return data
56 changes: 56 additions & 0 deletions dasf/ml/inference/loader/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import inspect
import os
import torch

from .base import BaseLoader


class TorchLoader(BaseLoader):
"""
Model Loader for Torch models
"""

def __init__(
self, model_class_or_file, dtype=torch.float32, checkpoint=None, device=None
):
"""
model_class_or_file: class or file with model definition
dtype: data type of model input
checkpoint: model chekpoint file
device: device to place model ("cpu" or "gpu")
"""
super().__init__()
self.model_class_or_file = model_class_or_file
self.dtype = dtype
self.checkpoint = checkpoint
self.device = device

def load_model(self, **kwargs):
device = torch.device("cuda" if self.device == "gpu" else "cpu")
if inspect.isclass(self.model_class_or_file):
model = self.model_class_or_file(**kwargs)
if self.checkpoint:
state_dict = torch.load(self.checkpoint, map_location=device)
state_dict = (
state_dict["state_dict"]
if "state_dict" in state_dict
else state_dict
) # In case model was saved by TensorBoard
model.load_state_dict(state_dict)
elif os.path.isfile(self.model_class_or_file):
model = torch.load(self.model_class_or_file)
else:
raise ValueError(
"model_class_or_file must be a model class or path to model file"
)
model.to(device=device)
model.eval()
return model

def inference(self, model, data):
data = torch.from_numpy(data)
device = torch.device("cuda" if self.device == "gpu" else "cpu")
data = data.to(device, dtype=self.dtype)
with torch.no_grad():
output = model(data)
return output.cpu().numpy() if self.device == "gpu" else output.numpy()
6 changes: 6 additions & 0 deletions dasf/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __add_into_dag(self, obj, func_name, parameters=None, itself=None):
def __inspect_element(self, obj):
from dasf.datasets.base import Dataset
from dasf.transforms.base import Transform, Fit
from dasf.ml.inference.loader.base import BaseLoader

def generate_name(class_name, func_name):
return ("%s.%s" % (class_name, func_name))
Expand All @@ -119,6 +120,11 @@ def generate_name(class_name, func_name):
generate_name(obj.__class__.__name__,
"fit"),
obj)
elif issubclass(obj.__class__, BaseLoader) and hasattr(obj, "load"):
return (obj.load,
generate_name(obj.__class__.__name__,
"load"),
obj)
elif issubclass(obj.__class__, Transform) and hasattr(obj, "transform"):
return (obj.transform,
generate_name(obj.__class__.__name__,
Expand Down
Loading

0 comments on commit a167057

Please sign in to comment.