From b84974d2e734f2ad9c2e52e2294ab507d4a2c16b Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 18 Mar 2024 20:08:51 +0530 Subject: [PATCH 01/19] Support for Safetensors --- .../huggingface/materializers/__init__.py | 3 + .../huggingface_pt_model_st_materializer.py | 94 +++++++++++++++++++ .../pytorch/materializers/__init__.py | 4 + .../base_pytorch_st_materializer.py | 57 +++++++++++ .../pytorch_module_st_materializer.py | 71 ++++++++++++++ .../materializers/__init__.py | 4 + .../pytorch_lightning_st_materializer.py | 33 +++++++ ...st_huggingface_pt_model_st_materializer.py | 31 ++++++ .../test_pytorch_module_st_materializer.py | 33 +++++++ 9 files changed, 330 insertions(+) create mode 100644 src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py create mode 100644 src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py create mode 100644 src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py create mode 100644 src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py create mode 100644 tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py create mode 100644 tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py diff --git a/src/zenml/integrations/huggingface/materializers/__init__.py b/src/zenml/integrations/huggingface/materializers/__init__.py index 6f2469644c7..92f40c69e3a 100644 --- a/src/zenml/integrations/huggingface/materializers/__init__.py +++ b/src/zenml/integrations/huggingface/materializers/__init__.py @@ -25,3 +25,6 @@ from zenml.integrations.huggingface.materializers.huggingface_tokenizer_materializer import ( HFTokenizerMaterializer, ) +from zenml.integrations.huggingface.materializers.huggingface_pt_model_st_materializer import ( + HFPTModelSTMaterializer +) diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py new file mode 100644 index 00000000000..205cc2a292c --- /dev/null +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -0,0 +1,94 @@ +# Copyright (c) ZenML GmbH 2021. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Huggingface PyTorch model materializer using Safetensors.""" + +import importlib +import os +from tempfile import TemporaryDirectory +from typing import Any, ClassVar, Dict, Tuple, Type + +from safetensors.torch import load_model, save_model +from transformers import ( # type: ignore [import-untyped] + AutoConfig, + PreTrainedModel, +) + +from zenml.enums import ArtifactType +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.metadata.metadata_types import DType, MetadataType +from zenml.utils import io_utils + +DEFAULT_PT_MODEL_DIR = "hf_pt_model" + + +class HFPTModelSTMaterializer(BaseMaterializer): + """Materializer to read torch model to and from huggingface pretrained model.""" + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (PreTrainedModel,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL + + def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: + """Reads HFModel. + + Args: + data_type: The type of the model to read. + + Returns: + The model read from the specified dir. + """ + temp_dir = TemporaryDirectory() + io_utils.copy_dir( + os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir.name + ) + + config = AutoConfig.from_pretrained(temp_dir.name) + architecture = config.architectures[0] + model_cls = getattr( + importlib.import_module("transformers"), architecture + ) + loaded_model = load_model(model_cls, temp_dir.name) + return loaded_model + + def save(self, model: PreTrainedModel) -> None: + """Writes a Model to the specified dir. + + Args: + model: The Torch Model to write. + """ + temp_dir = TemporaryDirectory() + save_model(model, temp_dir.name) + io_utils.copy_dir( + temp_dir.name, + os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), + ) + + def extract_metadata( + self, model: PreTrainedModel + ) -> Dict[str, "MetadataType"]: + """Extract metadata from the given `PreTrainedModel` object. + + Args: + model: The `PreTrainedModel` object to extract metadata from. + + Returns: + The extracted metadata as a dictionary. + """ + from zenml.integrations.pytorch.utils import count_module_params + + module_param_metadata = count_module_params(model) + return { + **module_param_metadata, + "dtype": DType(str(model.dtype)), + "device": str(model.device), + } diff --git a/src/zenml/integrations/pytorch/materializers/__init__.py b/src/zenml/integrations/pytorch/materializers/__init__.py index 36d2df9e89c..5a5115c7c0e 100644 --- a/src/zenml/integrations/pytorch/materializers/__init__.py +++ b/src/zenml/integrations/pytorch/materializers/__init__.py @@ -19,3 +19,7 @@ from zenml.integrations.pytorch.materializers.pytorch_module_materializer import ( # noqa PyTorchModuleMaterializer, ) + +from zenml.integrations.pytorch.materializers.pytorch_module_st_materializer import ( #noqa + PyTorchModuleSTMaterializer +) \ No newline at end of file diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py new file mode 100644 index 00000000000..e157d579d22 --- /dev/null +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -0,0 +1,57 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the PyTorch DataLoader materializer using Safetensors.""" + +import os +from typing import Any, ClassVar, Type + +from safetensors.torch import load_file, save_file + +from zenml.materializers.base_materializer import BaseMaterializer + +DEFAULT_FILENAME = "obj.safetensors" + + +class BasePyTorchSTMaterializer(BaseMaterializer): + """Base class for PyTorch materializers.""" + + FILENAME: ClassVar[str] = DEFAULT_FILENAME + SKIP_REGISTRATION: ClassVar[bool] = True + + def load(self, data_type: Type[Any]) -> Any: + """Uses `torch.load` to load a PyTorch object. + + Args: + data_type: The type of the object to load. + + Returns: + The loaded PyTorch object. + """ + filename = os.path.join(self.uri, self.FILENAME) + return load_file(filename) + + def save(self, obj: Any) -> None: + """Uses `torch.save` to save a PyTorch object. + + Args: + obj: The PyTorch object to save. + """ + filename = os.path.join(self.uri, self.FILENAME) + save_file(obj, filename) + + +# Alias for the BasePyTorchMaterializer class, allowing users that have already used +# the old name to continue using it without breaking their code. +# 'BasePyTorchMaterializer' or 'BasePyTorchMaterliazer' to refer to the same class. +BasePyTorchSTMaterliazer = BasePyTorchSTMaterializer diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py new file mode 100644 index 00000000000..bea979f81c0 --- /dev/null +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py @@ -0,0 +1,71 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the PyTorch Module materializer using Safetensors.""" + +import os +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type + +from safetensors.torch import save_model +from torch.nn import Module + +from zenml.enums import ArtifactType +from zenml.integrations.pytorch.materializers.base_pytorch_st_materializer import ( + BasePyTorchSTMaterializer, +) +from zenml.integrations.pytorch.utils import count_module_params + +if TYPE_CHECKING: + from zenml.metadata.metadata_types import MetadataType + +DEFAULT_FILENAME = "entire_model.safetensors" +CHECKPOINT_FILENAME = "checkpoint.safetensors" + + +class PyTorchModuleSTMaterializer(BasePyTorchSTMaterializer): + """Materializer to read/write Pytorch models. + + Inspired by the guide: + https://pytorch.org/tutorials/beginner/saving_loading_models.html + """ + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL + FILENAME: ClassVar[str] = DEFAULT_FILENAME + + def save(self, model: Module) -> None: + """Writes a PyTorch model, as a model and a checkpoint. + + Args: + model: A torch.nn.Module or a dict to pass into model.save + """ + # Save entire model to artifact directory, This is the default behavior + # for loading model in development phase (training, evaluation) + super().save(model) + + # Also save model checkpoint to artifact directory, + # This is the default behavior for loading model in production phase (inference) + if isinstance(model, Module): + filename = os.path.join(self.uri, CHECKPOINT_FILENAME) + save_model(model, filename) + + def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: + """Extract metadata from the given `Model` object. + + Args: + model: The `Model` object to extract metadata from. + + Returns: + The extracted metadata as a dictionary. + """ + return {**count_module_params(model)} diff --git a/src/zenml/integrations/pytorch_lightning/materializers/__init__.py b/src/zenml/integrations/pytorch_lightning/materializers/__init__.py index 87eb2be8900..4c6159f31d4 100644 --- a/src/zenml/integrations/pytorch_lightning/materializers/__init__.py +++ b/src/zenml/integrations/pytorch_lightning/materializers/__init__.py @@ -16,3 +16,7 @@ from zenml.integrations.pytorch_lightning.materializers.pytorch_lightning_materializer import ( # noqa PyTorchLightningMaterializer, ) + +from zenml.integrations.pytorch_lightning.materializers.pytorch_lightning_st_materializer import ( #noqa + PyTorchLightningSTMaterializer +) \ No newline at end of file diff --git a/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py b/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py new file mode 100644 index 00000000000..b1f45be7597 --- /dev/null +++ b/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py @@ -0,0 +1,33 @@ +# Copyright (c) ZenML GmbH 2021. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the PyTorch Lightning Materializer using Safetensors.""" + +from typing import Any, ClassVar, Tuple, Type + +from torch.nn import Module + +from zenml.enums import ArtifactType +from zenml.integrations.pytorch.materializers.base_pytorch_st_materializer import ( + BasePyTorchSTMaterializer, +) + +CHECKPOINT_NAME = "final_checkpoint.safetensors" + + +class PyTorchLightningSTMaterializer(BasePyTorchSTMaterializer): + """Materializer to read/write PyTorch models.""" + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL + FILENAME: ClassVar[str] = CHECKPOINT_NAME diff --git a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py new file mode 100644 index 00000000000..85bf0331ade --- /dev/null +++ b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py @@ -0,0 +1,31 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from transformers import RobertaConfig, RobertaModel + +from tests.unit.test_general import _test_materializer +from zenml.integrations.huggingface.materializers.huggingface_pt_model_st_materializer import ( + HFPTModelSTMaterializer, +) + + +def test_huggingface_pretrained_model_materializer(clean_client): + """Tests whether the steps work for the Huggingface Pretrained Model materializer using Safetensors.""" + model = _test_materializer( + step_output=RobertaModel(RobertaConfig()), + materializer_class=HFPTModelSTMaterializer, + expected_metadata_size=5, + ) + + assert model.config.model_type == "roberta" diff --git a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py new file mode 100644 index 00000000000..26ffc7892b4 --- /dev/null +++ b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py @@ -0,0 +1,33 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from torch.nn import Linear + +from tests.unit.test_general import _test_materializer +from zenml.integrations.pytorch.materializers.pytorch_module_st_materializer import ( + PyTorchModuleSTMaterializer, +) + + +def test_pytorch_module_materializer(clean_client): + """Tests whether the steps work for the Sklearn materializer using Safetensors.""" + module = _test_materializer( + step_output=Linear(20, 20), + materializer_class=PyTorchModuleSTMaterializer, + expected_metadata_size=3, + ) + + assert module.in_features == 20 + assert module.out_features == 20 + assert module.bias is not None From 36982104fd8b0beb4e6bf854335e27db1f224769 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Tue, 19 Mar 2024 19:16:58 +0530 Subject: [PATCH 02/19] add documentation and few fixes --- .../handle-custom-data-types.md | 84 ++++++++++++++++++- pyproject.toml | 2 + .../huggingface_pt_model_st_materializer.py | 7 +- .../base_pytorch_st_materializer.py | 32 +++++-- .../pytorch_module_st_materializer.py | 4 +- ...st_huggingface_pt_model_st_materializer.py | 2 +- 6 files changed, 117 insertions(+), 14 deletions(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index 360fb2349f9..f8e466945a0 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -22,13 +22,93 @@ ZenML also provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.i In addition to the built-in materializers, ZenML also provides several integration-specific materializers that can be activated by installing the respective [integration](../../component-guide/integration-overview.md): -
IntegrationMaterializerHandled Data TypesStorage Format
bentomlBentoMaterializerbentoml.Bento.bento
deepchecksDeepchecksResultMateriailzerdeepchecks.CheckResult, deepchecks.SuiteResult.json
evidentlyEvidentlyProfileMaterializerevidently.Profile.json
great_expectationsGreatExpectationsMaterializergreat_expectations.ExpectationSuite, great_expectations.CheckpointResult.json
huggingfaceHFDatasetMaterializerdatasets.Dataset, datasets.DatasetDictDirectory
huggingfaceHFPTModelMaterializertransformers.PreTrainedModelDirectory
huggingfaceHFTFModelMaterializertransformers.TFPreTrainedModelDirectory
huggingfaceHFTokenizerMaterializertransformers.PreTrainedTokenizerBaseDirectory
lightgbmLightGBMBoosterMaterializerlgbm.Booster.txt
lightgbmLightGBMDatasetMaterializerlgbm.Dataset.binary
neural_prophetNeuralProphetMaterializerNeuralProphet.pt
pillowPillowImageMaterializerPillow.Image.PNG
polarsPolarsMaterializerpl.DataFrame, pl.Series.parquet
pycaretPyCaretMaterializerAny sklearn, xgboost, lightgbm or catboost model.pkl
pytorchPyTorchDataLoaderMaterializertorch.Dataset, torch.DataLoader.pt
pytorchPyTorchModuleMaterializertorch.Module.pt
scipySparseMaterializerscipy.spmatrix.npz
sparkSparkDataFrameMaterializerpyspark.DataFrame.parquet
sparkSparkModelMaterializerpyspark.Transformerpyspark.Estimator
tensorflowKerasMaterializertf.keras.ModelDirectory
tensorflowTensorflowDatasetMaterializertf.DatasetDirectory
whylogsWhylogsMaterializerwhylogs.DatasetProfileView.pb
xgboostXgboostBoosterMaterializerxgb.Booster.json
xgboostXgboostDMatrixMaterializerxgb.DMatrix.binary
+
IntegrationMaterializerHandled Data TypesStorage Format
bentomlBentoMaterializerbentoml.Bento.bento
deepchecksDeepchecksResultMateriailzerdeepchecks.CheckResult, deepchecks.SuiteResult.json
evidentlyEvidentlyProfileMaterializerevidently.Profile.json
great_expectationsGreatExpectationsMaterializergreat_expectations.ExpectationSuite, great_expectations.CheckpointResult.json
huggingfaceHFDatasetMaterializerdatasets.Dataset, datasets.DatasetDictDirectory
huggingfaceHFPTModelMaterializertransformers.PreTrainedModelDirectory
huggingfaceHFTFModelMaterializertransformers.TFPreTrainedModelDirectory
huggingfaceHFTokenizerMaterializertransformers.PreTrainedTokenizerBaseDirectory
lightgbmLightGBMBoosterMaterializerlgbm.Booster.txt
lightgbmLightGBMDatasetMaterializerlgbm.Dataset.binary
neural_prophetNeuralProphetMaterializerNeuralProphet.pt
pillowPillowImageMaterializerPillow.Image.PNG
polarsPolarsMaterializerpl.DataFrame, pl.Series.parquet
pycaretPyCaretMaterializerAny sklearn, xgboost, lightgbm or catboost model.pkl
pytorchPyTorchDataLoaderMaterializertorch.Dataset, torch.DataLoader.pt
pytorchPyTorchModuleMaterializertorch.Module.pt
pytorch_lightningPyTorchLightningMaterializertorch.Module.ckpt
scipySparseMaterializerscipy.spmatrix.npz
sparkSparkDataFrameMaterializerpyspark.DataFrame.parquet
sparkSparkModelMaterializerpyspark.Transformerpyspark.Estimator
tensorflowKerasMaterializertf.keras.ModelDirectory
tensorflowTensorflowDatasetMaterializertf.DatasetDirectory
whylogsWhylogsMaterializerwhylogs.DatasetProfileView.pb
xgboostXgboostBoosterMaterializerxgb.Booster.json
xgboostXgboostDMatrixMaterializerxgb.DMatrix.binary
{% hint style="warning" %} If you are running pipelines with a Docker-based [orchestrator](../../component-guide/orchestrators/orchestrators.md), you need to specify the corresponding integration as `required_integrations` in the `DockerSettings` of your pipeline in order to have the integration materializer available inside your Docker container. See the [pipeline configuration documentation](../pipelining-features/pipeline-settings.md) for more information. {% endhint %} -## Custom materializers + +## Safetensor Materializers + + +In addition to the standard integration-specific materializers that employ `Pickle` for serialization, opting for `Safetensors` offers a faster and more secure approach to model serialization. Further details on `Safetensors` can be found [here](https://huggingface.co/docs/safetensors/en/index). + +
IntegrationMaterializerHandled Data TypesStorage Format
huggingfaceHFPTModelSTMaterializertransformers.PreTrainedModel.safetensors
pytorchPyTorchModuleSTMaterializertorch.Module.safetensors
pytorch_lightningPyTorchLightningSTMaterializertorch.Module.safetensors
+ +### Here's an example showing how to use `PyTorchModuleSTMaterializer`: + + +Let's see how materialization using safetensors works with a basic example. Here we will use `resenet50` from pytorch to test the functionality: + +``` python +import logging +from torchvision.models import resnet50 + +from zenml.steps import step +from zenml.pipelines import pipeline +from zenml.integrations.pytorch.materializers.pytorch_module_st_materializer import PyTorchModuleSTMaterializer + +# initialize materializer, pre-trained and base model +materializer = PyTorchModuleSTMaterializer(uri="") +pretrained_model = resnet50() +base_model = resnet50(weights=None) +``` + +Create `pipeline` which includes steps to `save` and `load` model. + + +``` python +@step(enable_cache=False) +def my_first_step(): + """Step that saves the Pytorch model""" + + logging.info("Saving Model") + materializer.save(pretrained_model) + + +@step(enable_cache=False) +def my_second_step(): + """Step that loads the model and returns it""" + + logging.info("Loading Model") + materializer.load(base_model) + logging.info(f"Model path: {materializer.FILENAME}") + + +@pipeline +def first_pipeline(): + my_first_step() + my_second_step() + +first_pipeline() +``` + +By running pipeline it will yield the following output: + +```python +Initiating a new run for the pipeline: first_pipeline. +Registered new version: (version 12). +Executing a new run. +Using user: default +Using stack: default + artifact_store: default + orchestrator: default +Preventing execution of pipeline 'first_pipeline'. If this is not intended behavior, make sure to unset the environment variable 'ZENML_PREVENT_PIPELINE_EXECUTION'. +Caching disabled explicitly for step_1. +Step step_1 has started. +Saving Model +Step step_1 has finished in 0.173s. +Caching disabled explicitly for step_2. +Step step_2 has started. +Loading Model +Model path: entire_model.safetensors +Step step_2 has finished in 0.837s. +Pipeline run has finished in 1.754s. +You can visualize your pipeline runs in the ZenML Dashboard. In order to try it locally, please run zenml up. +``` + +## Custom Materializers ### Configuring a step/pipeline to use a custom materializer diff --git a/pyproject.toml b/pyproject.toml index 319cbebf3d4..9882b9fcb38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ pyyaml = ">=6.0.1" rich = { extras = ["jupyter"], version = ">=12.0.0" } sqlalchemy_utils = "0.38.3" sqlmodel = "0.0.8" +safetensors = "^0.4.2" importlib_metadata = { version = "<=7.0.0", python = "<3.10" } # Optional dependencies for the ZenServer @@ -444,6 +445,7 @@ module = [ "fastapi_utils.*", "sqlalchemy_utils.*", "sky.*", + "safetensors.*", "copier.*", "datasets.*", "pyngrok.*", diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index 205cc2a292c..e409666da0e 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -30,6 +30,7 @@ from zenml.utils import io_utils DEFAULT_PT_MODEL_DIR = "hf_pt_model" +DEFAULT_FILENAME = "model.safetensors" class HFPTModelSTMaterializer(BaseMaterializer): @@ -57,7 +58,8 @@ def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: model_cls = getattr( importlib.import_module("transformers"), architecture ) - loaded_model = load_model(model_cls, temp_dir.name) + filepath = os.path.join(temp_dir.name, DEFAULT_FILENAME) + loaded_model = load_model(model_cls, filepath) return loaded_model def save(self, model: PreTrainedModel) -> None: @@ -67,7 +69,8 @@ def save(self, model: PreTrainedModel) -> None: model: The Torch Model to write. """ temp_dir = TemporaryDirectory() - save_model(model, temp_dir.name) + filepath = os.path.join(temp_dir.name, DEFAULT_FILENAME) + save_model(model, filepath) io_utils.copy_dir( temp_dir.name, os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index e157d579d22..a14af491b6b 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -14,9 +14,10 @@ """Implementation of the PyTorch DataLoader materializer using Safetensors.""" import os -from typing import Any, ClassVar, Type +from typing import Any, ClassVar, Optional, Type -from safetensors.torch import load_file, save_file +from safetensors.torch import load_file, load_model, save_file, save_model +from torch.nn import Module from zenml.materializers.base_materializer import BaseMaterializer @@ -29,26 +30,43 @@ class BasePyTorchSTMaterializer(BaseMaterializer): FILENAME: ClassVar[str] = DEFAULT_FILENAME SKIP_REGISTRATION: ClassVar[bool] = True - def load(self, data_type: Type[Any]) -> Any: - """Uses `torch.load` to load a PyTorch object. + def load(self, obj: Any, data_type: Optional[Type[Any]] = None) -> Any: + """Uses `safetensors` to load a PyTorch object. Args: + obj: The model to load onto. data_type: The type of the object to load. Returns: The loaded PyTorch object. """ filename = os.path.join(self.uri, self.FILENAME) - return load_file(filename) + try: + if isinstance(obj, Module): + return load_model(obj, filename) + + return load_file(obj, filename) + except: + raise ValueError( + "data_type should be of type: nn.Module or Dict[str, torch.Tensor]" + ) def save(self, obj: Any) -> None: - """Uses `torch.save` to save a PyTorch object. + """Uses `safetensors` to save a PyTorch object. Args: obj: The PyTorch object to save. """ filename = os.path.join(self.uri, self.FILENAME) - save_file(obj, filename) + try: + if isinstance(obj, Module): + save_model(obj, filename) + else: + save_file(obj, filename) + except: + raise ValueError( + "data_type should be of type: nn.Module or Dict[str, torch.Tensor]" + ) # Alias for the BasePyTorchMaterializer class, allowing users that have already used diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py index bea979f81c0..e5787e0bc5f 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py @@ -16,7 +16,7 @@ import os from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type -from safetensors.torch import save_model +from safetensors.torch import save_file from torch.nn import Module from zenml.enums import ArtifactType @@ -57,7 +57,7 @@ def save(self, model: Module) -> None: # This is the default behavior for loading model in production phase (inference) if isinstance(model, Module): filename = os.path.join(self.uri, CHECKPOINT_FILENAME) - save_model(model, filename) + save_file(model.state_dict(), filename) def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: """Extract metadata from the given `Model` object. diff --git a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py index 85bf0331ade..584d27c2667 100644 --- a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py +++ b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py @@ -20,7 +20,7 @@ ) -def test_huggingface_pretrained_model_materializer(clean_client): +def test_huggingface_pretrained_model_st_materializer(clean_client): """Tests whether the steps work for the Huggingface Pretrained Model materializer using Safetensors.""" model = _test_materializer( step_output=RobertaModel(RobertaConfig()), From df476cbe7c038f95831cdc2db599953135f6837a Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 23 Mar 2024 08:13:23 +0530 Subject: [PATCH 03/19] lint and tests fix --- .../huggingface_pt_model_st_materializer.py | 16 ++++----- .../base_pytorch_st_materializer.py | 33 ++++++++++--------- ...st_huggingface_pt_model_st_materializer.py | 1 + .../test_pytorch_module_st_materializer.py | 3 +- tests/unit/test_general.py | 10 ++++-- 5 files changed, 35 insertions(+), 28 deletions(-) diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index e409666da0e..7f7081cf5ad 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -13,14 +13,12 @@ # permissions and limitations under the License. """Implementation of the Huggingface PyTorch model materializer using Safetensors.""" -import importlib import os from tempfile import TemporaryDirectory from typing import Any, ClassVar, Dict, Tuple, Type from safetensors.torch import load_model, save_model from transformers import ( # type: ignore [import-untyped] - AutoConfig, PreTrainedModel, ) @@ -39,10 +37,13 @@ class HFPTModelSTMaterializer(BaseMaterializer): ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (PreTrainedModel,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: + def load( + self, model: PreTrainedModel, data_type: Type[PreTrainedModel] + ) -> PreTrainedModel: """Reads HFModel. Args: + model: The model to load onto. data_type: The type of the model to read. Returns: @@ -53,14 +54,9 @@ def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir.name ) - config = AutoConfig.from_pretrained(temp_dir.name) - architecture = config.architectures[0] - model_cls = getattr( - importlib.import_module("transformers"), architecture - ) filepath = os.path.join(temp_dir.name, DEFAULT_FILENAME) - loaded_model = load_model(model_cls, filepath) - return loaded_model + load_model(model, filepath) + return model def save(self, model: PreTrainedModel) -> None: """Writes a Model to the specified dir. diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index a14af491b6b..cce07de22c4 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -17,7 +17,6 @@ from typing import Any, ClassVar, Optional, Type from safetensors.torch import load_file, load_model, save_file, save_model -from torch.nn import Module from zenml.materializers.base_materializer import BaseMaterializer @@ -37,36 +36,40 @@ def load(self, obj: Any, data_type: Optional[Type[Any]] = None) -> Any: obj: The model to load onto. data_type: The type of the object to load. + Raises: + ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor] + Returns: The loaded PyTorch object. """ filename = os.path.join(self.uri, self.FILENAME) try: - if isinstance(obj, Module): - return load_model(obj, filename) + if isinstance(obj, dict): + return load_file(filename) - return load_file(obj, filename) - except: - raise ValueError( - "data_type should be of type: nn.Module or Dict[str, torch.Tensor]" - ) + load_model(obj, filename) + return obj + except Exception as e: + raise ValueError(f"Invalid data_type received: {e}") def save(self, obj: Any) -> None: """Uses `safetensors` to save a PyTorch object. Args: obj: The PyTorch object to save. + + Raises: + ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor] + """ filename = os.path.join(self.uri, self.FILENAME) try: - if isinstance(obj, Module): - save_model(obj, filename) - else: + if isinstance(obj, dict): save_file(obj, filename) - except: - raise ValueError( - "data_type should be of type: nn.Module or Dict[str, torch.Tensor]" - ) + else: + save_model(obj, filename) + except Exception as e: + raise ValueError(f"Invalid data_type received: {e}") # Alias for the BasePyTorchMaterializer class, allowing users that have already used diff --git a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py index 584d27c2667..89880cfcdf2 100644 --- a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py +++ b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py @@ -26,6 +26,7 @@ def test_huggingface_pretrained_model_st_materializer(clean_client): step_output=RobertaModel(RobertaConfig()), materializer_class=HFPTModelSTMaterializer, expected_metadata_size=5, + pass_step_output=True, ) assert model.config.model_type == "roberta" diff --git a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py index 26ffc7892b4..6e68b6b7890 100644 --- a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py +++ b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py @@ -20,12 +20,13 @@ ) -def test_pytorch_module_materializer(clean_client): +def test_pytorch_module_st_materializer(clean_client): """Tests whether the steps work for the Sklearn materializer using Safetensors.""" module = _test_materializer( step_output=Linear(20, 20), materializer_class=PyTorchModuleSTMaterializer, expected_metadata_size=3, + pass_step_output=True, ) assert module.in_features == 20 diff --git a/tests/unit/test_general.py b/tests/unit/test_general.py index 5176acb8acf..7f4425a55e7 100644 --- a/tests/unit/test_general.py +++ b/tests/unit/test_general.py @@ -36,6 +36,7 @@ def _test_materializer( validation_function: Optional[Callable[[str], Any]] = None, expected_metadata_size: Optional[int] = None, return_metadata: bool = False, + pass_step_output: bool = False, assert_data_exists: bool = True, assert_data_type: bool = True, assert_visualization_exists: bool = False, @@ -59,8 +60,10 @@ def _test_materializer( file exists or a certain number of files were written. expected_metadata_size: If provided, we assert that the metadata dict returned by `materializer.extract_full_metadata()` has this size. - return_metadata: If True, we return the metadata dict returned by + return_metadata: If `True`, we return the metadata dict returned by `materializer.extract_full_metadata()`. + pass_step_output: If `True`, we also pass step_output to safetensors + materializers. assert_data_exists: If `True`, we also assert that `materializer.save()` wrote something to disk. assert_data_type: If `True`, we also assert that `materializer.load()` @@ -109,7 +112,10 @@ def _test_materializer( assert isinstance(value, MetadataTypeTuple) # Assert that materializer loads the data with the correct type - loaded_data = materializer.load(step_output_type) + if pass_step_output: + loaded_data = materializer.load(step_output, step_output_type) + else: + loaded_data = materializer.load(step_output_type) if assert_data_type: assert isinstance(loaded_data, step_output_type) # correct type From ab2a99455e07fbc0ee0bb1ee3dd867a3a867cdc8 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 25 Mar 2024 18:56:44 +0530 Subject: [PATCH 04/19] lint fix --- .../materializers/huggingface_pt_model_st_materializer.py | 4 ++-- .../pytorch/materializers/base_pytorch_st_materializer.py | 4 ++-- .../spark/materializers/spark_model_materializer.py | 6 +++--- src/zenml/materializers/base_materializer.py | 4 +++- tests/unit/test_general.py | 2 +- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index 7f7081cf5ad..50287fc6c2f 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -38,13 +38,13 @@ class HFPTModelSTMaterializer(BaseMaterializer): ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL def load( - self, model: PreTrainedModel, data_type: Type[PreTrainedModel] + self, data_type: Type[PreTrainedModel], model: PreTrainedModel ) -> PreTrainedModel: """Reads HFModel. Args: - model: The model to load onto. data_type: The type of the model to read. + model: The model to load onto. Returns: The model read from the specified dir. diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index cce07de22c4..d317da08aa0 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -29,12 +29,12 @@ class BasePyTorchSTMaterializer(BaseMaterializer): FILENAME: ClassVar[str] = DEFAULT_FILENAME SKIP_REGISTRATION: ClassVar[bool] = True - def load(self, obj: Any, data_type: Optional[Type[Any]] = None) -> Any: + def load(self, data_type: Optional[Type[Any]], obj: Any) -> Any: """Uses `safetensors` to load a PyTorch object. Args: - obj: The model to load onto. data_type: The type of the object to load. + obj: The model to load onto. Raises: ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor] diff --git a/src/zenml/integrations/spark/materializers/spark_model_materializer.py b/src/zenml/integrations/spark/materializers/spark_model_materializer.py index 0911d9ade3e..72248aa5387 100644 --- a/src/zenml/integrations/spark/materializers/spark_model_materializer.py +++ b/src/zenml/integrations/spark/materializers/spark_model_materializer.py @@ -35,18 +35,18 @@ class SparkModelMaterializer(BaseMaterializer): ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL def load( - self, model_type: Type[Any] + self, data_type: Type[Any] ) -> Union[Transformer, Estimator, Model]: # type: ignore[type-arg] """Reads and returns a Spark ML model. Args: - model_type: The type of the model to read. + data_type: The type of the model to read. Returns: A loaded spark model. """ path = os.path.join(self.uri, DEFAULT_FILEPATH) - return model_type.load(path) # type: ignore[no-any-return] + return data_type.load(path) # type: ignore[no-any-return] def save( self, diff --git a/src/zenml/materializers/base_materializer.py b/src/zenml/materializers/base_materializer.py index 9c6b8a1ed83..50982b65ba2 100644 --- a/src/zenml/materializers/base_materializer.py +++ b/src/zenml/materializers/base_materializer.py @@ -126,11 +126,13 @@ def __init__(self, uri: str): # Public Interface # ================ - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: Type[Any], *args, **kwargs) -> Any: """Write logic here to load the data of an artifact. Args: data_type: What type the artifact data should be loaded as. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. Returns: The data of the artifact. diff --git a/tests/unit/test_general.py b/tests/unit/test_general.py index 7f4425a55e7..f7f3f502f9f 100644 --- a/tests/unit/test_general.py +++ b/tests/unit/test_general.py @@ -113,7 +113,7 @@ def _test_materializer( # Assert that materializer loads the data with the correct type if pass_step_output: - loaded_data = materializer.load(step_output, step_output_type) + loaded_data = materializer.load(step_output_type, step_output) else: loaded_data = materializer.load(step_output_type) if assert_data_type: From 22f633925b672c1a474993b79efd3e15bea94053 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 25 Mar 2024 19:39:16 +0530 Subject: [PATCH 05/19] lint fix --- src/zenml/materializers/base_materializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/materializers/base_materializer.py b/src/zenml/materializers/base_materializer.py index 50982b65ba2..aaf1b125df5 100644 --- a/src/zenml/materializers/base_materializer.py +++ b/src/zenml/materializers/base_materializer.py @@ -126,7 +126,7 @@ def __init__(self, uri: str): # Public Interface # ================ - def load(self, data_type: Type[Any], *args, **kwargs) -> Any: + def load(self, data_type: Type[Any], *args: Any, **kwargs: Any) -> Any: """Write logic here to load the data of an artifact. Args: From c7c13b42c3d6825c97f1a17164e13055d84ac9cc Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:00:49 +0530 Subject: [PATCH 06/19] Update src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../pytorch/materializers/base_pytorch_st_materializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index d317da08aa0..bb29c6879a9 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ef1169c68c8d9df909a304d9532452629a290eba Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:00:59 +0530 Subject: [PATCH 07/19] Update src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../materializers/huggingface_pt_model_st_materializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index 50287fc6c2f..a264f20134b 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2021. All Rights Reserved. +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ee8a40f359a432636467df35ba7f3dffa61a5a98 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:01:09 +0530 Subject: [PATCH 08/19] Update src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../pytorch/materializers/pytorch_module_st_materializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py index e5787e0bc5f..058bfed280d 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 8e81f18ce041fb13027a65ada88ca3f3b4d0533d Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:01:16 +0530 Subject: [PATCH 09/19] Update src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../materializers/pytorch_lightning_st_materializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py b/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py index b1f45be7597..73451a6133a 100644 --- a/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py +++ b/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_st_materializer.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2021. All Rights Reserved. +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f141142076e622556421b907728eb42e0ae47abb Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:02:12 +0530 Subject: [PATCH 10/19] Update docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../advanced-guide/data-management/handle-custom-data-types.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index de85b209985..98e6675eff0 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -49,7 +49,7 @@ from torchvision.models import resnet50 from zenml.steps import step from zenml.pipelines import pipeline -from zenml.integrations.pytorch.materializers.pytorch_module_st_materializer import PyTorchModuleSTMaterializer +from zenml.integrations.pytorch.materializers import PyTorchModuleSTMaterializer # initialize materializer, pre-trained and base model materializer = PyTorchModuleSTMaterializer(uri="") From 059c6d0577656e256c555c10df6fdd225e672023 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:02:20 +0530 Subject: [PATCH 11/19] Update docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../advanced-guide/data-management/handle-custom-data-types.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index 98e6675eff0..e41fdf518b2 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -41,7 +41,7 @@ In addition to the standard integration-specific materializers that employ `Pick ### Here's an example showing how to use `PyTorchModuleSTMaterializer`: -Let's see how materialization using safetensors works with a basic example. Here we will use `resenet50` from pytorch to test the functionality: +Let's see how materialization using safetensors works with a basic example. Here we will use `resnet50` from pytorch to test the functionality: ``` python import logging From 0c161b71b01aafdd3bc6e61388533d1a1169f623 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 28 Mar 2024 10:18:09 +0530 Subject: [PATCH 12/19] remove alias --- pyproject.toml | 1 - .../pytorch/materializers/base_pytorch_st_materializer.py | 6 ------ 2 files changed, 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a1afd0ba5d4..296b5677b79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ pyyaml = ">=6.0.1" rich = { extras = ["jupyter"], version = ">=12.0.0" } sqlalchemy_utils = "0.38.3" sqlmodel = "0.0.8" -safetensors = "^0.4.2" importlib_metadata = { version = "<=7.0.0", python = "<3.10" } # Optional dependencies for the ZenServer diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index bb29c6879a9..68c3034b715 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -70,9 +70,3 @@ def save(self, obj: Any) -> None: save_model(obj, filename) except Exception as e: raise ValueError(f"Invalid data_type received: {e}") - - -# Alias for the BasePyTorchMaterializer class, allowing users that have already used -# the old name to continue using it without breaking their code. -# 'BasePyTorchMaterializer' or 'BasePyTorchMaterliazer' to refer to the same class. -BasePyTorchSTMaterliazer = BasePyTorchSTMaterializer From 8f9f06cf422430c3f164e599934bf262fff58b7e Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 30 Mar 2024 12:46:56 +0530 Subject: [PATCH 13/19] remove passing object as argument --- .../handle-custom-data-types.md | 64 +++++++++---------- .../huggingface_pt_model_st_materializer.py | 44 ++++++------- .../base_pytorch_st_materializer.py | 37 +++++++---- src/zenml/materializers/base_materializer.py | 4 +- ...st_huggingface_pt_model_st_materializer.py | 1 - .../test_pytorch_module_st_materializer.py | 1 - tests/unit/test_general.py | 8 +-- 7 files changed, 80 insertions(+), 79 deletions(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index e41fdf518b2..43779fc5e70 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -43,45 +43,39 @@ In addition to the standard integration-specific materializers that employ `Pick Let's see how materialization using safetensors works with a basic example. Here we will use `resnet50` from pytorch to test the functionality: + +Create `pipeline` which includes steps to `save` and `load` model. + + ``` python import logging -from torchvision.models import resnet50 from zenml.steps import step from zenml.pipelines import pipeline from zenml.integrations.pytorch.materializers import PyTorchModuleSTMaterializer -# initialize materializer, pre-trained and base model -materializer = PyTorchModuleSTMaterializer(uri="") -pretrained_model = resnet50() -base_model = resnet50(weights=None) -``` -Create `pipeline` which includes steps to `save` and `load` model. +@step(enable_cache=False, output_materializers=PyTorchModuleSTMaterializer) +def my_first_step() -> Module: + """Step that saves a Pytorch model""" + from torchvision.models import resnet50 + pretrained_model = resnet50() -``` python -@step(enable_cache=False) -def my_first_step(): - """Step that saves the Pytorch model""" - - logging.info("Saving Model") - materializer.save(pretrained_model) + return pretrained_model @step(enable_cache=False) -def my_second_step(): - """Step that loads the model and returns it""" - - logging.info("Loading Model") - materializer.load(base_model) - logging.info(f"Model path: {materializer.FILENAME}") +def my_second_step(model: Module): + """Step that loads the model.""" + logging.info("Model loaded correctly.") @pipeline def first_pipeline(): - my_first_step() - my_second_step() + model = my_first_step() + my_second_step(model) + first_pipeline() ``` @@ -90,24 +84,26 @@ By running pipeline it will yield the following output: ```python Initiating a new run for the pipeline: first_pipeline. -Registered new version: (version 12). +Migrating the ZenML global configuration from version 0.55.5 to version 0.56.2... +Backing up the database before migration. +Database successfully backed up to the '/Users/.../Library/Application Support/zenml/database_backup/zenml-backup.db' backup file. If something goes wrong with the upgrade, ZenML will attempt to restore the database from this backup automatically. +Successfully cleaned up database dump file /Users/darshit/Library/Application Support/zenml/database_backup/zenml-backup.db. +Registered new version: (version 13). Executing a new run. Using user: default Using stack: default artifact_store: default orchestrator: default -Preventing execution of pipeline 'first_pipeline'. If this is not intended behavior, make sure to unset the environment variable 'ZENML_PREVENT_PIPELINE_EXECUTION'. -Caching disabled explicitly for step_1. -Step step_1 has started. -Saving Model -Step step_1 has finished in 0.173s. -Caching disabled explicitly for step_2. -Step step_2 has started. -Loading Model -Model path: entire_model.safetensors -Step step_2 has finished in 0.837s. -Pipeline run has finished in 1.754s. You can visualize your pipeline runs in the ZenML Dashboard. In order to try it locally, please run zenml up. +Preventing execution of pipeline 'first_pipeline'. If this is not intended behavior, make sure to unset the environment variable 'ZENML_PREVENT_PIPELINE_EXECUTION'. +Caching disabled explicitly for my_first_step. +Step my_first_step has started. +Step my_first_step has finished in 1.650s. +Caching disabled explicitly for my_second_step. +Step my_second_step has started. +Model loaded correctly. +Step my_second_step has finished in 0.827s. +Pipeline run has finished in 2.522s. ``` ## Custom Materializers diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index a264f20134b..7eaffa82527 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -14,10 +14,10 @@ """Implementation of the Huggingface PyTorch model materializer using Safetensors.""" import os -from tempfile import TemporaryDirectory from typing import Any, ClassVar, Dict, Tuple, Type -from safetensors.torch import load_model, save_model +import torch +from safetensors.torch import load_file, save_file from transformers import ( # type: ignore [import-untyped] PreTrainedModel, ) @@ -25,10 +25,9 @@ from zenml.enums import ArtifactType from zenml.materializers.base_materializer import BaseMaterializer from zenml.metadata.metadata_types import DType, MetadataType -from zenml.utils import io_utils -DEFAULT_PT_MODEL_DIR = "hf_pt_model" DEFAULT_FILENAME = "model.safetensors" +DEFAULT_MODEL_FILENAME = "model_architecture.json" class HFPTModelSTMaterializer(BaseMaterializer): @@ -37,26 +36,26 @@ class HFPTModelSTMaterializer(BaseMaterializer): ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (PreTrainedModel,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load( - self, data_type: Type[PreTrainedModel], model: PreTrainedModel - ) -> PreTrainedModel: + def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: """Reads HFModel. Args: data_type: The type of the model to read. - model: The model to load onto. Returns: The model read from the specified dir. """ - temp_dir = TemporaryDirectory() - io_utils.copy_dir( - os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir.name - ) + # Load model architecture + model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME) + model_arch = torch.load(model_filename) + _model = model_arch["model"] - filepath = os.path.join(temp_dir.name, DEFAULT_FILENAME) - load_model(model, filepath) - return model + # Load model weight + obj_filename = os.path.join(self.uri, DEFAULT_FILENAME) + weights = load_file(obj_filename) + _model.load_state_dict(weights) + + return _model def save(self, model: PreTrainedModel) -> None: """Writes a Model to the specified dir. @@ -64,13 +63,14 @@ def save(self, model: PreTrainedModel) -> None: Args: model: The Torch Model to write. """ - temp_dir = TemporaryDirectory() - filepath = os.path.join(temp_dir.name, DEFAULT_FILENAME) - save_model(model, filepath) - io_utils.copy_dir( - temp_dir.name, - os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), - ) + # Save model weights + obj_filename = os.path.join(self.uri, DEFAULT_FILENAME) + save_file(model.state_dict(), obj_filename) + + # Save model architecture + model_arch = {"model": model} + model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME) + torch.save(model_arch, model_filename) def extract_metadata( self, model: PreTrainedModel diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index 68c3034b715..045c1f5f250 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -16,11 +16,13 @@ import os from typing import Any, ClassVar, Optional, Type -from safetensors.torch import load_file, load_model, save_file, save_model +import torch +from safetensors.torch import load_file, save_file from zenml.materializers.base_materializer import BaseMaterializer DEFAULT_FILENAME = "obj.safetensors" +DEFAULT_MODEL_FILENAME = "model_architecture.json" class BasePyTorchSTMaterializer(BaseMaterializer): @@ -29,12 +31,11 @@ class BasePyTorchSTMaterializer(BaseMaterializer): FILENAME: ClassVar[str] = DEFAULT_FILENAME SKIP_REGISTRATION: ClassVar[bool] = True - def load(self, data_type: Optional[Type[Any]], obj: Any) -> Any: + def load(self, data_type: Optional[Type[Any]]) -> Any: """Uses `safetensors` to load a PyTorch object. Args: data_type: The type of the object to load. - obj: The model to load onto. Raises: ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor] @@ -42,13 +43,21 @@ def load(self, data_type: Optional[Type[Any]], obj: Any) -> Any: Returns: The loaded PyTorch object. """ - filename = os.path.join(self.uri, self.FILENAME) + obj_filename = os.path.join(self.uri, self.FILENAME) try: - if isinstance(obj, dict): - return load_file(filename) + if isinstance(data_type, dict): + return load_file(obj_filename) + + # Load model architecture + model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME) + model_arch = torch.load(model_filename) + _model = model_arch["model"] - load_model(obj, filename) - return obj + # Load model weights + weights = load_file(obj_filename) + _model.load_state_dict(weights) + + return _model except Exception as e: raise ValueError(f"Invalid data_type received: {e}") @@ -62,11 +71,17 @@ def save(self, obj: Any) -> None: ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor] """ - filename = os.path.join(self.uri, self.FILENAME) + obj_filename = os.path.join(self.uri, self.FILENAME) try: if isinstance(obj, dict): - save_file(obj, filename) + save_file(obj, obj_filename) else: - save_model(obj, filename) + # Save model weights + save_file(obj.state_dict(), obj_filename) + + # Save model architecture + model_arch = {"model": obj} + model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME) + torch.save(model_arch, model_filename) except Exception as e: raise ValueError(f"Invalid data_type received: {e}") diff --git a/src/zenml/materializers/base_materializer.py b/src/zenml/materializers/base_materializer.py index aaf1b125df5..9c6b8a1ed83 100644 --- a/src/zenml/materializers/base_materializer.py +++ b/src/zenml/materializers/base_materializer.py @@ -126,13 +126,11 @@ def __init__(self, uri: str): # Public Interface # ================ - def load(self, data_type: Type[Any], *args: Any, **kwargs: Any) -> Any: + def load(self, data_type: Type[Any]) -> Any: """Write logic here to load the data of an artifact. Args: data_type: What type the artifact data should be loaded as. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. Returns: The data of the artifact. diff --git a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py index 89880cfcdf2..584d27c2667 100644 --- a/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py +++ b/tests/integration/integrations/huggingface/materializers/test_huggingface_pt_model_st_materializer.py @@ -26,7 +26,6 @@ def test_huggingface_pretrained_model_st_materializer(clean_client): step_output=RobertaModel(RobertaConfig()), materializer_class=HFPTModelSTMaterializer, expected_metadata_size=5, - pass_step_output=True, ) assert model.config.model_type == "roberta" diff --git a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py index 6e68b6b7890..406fecff733 100644 --- a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py +++ b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py @@ -26,7 +26,6 @@ def test_pytorch_module_st_materializer(clean_client): step_output=Linear(20, 20), materializer_class=PyTorchModuleSTMaterializer, expected_metadata_size=3, - pass_step_output=True, ) assert module.in_features == 20 diff --git a/tests/unit/test_general.py b/tests/unit/test_general.py index f7f3f502f9f..c72509512f5 100644 --- a/tests/unit/test_general.py +++ b/tests/unit/test_general.py @@ -36,7 +36,6 @@ def _test_materializer( validation_function: Optional[Callable[[str], Any]] = None, expected_metadata_size: Optional[int] = None, return_metadata: bool = False, - pass_step_output: bool = False, assert_data_exists: bool = True, assert_data_type: bool = True, assert_visualization_exists: bool = False, @@ -62,8 +61,6 @@ def _test_materializer( returned by `materializer.extract_full_metadata()` has this size. return_metadata: If `True`, we return the metadata dict returned by `materializer.extract_full_metadata()`. - pass_step_output: If `True`, we also pass step_output to safetensors - materializers. assert_data_exists: If `True`, we also assert that `materializer.save()` wrote something to disk. assert_data_type: If `True`, we also assert that `materializer.load()` @@ -112,10 +109,7 @@ def _test_materializer( assert isinstance(value, MetadataTypeTuple) # Assert that materializer loads the data with the correct type - if pass_step_output: - loaded_data = materializer.load(step_output_type, step_output) - else: - loaded_data = materializer.load(step_output_type) + loaded_data = materializer.load(step_output_type) if assert_data_type: assert isinstance(loaded_data, step_output_type) # correct type From 26f7188e7884b26f057848b0e9d70d2837eddbdd Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 3 Apr 2024 08:07:30 +0530 Subject: [PATCH 14/19] Update docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../data-management/handle-custom-data-types.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index 43779fc5e70..fb82efcf472 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -50,8 +50,9 @@ Create `pipeline` which includes steps to `save` and `load` model. ``` python import logging -from zenml.steps import step -from zenml.pipelines import pipeline +from torch.nn import Module + +from zenml import step, pipeline from zenml.integrations.pytorch.materializers import PyTorchModuleSTMaterializer From 111aab14e59db99edfa2df9a6fbaa879bb6c416e Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 3 Apr 2024 08:08:03 +0530 Subject: [PATCH 15/19] Update docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../data-management/handle-custom-data-types.md | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index fb82efcf472..9d67abf29b2 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -85,26 +85,21 @@ By running pipeline it will yield the following output: ```python Initiating a new run for the pipeline: first_pipeline. -Migrating the ZenML global configuration from version 0.55.5 to version 0.56.2... -Backing up the database before migration. -Database successfully backed up to the '/Users/.../Library/Application Support/zenml/database_backup/zenml-backup.db' backup file. If something goes wrong with the upgrade, ZenML will attempt to restore the database from this backup automatically. -Successfully cleaned up database dump file /Users/darshit/Library/Application Support/zenml/database_backup/zenml-backup.db. -Registered new version: (version 13). +Reusing registered pipeline version: (version: 3). Executing a new run. Using user: default Using stack: default - artifact_store: default orchestrator: default + artifact_store: default You can visualize your pipeline runs in the ZenML Dashboard. In order to try it locally, please run zenml up. -Preventing execution of pipeline 'first_pipeline'. If this is not intended behavior, make sure to unset the environment variable 'ZENML_PREVENT_PIPELINE_EXECUTION'. Caching disabled explicitly for my_first_step. Step my_first_step has started. -Step my_first_step has finished in 1.650s. +Step my_first_step has finished in 1.159s. Caching disabled explicitly for my_second_step. Step my_second_step has started. Model loaded correctly. -Step my_second_step has finished in 0.827s. -Pipeline run has finished in 2.522s. +Step my_second_step has finished in 0.061s. +Pipeline run has finished in 1.266s. ``` ## Custom Materializers From a0e93701974778950016387d70088d5e95b53922 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 3 Apr 2024 08:09:11 +0530 Subject: [PATCH 16/19] Update docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com> --- .../advanced-guide/data-management/handle-custom-data-types.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index 9d67abf29b2..e4189535858 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -78,7 +78,8 @@ def first_pipeline(): my_second_step(model) -first_pipeline() +if __name__ == "__main__": + first_pipeline() ``` By running pipeline it will yield the following output: From 8c71c36721161fbb418ed9e539ac08f32b70d97b Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 3 Apr 2024 11:07:05 +0530 Subject: [PATCH 17/19] numpy_st materializer, pytorch_lightning tests and few fixes --- .../handle-custom-data-types.md | 2 +- .../huggingface_pt_model_st_materializer.py | 9 +- .../base_pytorch_st_materializer.py | 9 +- .../pytorch_module_st_materializer.py | 9 +- src/zenml/materializers/__init__.py | 2 + src/zenml/materializers/numpy_materializer.py | 2 +- .../materializers/numpy_st_materializer.py | 258 ++++++++++++++++++ .../test_pytorch_dataloader_materializer.py | 2 +- .../test_pytorch_module_materializer.py | 2 +- .../test_pytorch_module_st_materializer.py | 2 +- .../test_pytorch_lightning_materializer.py | 33 +++ .../test_pytorch_lightning_st_materializer.py | 33 +++ .../test_numpy_st_materializer.py | 48 ++++ 13 files changed, 403 insertions(+), 8 deletions(-) create mode 100644 src/zenml/materializers/numpy_st_materializer.py create mode 100644 tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_materializer.py create mode 100644 tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_st_materializer.py create mode 100644 tests/unit/materializers/test_numpy_st_materializer.py diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index e4189535858..a47e99e537e 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -12,7 +12,7 @@ A materializer dictates how a given artifact can be written to and retrieved fro ZenML already includes built-in materializers for many common data types. These are always enabled and are used in the background without requiring any user interaction / activation: -
MaterializerHandled Data TypesStorage Format
BuiltInMaterializerbool, float, int, str, None.json
BytesInMaterializerbytes.txt
BuiltInContainerMaterializerdict, list, set, tupleDirectory
NumpyMaterializernp.ndarray.npy
PandasMaterializerpd.DataFrame, pd.Series.csv (or .gzip if parquet is installed)
PydanticMaterializerpydantic.BaseModel.json
ServiceMaterializerzenml.services.service.BaseService.json
StructuredStringMaterializerzenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString.csv / .html / .md (depending on type)
+
MaterializerHandled Data TypesStorage Format
BuiltInMaterializerbool, float, int, str, None.json
BytesInMaterializerbytes.txt
BuiltInContainerMaterializerdict, list, set, tupleDirectory
NumpyMaterializernp.ndarray.npy
NumpySTMaterializerbool, float, int.safetensors
PandasMaterializerpd.DataFrame, pd.Series.csv (or .gzip if parquet is installed)
PydanticMaterializerpydantic.BaseModel.json
ServiceMaterializerzenml.services.service.BaseService.json
StructuredStringMaterializerzenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString.csv / .html / .md (depending on type)
{% hint style="warning" %} ZenML provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.io/latest/core\_code\_docs/core-materializers/#zenml.materializers.cloudpickle\_materializer.CloudpickleMaterializer) that can handle any object by saving it with [cloudpickle](https://github.com/cloudpipe/cloudpickle). However, this is not production-ready because the resulting artifacts cannot be loaded when running with a different Python version. In such cases, you should consider building a [custom Materializer](handle-custom-data-types.md#custom-materializers) to save your objects in a more robust and efficient format. diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index 7eaffa82527..33573505c41 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -17,11 +17,18 @@ from typing import Any, ClassVar, Dict, Tuple, Type import torch -from safetensors.torch import load_file, save_file from transformers import ( # type: ignore [import-untyped] PreTrainedModel, ) +try: + from safetensors.torch import load_file, save_file +except ImportError: + raise ImportError( + "You are using `HFMaterializer` with safetensors.", + "You can install `safetensors` by running `pip install safetensors`.", + ) + from zenml.enums import ArtifactType from zenml.materializers.base_materializer import BaseMaterializer from zenml.metadata.metadata_types import DType, MetadataType diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py index 045c1f5f250..8e08b5e4bcd 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_st_materializer.py @@ -17,7 +17,14 @@ from typing import Any, ClassVar, Optional, Type import torch -from safetensors.torch import load_file, save_file + +try: + from safetensors.torch import load_file, save_file +except ImportError: + raise ImportError( + "You are using `PytorchMaterializer` with safetensors.", + "You can install `safetensors` by running `pip install safetensors`.", + ) from zenml.materializers.base_materializer import BaseMaterializer diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py index 058bfed280d..f9d76af0492 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_st_materializer.py @@ -16,7 +16,14 @@ import os from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type -from safetensors.torch import save_file +try: + from safetensors.torch import save_file +except ImportError: + raise ImportError( + "You are using `PytorchMaterializer` with safetensors.", + "You can install `safetensors` by running `pip install safetensors`.", + ) + from torch.nn import Module from zenml.enums import ArtifactType diff --git a/src/zenml/materializers/__init__.py b/src/zenml/materializers/__init__.py index 41b0288b95c..516865b0a4a 100644 --- a/src/zenml/materializers/__init__.py +++ b/src/zenml/materializers/__init__.py @@ -30,6 +30,7 @@ StructuredStringMaterializer, ) from zenml.materializers.numpy_materializer import NumpyMaterializer +from zenml.materializers.numpy_st_materializer import NumpySTMaterializer from zenml.materializers.pandas_materializer import PandasMaterializer from zenml.materializers.pydantic_materializer import PydanticMaterializer from zenml.materializers.service_materializer import ServiceMaterializer @@ -41,6 +42,7 @@ "CloudpickleMaterializer", "StructuredStringMaterializer", "NumpyMaterializer", + "NumpySTMaterializer", "PandasMaterializer", "PydanticMaterializer", "ServiceMaterializer", diff --git a/src/zenml/materializers/numpy_materializer.py b/src/zenml/materializers/numpy_materializer.py index e2eef1935a4..b853012944b 100644 --- a/src/zenml/materializers/numpy_materializer.py +++ b/src/zenml/materializers/numpy_materializer.py @@ -39,7 +39,7 @@ class NumpyMaterializer(BaseMaterializer): - """Materializer to read data to and from pandas.""" + """Materializer to read data to and from numpy.""" ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (np.ndarray,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA diff --git a/src/zenml/materializers/numpy_st_materializer.py b/src/zenml/materializers/numpy_st_materializer.py new file mode 100644 index 00000000000..8423c12b616 --- /dev/null +++ b/src/zenml/materializers/numpy_st_materializer.py @@ -0,0 +1,258 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the ZenML NumPy materializer using safetensors.""" + +import os +from collections import Counter +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type + +import numpy as np + +try: + from safetensors.numpy import load_file, save_file +except ImportError: + raise ImportError( + "You are using `NumpyMaterializer` with safetensors.", + "You can install `safetensors` by running `pip install safetensors`.", + ) + +from zenml.client import Client +from zenml.enums import ArtifactType, VisualizationType +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.metadata.metadata_types import DType, MetadataType + +if TYPE_CHECKING: + from numpy.typing import NDArray + +logger = get_logger(__name__) + + +NUMPY_FILENAME = "data.safetensors" + +DATA_FILENAME = "data.parquet" +SHAPE_FILENAME = "shape.json" +DATA_VAR = "data_var" + + +class NumpySTMaterializer(BaseMaterializer): + """Materializer to read data to and from numpy using safetensors.""" + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (np.ndarray,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA + + def load(self, data_type: Type[Any]) -> "Any": + """Reads a numpy array from a `.safetensors` file. + + Args: + data_type: The type of the data to read. + + + Raises: + ImportError: If pyarrow is not installed. + + Returns: + The numpy array. + """ + artifact_store = Client().active_stack.artifact_store + numpy_file = os.path.join(self.uri, NUMPY_FILENAME) + + if artifact_store.exists(numpy_file): + arr = load_file(numpy_file) + return arr["ndarray"] + elif artifact_store.exists(os.path.join(self.uri, DATA_FILENAME)): + logger.warning( + "A legacy artifact was found. " + "This artifact was created with an older version of " + "ZenML. You can still use it, but it will be " + "converted to the new format on the next materialization." + ) + try: + # Import old materializer dependencies + import pyarrow as pa # type: ignore + import pyarrow.parquet as pq # type: ignore + + from zenml.utils import yaml_utils + + # Read numpy array from parquet file + shape_dict = yaml_utils.read_json( + os.path.join(self.uri, SHAPE_FILENAME) + ) + shape_tuple = tuple(shape_dict.values()) + with artifact_store.open( + os.path.join(self.uri, DATA_FILENAME), "rb" + ) as f: + input_stream = pa.input_stream(f) + data = pq.read_table(input_stream) + vals = getattr(data.to_pandas(), DATA_VAR).values + return np.reshape(vals, shape_tuple) + except ImportError: + raise ImportError( + "You have an old version of a `NumpyMaterializer` ", + "data artifact stored in the artifact store ", + "as a `.parquet` file, which requires `pyarrow` for reading. ", + "You can install `pyarrow` by running `pip install pyarrow`.", + ) + + def save(self, arr: "NDArray[Any]") -> None: + """Writes a np.ndarray to the artifact store as a `.safetensors` file. + + Args: + arr: The numpy array to write. + """ + filename = os.path.join(self.uri, NUMPY_FILENAME) + print(arr) + obj = {"ndarray": arr} + save_file(obj, filename) + + def save_visualizations( + self, arr: "NDArray[Any]" + ) -> Dict[str, VisualizationType]: + """Saves visualizations for a numpy array. + + If the array is 1D, a histogram is saved. If the array is 2D or 3D with + 3 or 4 channels, an image is saved. + + Args: + arr: The numpy array to visualize. + + Returns: + A dictionary of visualization URIs and their types. + """ + if not np.issubdtype(arr.dtype, np.number): + return {} + + try: + # Save histogram for 1D arrays + if len(arr.shape) == 1: + histogram_path = os.path.join(self.uri, "histogram.png") + histogram_path = histogram_path.replace("\\", "/") + self._save_histogram(histogram_path, arr) + return {histogram_path: VisualizationType.IMAGE} + + # Save as image for 3D arrays with 3 or 4 channels + if len(arr.shape) == 3 and arr.shape[2] in [3, 4]: + image_path = os.path.join(self.uri, "image.png") + image_path = image_path.replace("\\", "/") + self._save_image(image_path, arr) + return {image_path: VisualizationType.IMAGE} + + except ImportError: + logger.info( + "Skipping visualization of numpy array because matplotlib " + "is not installed. To install matplotlib, run " + "`pip install matplotlib`." + ) + + return {} + + def _save_histogram(self, output_path: str, arr: "NDArray[Any]") -> None: + """Saves a histogram of a numpy array. + + Args: + output_path: The path to save the histogram to. + arr: The numpy array of which to save the histogram. + """ + import matplotlib.pyplot as plt + + artifact_store = Client().active_stack.artifact_store + plt.hist(arr) + with artifact_store.open(output_path, "wb") as f: + plt.savefig(f) + plt.close() + + def _save_image(self, output_path: str, arr: "NDArray[Any]") -> None: + """Saves a numpy array as an image. + + Args: + output_path: The path to save the image to. + arr: The numpy array to save. + """ + from matplotlib.image import imsave + + artifact_store = Client().active_stack.artifact_store + with artifact_store.open(output_path, "wb") as f: + imsave(f, arr) + + def extract_metadata( + self, arr: "NDArray[Any]" + ) -> Dict[str, "MetadataType"]: + """Extract metadata from the given numpy array. + + Args: + arr: The numpy array to extract metadata from. + + Returns: + The extracted metadata as a dictionary. + """ + if np.issubdtype(arr.dtype, np.number): + return self._extract_numeric_metadata(arr) + elif np.issubdtype(arr.dtype, np.unicode_) or np.issubdtype( + arr.dtype, np.object_ + ): + return self._extract_text_metadata(arr) + else: + return {} + + def _extract_numeric_metadata( + self, arr: "NDArray[Any]" + ) -> Dict[str, "MetadataType"]: + """Extracts numeric metadata from a numpy array. + + Args: + arr: The numpy array to extract metadata from. + + Returns: + A dictionary of metadata. + """ + min_val = np.min(arr).item() + max_val = np.max(arr).item() + + numpy_metadata: Dict[str, "MetadataType"] = { + "shape": tuple(arr.shape), + "dtype": DType(arr.dtype.type), + "mean": np.mean(arr).item(), + "std": np.std(arr).item(), + "min": min_val, + "max": max_val, + } + return numpy_metadata + + def _extract_text_metadata( + self, arr: "NDArray[Any]" + ) -> Dict[str, "MetadataType"]: + """Extracts text metadata from a numpy array. + + Args: + arr: The numpy array to extract metadata from. + + Returns: + A dictionary of metadata. + """ + text = " ".join(arr) + words = text.split() + word_counts = Counter(words) + unique_words = len(word_counts) + total_words = len(words) + most_common_word, most_common_count = word_counts.most_common(1)[0] + + text_metadata: Dict[str, "MetadataType"] = { + "shape": tuple(arr.shape), + "dtype": DType(arr.dtype.type), + "unique_words": unique_words, + "total_words": total_words, + "most_common_word": most_common_word, + "most_common_count": most_common_count, + } + return text_metadata diff --git a/tests/integration/integrations/pytorch/materializers/test_pytorch_dataloader_materializer.py b/tests/integration/integrations/pytorch/materializers/test_pytorch_dataloader_materializer.py index 8cd0cbc0f29..1f4a94e056d 100644 --- a/tests/integration/integrations/pytorch/materializers/test_pytorch_dataloader_materializer.py +++ b/tests/integration/integrations/pytorch/materializers/test_pytorch_dataloader_materializer.py @@ -23,7 +23,7 @@ def test_pytorch_dataloader_materializer(clean_client): - """Tests whether the steps work for the Sklearn materializer.""" + """Tests whether the steps work for the PyTorch Dataloader materializer.""" dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) dataloader = _test_materializer( step_output=DataLoader(dataset, batch_size=3, num_workers=7), diff --git a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_materializer.py b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_materializer.py index 30103a1dce1..026ca6abada 100644 --- a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_materializer.py +++ b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_materializer.py @@ -21,7 +21,7 @@ def test_pytorch_module_materializer(clean_client): - """Tests whether the steps work for the Sklearn materializer.""" + """Tests whether the steps work for the Pytorch Module materializer.""" module = _test_materializer( step_output=Linear(20, 20), materializer_class=PyTorchModuleMaterializer, diff --git a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py index 406fecff733..753deadc74e 100644 --- a/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py +++ b/tests/integration/integrations/pytorch/materializers/test_pytorch_module_st_materializer.py @@ -21,7 +21,7 @@ def test_pytorch_module_st_materializer(clean_client): - """Tests whether the steps work for the Sklearn materializer using Safetensors.""" + """Tests whether the steps work for the PyTorch Module materializer using Safetensors.""" module = _test_materializer( step_output=Linear(20, 20), materializer_class=PyTorchModuleSTMaterializer, diff --git a/tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_materializer.py b/tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_materializer.py new file mode 100644 index 00000000000..800fca16713 --- /dev/null +++ b/tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_materializer.py @@ -0,0 +1,33 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from tests.unit.test_general import _test_materializer +from torch.nn import Linear + +from zenml.integrations.pytorch_lightning.materializers.pytorch_lightning_materializer import ( + PyTorchLightningMaterializer, +) + + +def test_pytorch_lightning_materializer(clean_client): + """Tests whether the steps work for the PyTorch Lightning materializer.""" + module = _test_materializer( + step_output=Linear(20, 20), + materializer_class=PyTorchLightningMaterializer, + expected_metadata_size=1, + ) + + assert module.in_features == 20 + assert module.out_features == 20 + assert module.bias is not None diff --git a/tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_st_materializer.py b/tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_st_materializer.py new file mode 100644 index 00000000000..2e1a41b024d --- /dev/null +++ b/tests/integration/integrations/pytorch_lightning/materializers/test_pytorch_lightning_st_materializer.py @@ -0,0 +1,33 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from tests.unit.test_general import _test_materializer +from torch.nn import Linear + +from zenml.integrations.pytorch_lightning.materializers.pytorch_lightning_st_materializer import ( + PyTorchLightningSTMaterializer, +) + + +def test_pytorch_lightning_materializer(clean_client): + """Tests whether the steps work for the PyTorch Lightning materializer using safetensors.""" + module = _test_materializer( + step_output=Linear(20, 20), + materializer_class=PyTorchLightningSTMaterializer, + expected_metadata_size=1, + ) + + assert module.in_features == 20 + assert module.out_features == 20 + assert module.bias is not None diff --git a/tests/unit/materializers/test_numpy_st_materializer.py b/tests/unit/materializers/test_numpy_st_materializer.py new file mode 100644 index 00000000000..7d4b8c31038 --- /dev/null +++ b/tests/unit/materializers/test_numpy_st_materializer.py @@ -0,0 +1,48 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import numpy as np + +from tests.unit.test_general import _test_materializer +from zenml.materializers.numpy_st_materializer import NumpySTMaterializer +from zenml.metadata.metadata_types import ( + DType, +) + + +def test_numpy_st_materializer(): + """Test the numpy materializer using safetensors with metadata extraction.""" + + numeric_array = np.array([1, 2, 3, -1, 0.4]) + + # Test the materializer with metadata extraction + numeric_result, numeric_metadata = _test_materializer( + step_output_type=np.ndarray, + materializer_class=NumpySTMaterializer, + step_output=numeric_array, + return_metadata=True, + expected_metadata_size=7, + assert_visualization_exists=True, + ) + + # Assert that the materialized array is correct + assert np.array_equal(numeric_array, numeric_result) + + # Assert that the extracted metadata is correct for numeric array + assert numeric_metadata["shape"] == (5,) + assert numeric_metadata["dtype"] == DType(numeric_array.dtype.type) + assert numeric_metadata["mean"] == 1.08 + assert numeric_metadata["std"] == 1.3658696863171098 + assert numeric_metadata["min"] == -1 + assert numeric_metadata["max"] == 3 From 66f6da5cc4898ac87dce3bc0d663cd33d56ba90d Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 3 Apr 2024 11:28:57 +0530 Subject: [PATCH 18/19] use tempprary dir for HF --- .../huggingface_pt_model_st_materializer.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py index 33573505c41..42049af366e 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_st_materializer.py @@ -29,9 +29,12 @@ "You can install `safetensors` by running `pip install safetensors`.", ) +from tempfile import TemporaryDirectory + from zenml.enums import ArtifactType from zenml.materializers.base_materializer import BaseMaterializer from zenml.metadata.metadata_types import DType, MetadataType +from zenml.utils import io_utils DEFAULT_FILENAME = "model.safetensors" DEFAULT_MODEL_FILENAME = "model_architecture.json" @@ -52,13 +55,15 @@ def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: Returns: The model read from the specified dir. """ + temp_dir = TemporaryDirectory() + io_utils.copy_dir(self.uri, temp_dir.name) # Load model architecture - model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME) + model_filename = os.path.join(temp_dir.name, DEFAULT_MODEL_FILENAME) model_arch = torch.load(model_filename) _model = model_arch["model"] # Load model weight - obj_filename = os.path.join(self.uri, DEFAULT_FILENAME) + obj_filename = os.path.join(temp_dir.name, DEFAULT_FILENAME) weights = load_file(obj_filename) _model.load_state_dict(weights) @@ -70,15 +75,22 @@ def save(self, model: PreTrainedModel) -> None: Args: model: The Torch Model to write. """ + temp_dir = TemporaryDirectory() + # Save model weights - obj_filename = os.path.join(self.uri, DEFAULT_FILENAME) + obj_filename = os.path.join(temp_dir.name, DEFAULT_FILENAME) save_file(model.state_dict(), obj_filename) # Save model architecture model_arch = {"model": model} - model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME) + model_filename = os.path.join(temp_dir.name, DEFAULT_MODEL_FILENAME) torch.save(model_arch, model_filename) + io_utils.copy_dir( + temp_dir.name, + self.uri, + ) + def extract_metadata( self, model: PreTrainedModel ) -> Dict[str, "MetadataType"]: From e937f053fe9a21fb93e317dff793b1ac97360677 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 8 Apr 2024 15:08:12 +0530 Subject: [PATCH 19/19] poetry fix --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5b2b7c7e785..37604b71d71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,9 @@ azure-mgmt-resource = { version = ">=21.0.0", optional = true } # Optional dependencies for the S3 artifact store s3fs = { version = ">=2022.11.0", optional = true } +# Optional dependencies for materializers using safetensors +safetensors = { version = "^0.4.2", optional = true } + # Optional dependencies for the GCS artifact store gcsfs = { version = ">=2022.11.0", optional = true } @@ -187,6 +190,7 @@ server = [ templates = ["copier", "jinja2-time", "ruff", "pyyaml-include"] terraform = ["python-terraform"] secrets-aws = ["boto3"] +safetensors = ["safetensors"] secrets-gcp = ["google-cloud-secret-manager"] secrets-azure = ["azure-identity", "azure-keyvault-secrets"] secrets-hashicorp = ["hvac"]