-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from discovery-unicamp/feature/patch-inference
Feature/Patch Inference Signed-off-by: Julio Faracco <[email protected]>
- Loading branch information
Showing
9 changed files
with
789 additions
and
8 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from dask.distributed import Worker | ||
|
||
from dasf.utils.funcs import get_dask_running_client | ||
from dasf.utils.decorators import task_handler | ||
|
||
|
||
class BaseLoader: | ||
""" | ||
BaseLoader for DL models. When running in a Dask Cluster instantiates a model per worker that will be reused on every subsequent prediction task. | ||
""" | ||
|
||
def __init__(self): | ||
self.model_instances = {} | ||
|
||
def inference(self, model, data): | ||
raise NotImplementedError("Inference must be implemented") | ||
|
||
def load_model(self): | ||
""" | ||
Load Model method is specific for each framework/model. | ||
""" | ||
raise NotImplementedError("Load Model must be implemented") | ||
|
||
def load_model_distributed(self, **kwargs): | ||
""" | ||
Distributed model instantiation | ||
""" | ||
try: | ||
Worker.model = self.load_model(**kwargs) | ||
return "UP" | ||
except: | ||
return "DOWN" | ||
|
||
def _lazy_load(self, **kwargs): | ||
client = get_dask_running_client() | ||
self.model_instances = {} | ||
if client: | ||
worker_addresses = list(client.scheduler_info()["workers"].keys()) | ||
self.model_instances = client.run( | ||
self.load_model_distributed, **kwargs, workers=worker_addresses | ||
) | ||
|
||
def _load(self, **kwargs): | ||
self.model_instances = {"local": self.load_model(**kwargs)} | ||
|
||
def _lazy_load_cpu(self, **kwargs): | ||
if not (hasattr(self, "device") and self.device): | ||
self.device = "cpu" | ||
self._lazy_load(**kwargs) | ||
|
||
def _lazy_load_gpu(self, **kwargs): | ||
if not (hasattr(self, "device") and self.device): | ||
self.device = "gpu" | ||
self._lazy_load(**kwargs) | ||
|
||
def _load_cpu(self, **kwargs): | ||
if not (hasattr(self, "device") and self.device): | ||
self.device = "cpu" | ||
self._load(**kwargs) | ||
|
||
def _load_gpu(self, **kwargs): | ||
if not (hasattr(self, "device") and self.device): | ||
self.device = "gpu" | ||
self._load(**kwargs) | ||
|
||
@task_handler | ||
def load(self, **kwargs): | ||
... | ||
|
||
def predict(self, data): | ||
""" | ||
Predict method called on prediction tasks. | ||
""" | ||
if not self.model_instances: | ||
raise RuntimeError( | ||
"Models have not been loaded. load method must be executed beforehand." | ||
) | ||
if "local" in self.model_instances: | ||
model = self.model_instances["local"] | ||
else: | ||
model = Worker.model | ||
data = self.preprocessing(data) | ||
output = self.inference(model, data) | ||
return self.postprocessing(output) | ||
|
||
def preprocessing(self, data): | ||
""" | ||
Preprocessing stage which is called before inference | ||
""" | ||
return data | ||
|
||
def inference(self, model, data): | ||
""" | ||
Inference method, receives model and input data | ||
""" | ||
raise NotImplementedError("Inference must be implemented") | ||
|
||
def postprocessing(self, data): | ||
""" | ||
Postprocessing stage which is called after inference | ||
""" | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import inspect | ||
import os | ||
import torch | ||
|
||
from .base import BaseLoader | ||
|
||
|
||
class TorchLoader(BaseLoader): | ||
""" | ||
Model Loader for Torch models | ||
""" | ||
|
||
def __init__( | ||
self, model_class_or_file, dtype=torch.float32, checkpoint=None, device=None | ||
): | ||
""" | ||
model_class_or_file: class or file with model definition | ||
dtype: data type of model input | ||
checkpoint: model chekpoint file | ||
device: device to place model ("cpu" or "gpu") | ||
""" | ||
super().__init__() | ||
self.model_class_or_file = model_class_or_file | ||
self.dtype = dtype | ||
self.checkpoint = checkpoint | ||
self.device = device | ||
|
||
def load_model(self, **kwargs): | ||
device = torch.device("cuda" if self.device == "gpu" else "cpu") | ||
if inspect.isclass(self.model_class_or_file): | ||
model = self.model_class_or_file(**kwargs) | ||
if self.checkpoint: | ||
state_dict = torch.load(self.checkpoint, map_location=device) | ||
state_dict = ( | ||
state_dict["state_dict"] | ||
if "state_dict" in state_dict | ||
else state_dict | ||
) # In case model was saved by TensorBoard | ||
model.load_state_dict(state_dict) | ||
elif os.path.isfile(self.model_class_or_file): | ||
model = torch.load(self.model_class_or_file) | ||
else: | ||
raise ValueError( | ||
"model_class_or_file must be a model class or path to model file" | ||
) | ||
model.to(device=device) | ||
model.eval() | ||
return model | ||
|
||
def inference(self, model, data): | ||
data = torch.from_numpy(data) | ||
device = torch.device("cuda" if self.device == "gpu" else "cpu") | ||
data = data.to(device, dtype=self.dtype) | ||
with torch.no_grad(): | ||
output = model(data) | ||
return output.cpu().numpy() if self.device == "gpu" else output.numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.