Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Patch Inference #24

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added dasf/ml/inference/__init__.py
Empty file.
Empty file.
102 changes: 102 additions & 0 deletions dasf/ml/inference/loader/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from dask.distributed import Worker

from dasf.utils.funcs import get_dask_running_client
from dasf.utils.decorators import task_handler


class BaseLoader:
"""
BaseLoader for DL models. When running in a Dask Cluster instantiates a model per worker that will be reused on every subsequent prediction task.
"""

def __init__(self):
self.model_instances = {}

def inference(self, model, data):
raise NotImplementedError("Inference must be implemented")

def load_model(self):
"""
Load Model method is specific for each framework/model.
"""
raise NotImplementedError("Load Model must be implemented")

def load_model_distributed(self, **kwargs):
"""
Distributed model instantiation
"""
try:
Worker.model = self.load_model(**kwargs)
return "UP"
except:
return "DOWN"

def _lazy_load(self, **kwargs):
client = get_dask_running_client()
self.model_instances = {}
if client:
worker_addresses = list(client.scheduler_info()["workers"].keys())
self.model_instances = client.run(
self.load_model_distributed, **kwargs, workers=worker_addresses
)

def _load(self, **kwargs):
self.model_instances = {"local": self.load_model(**kwargs)}

def _lazy_load_cpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "cpu"
self._lazy_load(**kwargs)

def _lazy_load_gpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "gpu"
self._lazy_load(**kwargs)

def _load_cpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "cpu"
self._load(**kwargs)

def _load_gpu(self, **kwargs):
if not (hasattr(self, "device") and self.device):
self.device = "gpu"
self._load(**kwargs)

@task_handler
def load(self, **kwargs):
...

def predict(self, data):
"""
Predict method called on prediction tasks.
"""
if not self.model_instances:
raise RuntimeError(
"Models have not been loaded. load method must be executed beforehand."
)
if "local" in self.model_instances:
model = self.model_instances["local"]
else:
model = Worker.model
data = self.preprocessing(data)
output = self.inference(model, data)
return self.postprocessing(output)

def preprocessing(self, data):
"""
Preprocessing stage which is called before inference
"""
return data

def inference(self, model, data):
"""
Inference method, receives model and input data
"""
raise NotImplementedError("Inference must be implemented")

def postprocessing(self, data):
"""
Postprocessing stage which is called after inference
"""
return data
56 changes: 56 additions & 0 deletions dasf/ml/inference/loader/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import inspect
import os
import torch

from .base import BaseLoader


class TorchLoader(BaseLoader):
"""
Model Loader for Torch models
"""

def __init__(
self, model_class_or_file, dtype=torch.float32, checkpoint=None, device=None
):
"""
model_class_or_file: class or file with model definition
dtype: data type of model input
checkpoint: model chekpoint file
device: device to place model ("cpu" or "gpu")
"""
super().__init__()
self.model_class_or_file = model_class_or_file
self.dtype = dtype
self.checkpoint = checkpoint
self.device = device

def load_model(self, **kwargs):
device = torch.device("cuda" if self.device == "gpu" else "cpu")
if inspect.isclass(self.model_class_or_file):
model = self.model_class_or_file(**kwargs)
if self.checkpoint:
state_dict = torch.load(self.checkpoint, map_location=device)
state_dict = (
state_dict["state_dict"]
if "state_dict" in state_dict
else state_dict
) # In case model was saved by TensorBoard
model.load_state_dict(state_dict)
elif os.path.isfile(self.model_class_or_file):
model = torch.load(self.model_class_or_file)
else:
raise ValueError(
"model_class_or_file must be a model class or path to model file"
)
model.to(device=device)
model.eval()
return model

def inference(self, model, data):
data = torch.from_numpy(data)
device = torch.device("cuda" if self.device == "gpu" else "cpu")
data = data.to(device, dtype=self.dtype)
with torch.no_grad():
output = model(data)
return output.cpu().numpy() if self.device == "gpu" else output.numpy()
6 changes: 6 additions & 0 deletions dasf/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __add_into_dag(self, obj, func_name, parameters=None, itself=None):
def __inspect_element(self, obj):
from dasf.datasets.base import Dataset
from dasf.transforms.base import Transform, Fit
from dasf.ml.inference.loader.base import BaseLoader

def generate_name(class_name, func_name):
return ("%s.%s" % (class_name, func_name))
Expand All @@ -119,6 +120,11 @@ def generate_name(class_name, func_name):
generate_name(obj.__class__.__name__,
"fit"),
obj)
elif issubclass(obj.__class__, BaseLoader) and hasattr(obj, "load"):
return (obj.load,
generate_name(obj.__class__.__name__,
"load"),
obj)
elif issubclass(obj.__class__, Transform) and hasattr(obj, "transform"):
return (obj.transform,
generate_name(obj.__class__.__name__,
Expand Down
Loading
Loading