Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add openvino and DirDiskMonitor #763

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 116 additions & 36 deletions batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from torch import nn
from torch.optim.swa_utils import AveragedModel, SWALR

import openvino as ov
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved
import shelve

from sklearn.decomposition import PCA

from ...utils_import import make_delayed_import
Expand Down Expand Up @@ -414,6 +417,7 @@ def __init__(self, config=None):
self.model = None
self._model_cpu_backup = None
self._loaded_from_onnx = None
self._loaded_from_openvino = None

# Leading device and list of all devices to use
self.device = None
Expand Down Expand Up @@ -974,9 +978,9 @@ def model_to_device(self, model=None):
model = model if model is not None else self.model

if len(self.devices) > 1:
self.model = nn.DataParallel(self.model, self.devices)
model = nn.DataParallel(model, self.devices)
else:
self.model.to(self.device)
model = model.to(self.device)


# Apply model to train/predict on given data
Expand Down Expand Up @@ -1359,7 +1363,7 @@ def predict(self, inputs, targets=None, outputs=None, lock=True, microbatch_size

>>> batchflow_model.predict(inputs=B.images, outputs='model.body.encoder["block-0"]')
"""
if self._loaded_from_onnx:
if self._loaded_from_onnx or self._loaded_from_openvino:
microbatch_size = self.microbatch_size
microbatch_pad_last = True

Expand Down Expand Up @@ -1641,8 +1645,8 @@ def convert_outputs(self, outputs):


# Store model
def save(self, path, use_onnx=False, path_onnx=None, batch_size=None, opset_version=13,
pickle_module=dill, **kwargs):
def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
batch_size=None, opset_version=13, pickle_module=dill, **kwargs):
""" Save underlying PyTorch model along with meta parameters (config, device spec, etc).

If `use_onnx` is set to True, then the model is converted to ONNX format and stored in a separate file.
Expand All @@ -1658,7 +1662,13 @@ def save(self, path, use_onnx=False, path_onnx=None, batch_size=None, opset_vers
Whether to store model in ONNX format.
path_onnx : str, optional
Used only if `use_onnx` is True.
If provided, then path to store the ONNX model; default `path_onnx` is `path` with added '_onnx' postfix.
If provided, then path to store the ONNX model; default `path_onnx` is `path` with '_onnx' postfix.
use_openvino: bool
Whether to store model as openvino xml file.
path_openvino : str, optional
Used only if `use_openvino` is True.
If provided, then path to store the openvino model; default `path_openvino` is `path` with '_openvino'
postfix.
batch_size : int, optional
Used only if `use_onnx` is True.
Fixed batch size of the ONNX module. This is the only viable batch size for this model after loading.
Expand Down Expand Up @@ -1691,11 +1701,37 @@ def save(self, path, use_onnx=False, path_onnx=None, batch_size=None, opset_vers
preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

elif use_openvino:
if batch_size is None:
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Specify valid `batch_size`, used for model inference!')

path_openvino = path_openvino or (path + '_openvino')
if os.path.splitext(path_openvino)[-1] == '':
path_openvino = f'{path_openvino}.xml'

# Save model
model = self.model.eval()

if not isinstance(self.model, ov.Model):
inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False)
model = ov.convert_model(model, example_input=inputs)

ov.save_model(model, output_model=path_openvino)

# Save the rest of parameters
preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
preserved_dict = {item: getattr(self, item) for item in preserved}
out_path_params = f'{os.path.splitext(path_openvino)[0]}_bf_params_db'

with shelve.open(out_path_params) as params_db:
params_db.update(preserved_dict)
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved

else:
torch.save({item: getattr(self, item) for item in self.PRESERVE},
path, pickle_module=pickle_module, **kwargs)

def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
""" Load a torch model from a file.

If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size
Expand All @@ -1705,49 +1741,69 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
----------
file : str, PathLike, io.Bytes
a file where a model is stored.
eval_mode : bool
Whether to switch the model to eval mode.
is_openvino : bool, default False
Whether the load file as openvino model instance.
make_infrastructure : bool
Whether to re-create model loss, optimizer, scaler and decay.
mode : str
Model mode.
pickle_module : module
Module to use for pickling.
kwargs : dict
Other keyword arguments, passed directly to :func:`torch.save`.
"""
self._parse_devices()
if self.device:
kwargs['map_location'] = self.device
kwargs['map_location'] = 'cpu'

# Load items from disk storage and set them as insance attributes
checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs)
if is_openvino:
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved
device = kwargs.pop('device', None) or self.device or 'CPU'
self.device = device.lower()

# `load_config` is a reference to `self.external_config` used to update `config`
# It is required since `self.external_config` may be overwritten in the cycle below
load_config = self.external_config
model = OVModel(model_path=file, device=device, **kwargs)
self.model = model

for key, value in checkpoint.items():
setattr(self, key, value)
self.config = self.config + load_config
# Load params
out_path_params = f'{os.path.splitext(file)[0]}_bf_params_db'
with shelve.open(out_path_params) as params_db:
params = {**params_db}

# Load model from onnx, if needed
if 'onnx' in checkpoint:
try:
from onnx2torch import convert #pylint: disable=import-outside-toplevel
except ImportError as e:
raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e
for key, value in params.items():
setattr(self, key, value)

model = convert(checkpoint['path_onnx']).eval()
self.model = model
self.microbatch_size = checkpoint['onnx_batch_size']
self._loaded_from_onnx = True
self._loaded_from_openvino = True
self.disable_training = True
else:
kwargs['map_location'] = self.device if self.device else 'cpu'

self.model_to_device()
if make_infrastructure:
self.make_infrastructure()
# Load items from disk storage and set them as insance attributes
checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs)

# `load_config` is a reference to `self.external_config` used to update `config`
# It is required since `self.external_config` may be overwritten in the cycle below
load_config = self.external_config

for key, value in checkpoint.items():
setattr(self, key, value)
self.config = self.config + load_config

# Load model from onnx, if needed
if 'onnx' in checkpoint:
try:
from onnx2torch import convert #pylint: disable=import-outside-toplevel
except ImportError as e:
raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e

model = convert(checkpoint['path_onnx']).eval()
self.model = model
self.microbatch_size = checkpoint['onnx_batch_size']
self._loaded_from_onnx = True
self.disable_training = True

self.set_model_mode(mode)
self.model_to_device()

if make_infrastructure:
self.make_infrastructure()

self.set_model_mode(mode)


# Utilities to use when working with TorchModel
Expand Down Expand Up @@ -1890,8 +1946,7 @@ def reduce_channels(array, normalize=True, n_components=3):
"""
array = array.transpose(0, 2, 3, 1)
pca_instance = PCA(n_components=n_components)

compressed_array= pca_instance.fit_transform(array.reshape(-1, array.shape[-1]))
compressed_array = pca_instance.fit_transform(array.reshape(-1, array.shape[-1]))
compressed_array = compressed_array.reshape(*array.shape[:3], n_components)
if normalize:
normalizer = Normalizer(mode='minmax')
Expand All @@ -1900,3 +1955,28 @@ def reduce_channels(array, normalize=True, n_components=3):
explained_variance_ratio = pca_instance.explained_variance_ratio_

return compressed_array, explained_variance_ratio

class OVModel:
def __init__(self, model_path, core_config=None, device='CPU', compile_config=None):
core = ov.Core()

if core_config is not None:
for name, kwargs_ in core_config.items():
core.set_property(name, kwargs_)

self.model = core.read_model(model=model_path)

if compile_config is None:
compile_config = {}
self.model = core.compile_model(self.model, device, config=compile_config)

def eval(self):
""" Placeholder for compatibility with :class:`~TorchModel` methods."""
pass

def __call__(self, input_tensor):
""" Evaluate model on provided data. """
results = self.model(input_tensor)

results = torch.from_numpy(results[self.model.output(0)])
return results
31 changes: 31 additions & 0 deletions batchflow/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,36 @@ def get_usage(**kwargs):
_ = kwargs
return psutil.disk_usage('/').used / (1024 **3)

class DirDiskMonitor(ResourceMonitor):
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved
""" Track disk usage in the provided dir. """
UNIT = 'Mb'

@staticmethod
def get_size(path):
""" Get disk usage in the directory.

Under the hood, we recursively evaluate filesizes in all nested directories.
"""
size = 0

for dir_or_file in os.scandir(path):
if os.path.isfile(dir_or_file):
size += os.path.getsize(dir_or_file)
else:
size += DirDiskMonitor.get_size(dir_or_file)

return size

@staticmethod
def get_usage(**kwargs):
""" Track disk usage in the provided dir. """
dir_path = kwargs.get('track_dir_path', None)
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved

if dir_path is None:
return None

return DirDiskMonitor.get_size(path=dir_path) / (1024 **2)
HollowPrincess marked this conversation as resolved.
Show resolved Hide resolved


# Process resource monitors: track resources of a given process
class ProcessResourceMonitor(ResourceMonitor):
Expand Down Expand Up @@ -356,6 +386,7 @@ def get_usage(gpu_handles=None, **kwargs):
# System-wide monitors
TotalCPUMonitor: ['total_cpu'],
TotalMemoryMonitor: ['total_memory', 'total_rss'],
DirDiskMonitor: ['dir_disk'],
TotalDiskMonitor: ['total_disk'],

# Process monitors
Expand Down
Loading