From 92b42ff4816df363efe798d1ea5edf38a5d28c38 Mon Sep 17 00:00:00 2001 From: HollowPrincess Date: Fri, 6 Sep 2024 14:41:03 +0000 Subject: [PATCH 1/3] add openvino and dirmonitor --- batchflow/models/torch/base.py | 152 +++++++++++++++++++++++++-------- batchflow/monitor.py | 31 +++++++ 2 files changed, 147 insertions(+), 36 deletions(-) diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index dc7bad07d..d4dd41ed9 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -16,6 +16,9 @@ from torch import nn from torch.optim.swa_utils import AveragedModel, SWALR +import openvino as ov +import shelve + from sklearn.decomposition import PCA from ...utils_import import make_delayed_import @@ -414,6 +417,7 @@ def __init__(self, config=None): self.model = None self._model_cpu_backup = None self._loaded_from_onnx = None + self._loaded_from_openvino = None # Leading device and list of all devices to use self.device = None @@ -974,9 +978,9 @@ def model_to_device(self, model=None): model = model if model is not None else self.model if len(self.devices) > 1: - self.model = nn.DataParallel(self.model, self.devices) + model = nn.DataParallel(model, self.devices) else: - self.model.to(self.device) + model = model.to(self.device) # Apply model to train/predict on given data @@ -1359,7 +1363,7 @@ def predict(self, inputs, targets=None, outputs=None, lock=True, microbatch_size >>> batchflow_model.predict(inputs=B.images, outputs='model.body.encoder["block-0"]') """ - if self._loaded_from_onnx: + if self._loaded_from_onnx or self._loaded_from_openvino: microbatch_size = self.microbatch_size microbatch_pad_last = True @@ -1641,8 +1645,8 @@ def convert_outputs(self, outputs): # Store model - def save(self, path, use_onnx=False, path_onnx=None, batch_size=None, opset_version=13, - pickle_module=dill, **kwargs): + def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None, + batch_size=None, opset_version=13, pickle_module=dill, **kwargs): """ Save underlying PyTorch model along with meta parameters (config, device spec, etc). If `use_onnx` is set to True, then the model is converted to ONNX format and stored in a separate file. @@ -1658,7 +1662,13 @@ def save(self, path, use_onnx=False, path_onnx=None, batch_size=None, opset_vers Whether to store model in ONNX format. path_onnx : str, optional Used only if `use_onnx` is True. - If provided, then path to store the ONNX model; default `path_onnx` is `path` with added '_onnx' postfix. + If provided, then path to store the ONNX model; default `path_onnx` is `path` with '_onnx' postfix. + use_openvino: bool + Whether to store model as openvino xml file. + path_openvino : str, optional + Used only if `use_openvino` is True. + If provided, then path to store the openvino model; default `path_openvino` is `path` with '_openvino' + postfix. batch_size : int, optional Used only if `use_onnx` is True. Fixed batch size of the ONNX module. This is the only viable batch size for this model after loading. @@ -1691,11 +1701,37 @@ def save(self, path, use_onnx=False, path_onnx=None, batch_size=None, opset_vers preserved_dict = {item: getattr(self, item) for item in preserved} torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict}, path, pickle_module=pickle_module, **kwargs) + + elif use_openvino: + if batch_size is None: + raise ValueError('Specify valid `batch_size`, used for model inference!') + + path_openvino = path_openvino or (path + '_openvino') + if os.path.splitext(path_openvino)[-1] == '': + path_openvino = f'{path_openvino}.xml' + + # Save model + model = self.model.eval() + + if not isinstance(self.model, ov.Model): + inputs = self.make_placeholder_data(batch_size=batch_size, unwrap=False) + model = ov.convert_model(model, example_input=inputs) + + ov.save_model(model, output_model=path_openvino) + + # Save the rest of parameters + preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) + preserved_dict = {item: getattr(self, item) for item in preserved} + out_path_params = f'{os.path.splitext(path_openvino)[0]}_bf_params_db' + + with shelve.open(out_path_params) as params_db: + params_db.update(preserved_dict) + else: torch.save({item: getattr(self, item) for item in self.PRESERVE}, path, pickle_module=pickle_module, **kwargs) - def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): + def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): """ Load a torch model from a file. If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size @@ -1705,49 +1741,69 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, ---------- file : str, PathLike, io.Bytes a file where a model is stored. - eval_mode : bool - Whether to switch the model to eval mode. + is_openvino : bool, default False + Whether the load file as openvino model instance. make_infrastructure : bool Whether to re-create model loss, optimizer, scaler and decay. + mode : str + Model mode. pickle_module : module Module to use for pickling. kwargs : dict Other keyword arguments, passed directly to :func:`torch.save`. """ self._parse_devices() - if self.device: - kwargs['map_location'] = self.device - kwargs['map_location'] = 'cpu' - # Load items from disk storage and set them as insance attributes - checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs) + if is_openvino: + device = kwargs.pop('device', None) or self.device or 'CPU' + self.device = device.lower() - # `load_config` is a reference to `self.external_config` used to update `config` - # It is required since `self.external_config` may be overwritten in the cycle below - load_config = self.external_config + model = OVModel(model_path=file, device=device, **kwargs) + self.model = model - for key, value in checkpoint.items(): - setattr(self, key, value) - self.config = self.config + load_config + # Load params + out_path_params = f'{os.path.splitext(file)[0]}_bf_params_db' + with shelve.open(out_path_params) as params_db: + params = {**params_db} - # Load model from onnx, if needed - if 'onnx' in checkpoint: - try: - from onnx2torch import convert #pylint: disable=import-outside-toplevel - except ImportError as e: - raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e + for key, value in params.items(): + setattr(self, key, value) - model = convert(checkpoint['path_onnx']).eval() - self.model = model - self.microbatch_size = checkpoint['onnx_batch_size'] - self._loaded_from_onnx = True + self._loaded_from_openvino = True self.disable_training = True + else: + kwargs['map_location'] = self.device if self.device else 'cpu' - self.model_to_device() - if make_infrastructure: - self.make_infrastructure() + # Load items from disk storage and set them as insance attributes + checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs) + + # `load_config` is a reference to `self.external_config` used to update `config` + # It is required since `self.external_config` may be overwritten in the cycle below + load_config = self.external_config + + for key, value in checkpoint.items(): + setattr(self, key, value) + self.config = self.config + load_config + + # Load model from onnx, if needed + if 'onnx' in checkpoint: + try: + from onnx2torch import convert #pylint: disable=import-outside-toplevel + except ImportError as e: + raise ImportError('Loading model, stored in ONNX format, requires `onnx2torch` library.') from e + + model = convert(checkpoint['path_onnx']).eval() + self.model = model + self.microbatch_size = checkpoint['onnx_batch_size'] + self._loaded_from_onnx = True + self.disable_training = True - self.set_model_mode(mode) + self.model_to_device() + + if make_infrastructure: + self.make_infrastructure() + + self.set_model_mode(mode) # Utilities to use when working with TorchModel @@ -1890,8 +1946,7 @@ def reduce_channels(array, normalize=True, n_components=3): """ array = array.transpose(0, 2, 3, 1) pca_instance = PCA(n_components=n_components) - - compressed_array= pca_instance.fit_transform(array.reshape(-1, array.shape[-1])) + compressed_array = pca_instance.fit_transform(array.reshape(-1, array.shape[-1])) compressed_array = compressed_array.reshape(*array.shape[:3], n_components) if normalize: normalizer = Normalizer(mode='minmax') @@ -1900,3 +1955,28 @@ def reduce_channels(array, normalize=True, n_components=3): explained_variance_ratio = pca_instance.explained_variance_ratio_ return compressed_array, explained_variance_ratio + +class OVModel: + def __init__(self, model_path, core_config=None, device='CPU', compile_config=None): + core = ov.Core() + + if core_config is not None: + for name, kwargs_ in core_config.items(): + core.set_property(name, kwargs_) + + self.model = core.read_model(model=model_path) + + if compile_config is None: + compile_config = {} + self.model = core.compile_model(self.model, device, config=compile_config) + + def eval(self): + """ Placeholder for compatibility with :class:`~TorchModel` methods.""" + pass + + def __call__(self, input_tensor): + """ Evaluate model on provided data. """ + results = self.model(input_tensor) + + results = torch.from_numpy(results[self.model.output(0)]) + return results diff --git a/batchflow/monitor.py b/batchflow/monitor.py index 294610400..8cbc5a1d6 100644 --- a/batchflow/monitor.py +++ b/batchflow/monitor.py @@ -166,6 +166,36 @@ def get_usage(**kwargs): _ = kwargs return psutil.disk_usage('/').used / (1024 **3) +class DirDiskMonitor(ResourceMonitor): + """ Track disk usage in the provided dir. """ + UNIT = 'Mb' + + @staticmethod + def get_size(path): + """ Get disk usage in the directory. + + Under the hood, we recursively evaluate filesizes in all nested directories. + """ + size = 0 + + for dir_or_file in os.scandir(path): + if os.path.isfile(dir_or_file): + size += os.path.getsize(dir_or_file) + else: + size += DirDiskMonitor.get_size(dir_or_file) + + return size + + @staticmethod + def get_usage(**kwargs): + """ Track disk usage in the provided dir. """ + dir_path = kwargs.get('track_dir_path', None) + + if dir_path is None: + return None + + return DirDiskMonitor.get_size(path=dir_path) / (1024 **2) + # Process resource monitors: track resources of a given process class ProcessResourceMonitor(ResourceMonitor): @@ -356,6 +386,7 @@ def get_usage(gpu_handles=None, **kwargs): # System-wide monitors TotalCPUMonitor: ['total_cpu'], TotalMemoryMonitor: ['total_memory', 'total_rss'], + DirDiskMonitor: ['dir_disk'], TotalDiskMonitor: ['total_disk'], # Process monitors From 24085ddbe10cc1c8a8e25c7ac2fb38f11a5e47f9 Mon Sep 17 00:00:00 2001 From: HollowPrincess Date: Mon, 9 Sep 2024 14:40:57 +0000 Subject: [PATCH 2/3] fix PR comments --- batchflow/models/torch/base.py | 96 ++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 40 deletions(-) diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index d4dd41ed9..a8f094144 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -16,8 +16,6 @@ from torch import nn from torch.optim.swa_utils import AveragedModel, SWALR -import openvino as ov -import shelve from sklearn.decomposition import PCA @@ -1703,8 +1701,7 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op path, pickle_module=pickle_module, **kwargs) elif use_openvino: - if batch_size is None: - raise ValueError('Specify valid `batch_size`, used for model inference!') + import openvino as ov path_openvino = path_openvino or (path + '_openvino') if os.path.splitext(path_openvino)[-1] == '': @@ -1722,16 +1719,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op # Save the rest of parameters preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) preserved_dict = {item: getattr(self, item) for item in preserved} - out_path_params = f'{os.path.splitext(path_openvino)[0]}_bf_params_db' - - with shelve.open(out_path_params) as params_db: - params_db.update(preserved_dict) + torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict}, + path, pickle_module=pickle_module, **kwargs) else: torch.save({item: getattr(self, item) for item in self.PRESERVE}, path, pickle_module=pickle_module, **kwargs) - def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): + def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): """ Load a torch model from a file. If the model was saved in ONNX format (refer to :meth:`.save` for more info), we fix the microbatch size @@ -1741,8 +1736,6 @@ def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', ---------- file : str, PathLike, io.Bytes a file where a model is stored. - is_openvino : bool, default False - Whether the load file as openvino model instance. make_infrastructure : bool Whether to re-create model loss, optimizer, scaler and decay. mode : str @@ -1752,39 +1745,40 @@ def load(self, file, is_openvino=False, make_infrastructure=False, mode='eval', kwargs : dict Other keyword arguments, passed directly to :func:`torch.save`. """ - self._parse_devices() + model_load_kwargs = kwargs.pop('model_load_kwargs', {}) - if is_openvino: - device = kwargs.pop('device', None) or self.device or 'CPU' - self.device = device.lower() + device = kwargs.pop('device', None) - model = OVModel(model_path=file, device=device, **kwargs) - self.model = model + if device is not None: + self.device = device - # Load params - out_path_params = f'{os.path.splitext(file)[0]}_bf_params_db' - with shelve.open(out_path_params) as params_db: - params = {**params_db} + if (self.device == 'cpu') or ((not isinstance(self.device, str)) and (self.device.type == 'cpu')): + self.amp = False + else: + self._parse_devices() - for key, value in params.items(): - setattr(self, key, value) + kwargs['map_location'] = self.device - self._loaded_from_openvino = True - self.disable_training = True - else: - kwargs['map_location'] = self.device if self.device else 'cpu' + # Load items from disk storage and set them as insance attributes + checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs) - # Load items from disk storage and set them as insance attributes - checkpoint = torch.load(file, pickle_module=pickle_module, **kwargs) + # `load_config` is a reference to `self.external_config` used to update `config` + # It is required since `self.external_config` may be overwritten in the cycle below + load_config = self.external_config - # `load_config` is a reference to `self.external_config` used to update `config` - # It is required since `self.external_config` may be overwritten in the cycle below - load_config = self.external_config + for key, value in checkpoint.items(): + setattr(self, key, value) + self.config = self.config + load_config - for key, value in checkpoint.items(): - setattr(self, key, value) - self.config = self.config + load_config + if 'openvino' in checkpoint: + # Load openvino model + model = OVModel(model_path=checkpoint['path_openvino'], **model_load_kwargs) + self.model = model + self._loaded_from_openvino = True + self.disable_training = True + + else: # Load model from onnx, if needed if 'onnx' in checkpoint: try: @@ -1957,25 +1951,47 @@ def reduce_channels(array, normalize=True, n_components=3): return compressed_array, explained_variance_ratio class OVModel: - def __init__(self, model_path, core_config=None, device='CPU', compile_config=None): + """ Class-wrapper for openvino models to interact with them through :class:`~.TorchModel` interface. + + Note, openvino models are loaded on 'cpu' only. + + Parameters + ---------- + model_path : str + Path to compiled openvino model. + core_config : tuple or dict, optional + Openvino core properties. + If you want set properties globally provide them as tuple: `('CPU', {name: value})`. + For local properties just provide `{name: value}` dict. + For more, read the documentation: + https://docs.openvino.ai/2023.3/openvino_docs_OV_UG_query_api.html#setting-properties-globally + compile_config : dict, optional + Openvino model compilation config. + """ + def __init__(self, model_path, core_config=None, compile_config=None): + import openvino as ov + core = ov.Core() if core_config is not None: - for name, kwargs_ in core_config.items(): - core.set_property(name, kwargs_) + if isinstance(core_config, tuple): + core.set_property(core_config[0], core_config[1]) + else: + core.set_property(core_config) self.model = core.read_model(model=model_path) if compile_config is None: compile_config = {} - self.model = core.compile_model(self.model, device, config=compile_config) + + self.model = core.compile_model(self.model, 'CPU', config=compile_config) def eval(self): """ Placeholder for compatibility with :class:`~TorchModel` methods.""" pass def __call__(self, input_tensor): - """ Evaluate model on provided data. """ + """ Evaluate model on the provided data. """ results = self.model(input_tensor) results = torch.from_numpy(results[self.model.output(0)]) From 890fce07af87bea9219b11c5415afe9e769b228c Mon Sep 17 00:00:00 2001 From: HollowPrincess Date: Fri, 13 Sep 2024 07:56:02 +0000 Subject: [PATCH 3/3] rename dir_path --- batchflow/monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batchflow/monitor.py b/batchflow/monitor.py index 8cbc5a1d6..b27baaa9a 100644 --- a/batchflow/monitor.py +++ b/batchflow/monitor.py @@ -189,7 +189,7 @@ def get_size(path): @staticmethod def get_usage(**kwargs): """ Track disk usage in the provided dir. """ - dir_path = kwargs.get('track_dir_path', None) + dir_path = kwargs.get('dir_path', None) if dir_path is None: return None