Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for safetensors materializers #2539

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b84974d
Support for Safetensors
Dev-Khant Mar 18, 2024
3698210
add documentation and few fixes
Dev-Khant Mar 19, 2024
bd7dfa8
Merge branch 'develop' into support-for-safetensors
Dev-Khant Mar 20, 2024
dd1fac7
Merge branch 'develop' into support-for-safetensors
Dev-Khant Mar 21, 2024
c815a42
Merge branch 'develop' into support-for-safetensors
strickvl Mar 22, 2024
df476cb
lint and tests fix
Dev-Khant Mar 23, 2024
c610440
Merge branch 'develop' into support-for-safetensors
strickvl Mar 25, 2024
1a79019
Merge branch 'develop' into support-for-safetensors
Dev-Khant Mar 25, 2024
ab2a994
lint fix
Dev-Khant Mar 25, 2024
b99a465
Merge branch 'develop' into support-for-safetensors
strickvl Mar 25, 2024
22f6339
lint fix
Dev-Khant Mar 25, 2024
c7c13b4
Update src/zenml/integrations/pytorch/materializers/base_pytorch_st_m…
Dev-Khant Mar 28, 2024
ef1169c
Update src/zenml/integrations/huggingface/materializers/huggingface_p…
Dev-Khant Mar 28, 2024
ee8a40f
Update src/zenml/integrations/pytorch/materializers/pytorch_module_st…
Dev-Khant Mar 28, 2024
8e81f18
Update src/zenml/integrations/pytorch_lightning/materializers/pytorch…
Dev-Khant Mar 28, 2024
f141142
Update docs/book/user-guide/advanced-guide/data-management/handle-cus…
Dev-Khant Mar 28, 2024
059c6d0
Update docs/book/user-guide/advanced-guide/data-management/handle-cus…
Dev-Khant Mar 28, 2024
574a4d3
Merge branch 'develop' into support-for-safetensors
Dev-Khant Mar 28, 2024
0c161b7
remove alias
Dev-Khant Mar 28, 2024
6c74334
Merge branch 'develop' into support-for-safetensors
Dev-Khant Mar 30, 2024
8f9f06c
remove passing object as argument
Dev-Khant Mar 30, 2024
26f7188
Update docs/book/user-guide/advanced-guide/data-management/handle-cus…
Dev-Khant Apr 3, 2024
111aab1
Update docs/book/user-guide/advanced-guide/data-management/handle-cus…
Dev-Khant Apr 3, 2024
a0e9370
Update docs/book/user-guide/advanced-guide/data-management/handle-cus…
Dev-Khant Apr 3, 2024
8c71c36
numpy_st materializer, pytorch_lightning tests and few fixes
Dev-Khant Apr 3, 2024
8615da8
Merge branch 'develop' into support-for-safetensors
Dev-Khant Apr 3, 2024
66f6da5
use tempprary dir for HF
Dev-Khant Apr 3, 2024
083ad9f
Merge branch 'develop' into support-for-safetensors
avishniakov Apr 8, 2024
e937f05
poetry fix
Dev-Khant Apr 8, 2024
e423484
Merge branch 'develop' into support-for-safetensors
Dev-Khant Apr 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ A materializer dictates how a given artifact can be written to and retrieved fro

ZenML already includes built-in materializers for many common data types. These are always enabled and are used in the background without requiring any user interaction / activation:

<table data-full-width="true"><thead><tr><th>Materializer</th><th>Handled Data Types</th><th>Storage Format</th></tr></thead><tbody><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.built_in_materializer.BuiltInMaterializer">BuiltInMaterializer</a></td><td><code>bool</code>, <code>float</code>, <code>int</code>, <code>str</code>, <code>None</code></td><td><code>.json</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.built_in_materializer.BytesMaterializer">BytesInMaterializer</a></td><td><code>bytes</code></td><td><code>.txt</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.built_in_materializer.BuiltInContainerMaterializer">BuiltInContainerMaterializer</a></td><td><code>dict</code>, <code>list</code>, <code>set</code>, <code>tuple</code></td><td>Directory</td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.numpy_materializer.NumpyMaterializer">NumpyMaterializer</a></td><td><code>np.ndarray</code></td><td><code>.npy</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.pandas_materializer.PandasMaterializer">PandasMaterializer</a></td><td><code>pd.DataFrame</code>, <code>pd.Series</code></td><td><code>.csv</code> (or <code>.gzip</code> if <code>parquet</code> is installed)</td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.pydantic_materializer.PydanticMaterializer">PydanticMaterializer</a></td><td><code>pydantic.BaseModel</code></td><td><code>.json</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.service_materializer.ServiceMaterializer">ServiceMaterializer</a></td><td><code>zenml.services.service.BaseService</code></td><td><code>.json</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.structured_string_materializer.StructuredStringMaterializer">StructuredStringMaterializer</a></td><td><code>zenml.types.CSVString</code>, <code>zenml.types.HTMLString</code>, <code>zenml.types.MarkdownString</code></td><td><code>.csv</code> / <code>.html</code> / <code>.md</code> (depending on type)</td></tr></tbody></table>
<table data-full-width="true"><thead><tr><th>Materializer</th><th>Handled Data Types</th><th>Storage Format</th></tr></thead><tbody><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.built_in_materializer.BuiltInMaterializer">BuiltInMaterializer</a></td><td><code>bool</code>, <code>float</code>, <code>int</code>, <code>str</code>, <code>None</code></td><td><code>.json</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.built_in_materializer.BytesMaterializer">BytesInMaterializer</a></td><td><code>bytes</code></td><td><code>.txt</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.built_in_materializer.BuiltInContainerMaterializer">BuiltInContainerMaterializer</a></td><td><code>dict</code>, <code>list</code>, <code>set</code>, <code>tuple</code></td><td>Directory</td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.numpy_materializer.NumpyMaterializer">NumpyMaterializer</a></td><td><code>np.ndarray</code></td><td><code>.npy</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.numpy_st_materializer.NumpySTMaterializer">NumpySTMaterializer</a></td><td><code>bool</code>, <code>float</code>, <code>int</code></td><td><code>.safetensors</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.pandas_materializer.PandasMaterializer">PandasMaterializer</a></td><td><code>pd.DataFrame</code>, <code>pd.Series</code></td><td><code>.csv</code> (or <code>.gzip</code> if <code>parquet</code> is installed)</td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.pydantic_materializer.PydanticMaterializer">PydanticMaterializer</a></td><td><code>pydantic.BaseModel</code></td><td><code>.json</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.service_materializer.ServiceMaterializer">ServiceMaterializer</a></td><td><code>zenml.services.service.BaseService</code></td><td><code>.json</code></td></tr><tr><td><a href="https://sdkdocs.zenml.io/latest/core_code_docs/core-materializers/#zenml.materializers.structured_string_materializer.StructuredStringMaterializer">StructuredStringMaterializer</a></td><td><code>zenml.types.CSVString</code>, <code>zenml.types.HTMLString</code>, <code>zenml.types.MarkdownString</code></td><td><code>.csv</code> / <code>.html</code> / <code>.md</code> (depending on type)</td></tr></tbody></table>

{% hint style="warning" %}
ZenML provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.io/latest/core\_code\_docs/core-materializers/#zenml.materializers.cloudpickle\_materializer.CloudpickleMaterializer) that can handle any object by saving it with [cloudpickle](https://github.com/cloudpipe/cloudpickle). However, this is not production-ready because the resulting artifacts cannot be loaded when running with a different Python version. In such cases, you should consider building a [custom Materializer](handle-custom-data-types.md#custom-materializers) to save your objects in a more robust and efficient format.
Expand All @@ -30,7 +30,80 @@ In addition to the built-in materializers, ZenML also provides several integrati
If you are running pipelines with a Docker-based [orchestrator](../../component-guide/orchestrators/orchestrators.md), you need to specify the corresponding integration as `required_integrations` in the `DockerSettings` of your pipeline in order to have the integration materializer available inside your Docker container. See the [pipeline configuration documentation](../pipelining-features/pipeline-settings.md) for more information.
{% endhint %}

## Custom materializers

## Safetensor Materializers


In addition to the standard integration-specific materializers that employ `Pickle` for serialization, opting for `Safetensors` offers a faster and more secure approach to model serialization. Further details on `Safetensors` can be found [here](https://huggingface.co/docs/safetensors/en/index).

<table data-full-width="true"><thead><tr><th width="199.5">Integration</th><th width="271">Materializer</th><th width="390">Handled Data Types</th><th width="200">Storage Format</th></tr></thead><tbody><tr><td>huggingface</td><td><a href="https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-huggingface/#zenml.integrations.huggingface.materializers.huggingface_pt_model_st_materializer.HFPTModelSTMaterializer">HFPTModelSTMaterializer</a></td><td><code>transformers.PreTrainedModel</code></td><td><code>.safetensors</code></td></tr><tr><td>pytorch</td><td><a href="https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-pytorch/#zenml.integrations.pytorch.materializers.pytorch_module_st_materializer.PyTorchModuleSTMaterializer">PyTorchModuleSTMaterializer</a></td><td><code>torch.Module</code></td><td><code>.safetensors</code></td></tr><tr><td>pytorch_lightning</td><td><a href="https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-pytorch_lightning/#zenml.integrations.pytorch_lightning.materializers.pytorch_lightning_st_materializer.PyTorchLightningSTMaterializer">PyTorchLightningSTMaterializer</a></td><td><code>torch.Module</code></td><td><code>.safetensors</code></td></tr></tbody></table>

### Here's an example showing how to use `PyTorchModuleSTMaterializer`:


Let's see how materialization using safetensors works with a basic example. Here we will use `resnet50` from pytorch to test the functionality:


Create `pipeline` which includes steps to `save` and `load` model.


``` python
import logging

from torch.nn import Module

from zenml import step, pipeline
from zenml.integrations.pytorch.materializers import PyTorchModuleSTMaterializer


@step(enable_cache=False, output_materializers=PyTorchModuleSTMaterializer)
def my_first_step() -> Module:
"""Step that saves a Pytorch model"""
from torchvision.models import resnet50

pretrained_model = resnet50()

return pretrained_model


@step(enable_cache=False)
def my_second_step(model: Module):
"""Step that loads the model."""
logging.info("Model loaded correctly.")


@pipeline
def first_pipeline():
model = my_first_step()
my_second_step(model)


if __name__ == "__main__":
first_pipeline()
```

By running pipeline it will yield the following output:

```python
Initiating a new run for the pipeline: first_pipeline.
Reusing registered pipeline version: (version: 3).
Executing a new run.
Using user: default
Using stack: default
orchestrator: default
artifact_store: default
You can visualize your pipeline runs in the ZenML Dashboard. In order to try it locally, please run zenml up.
Caching disabled explicitly for my_first_step.
Step my_first_step has started.
Step my_first_step has finished in 1.159s.
Caching disabled explicitly for my_second_step.
Step my_second_step has started.
Model loaded correctly.
Step my_second_step has finished in 0.061s.
Pipeline run has finished in 1.266s.
```

## Custom Materializers

### Configuring a step/pipeline to use a custom materializer

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ azure-mgmt-resource = { version = ">=21.0.0", optional = true }
# Optional dependencies for the S3 artifact store
s3fs = { version = ">=2022.11.0", optional = true }

# Optional dependencies for materializers using safetensors
safetensors = { version = "^0.4.2", optional = true }

# Optional dependencies for the GCS artifact store
gcsfs = { version = ">=2022.11.0", optional = true }

Expand Down Expand Up @@ -186,6 +189,7 @@ server = [
templates = ["copier", "jinja2-time", "ruff", "pyyaml-include"]
terraform = ["python-terraform"]
secrets-aws = ["boto3"]
safetensors = ["safetensors"]
secrets-gcp = ["google-cloud-secret-manager"]
secrets-azure = ["azure-identity", "azure-keyvault-secrets"]
secrets-hashicorp = ["hvac"]
Expand Down Expand Up @@ -451,6 +455,7 @@ module = [
"fastapi_utils.*",
"sqlalchemy_utils.*",
"sky.*",
"safetensors.*",
Dev-Khant marked this conversation as resolved.
Show resolved Hide resolved
"copier.*",
"datasets.*",
"pyngrok.*",
Expand Down
3 changes: 3 additions & 0 deletions src/zenml/integrations/huggingface/materializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@
from zenml.integrations.huggingface.materializers.huggingface_tokenizer_materializer import (
HFTokenizerMaterializer,
)
from zenml.integrations.huggingface.materializers.huggingface_pt_model_st_materializer import (
HFPTModelSTMaterializer
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Implementation of the Huggingface PyTorch model materializer using Safetensors."""

import os
from typing import Any, ClassVar, Dict, Tuple, Type

import torch
from transformers import ( # type: ignore [import-untyped]
PreTrainedModel,
)

try:
from safetensors.torch import load_file, save_file
except ImportError:
raise ImportError(
"You are using `HFMaterializer` with safetensors.",
"You can install `safetensors` by running `pip install safetensors`.",
)

from tempfile import TemporaryDirectory

from zenml.enums import ArtifactType
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.metadata.metadata_types import DType, MetadataType
from zenml.utils import io_utils

DEFAULT_FILENAME = "model.safetensors"
DEFAULT_MODEL_FILENAME = "model_architecture.json"


class HFPTModelSTMaterializer(BaseMaterializer):
"""Materializer to read torch model to and from huggingface pretrained model."""

ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (PreTrainedModel,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL

def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel:
"""Reads HFModel.

Args:
data_type: The type of the model to read.

Returns:
The model read from the specified dir.
"""
temp_dir = TemporaryDirectory()
io_utils.copy_dir(self.uri, temp_dir.name)
# Load model architecture
model_filename = os.path.join(temp_dir.name, DEFAULT_MODEL_FILENAME)
model_arch = torch.load(model_filename)
_model = model_arch["model"]

# Load model weight
obj_filename = os.path.join(temp_dir.name, DEFAULT_FILENAME)
weights = load_file(obj_filename)
_model.load_state_dict(weights)

return _model

def save(self, model: PreTrainedModel) -> None:
"""Writes a Model to the specified dir.

Args:
model: The Torch Model to write.
"""
temp_dir = TemporaryDirectory()

# Save model weights
obj_filename = os.path.join(temp_dir.name, DEFAULT_FILENAME)
save_file(model.state_dict(), obj_filename)

# Save model architecture
model_arch = {"model": model}
model_filename = os.path.join(temp_dir.name, DEFAULT_MODEL_FILENAME)
torch.save(model_arch, model_filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are on the right path here, however, there is an issue:

This is a pattern that I see in all of the new materializers, AFAIK, if you do torch.save(...), it does not only save the model architecture but also the weights.

You can see this in play in the example we mentioned above. If you check your artifacts in your local artifact store manually, there are entire_model.safetensors and model_architecture.json present which are both roughly 100 MBs. Basically, it is saving the model twice in two different ways. We need to modify the torch.save and torch.load calls to only handle the architecture without the weights.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bcdurak Here I could not find/there is no method to just store the architecture in pytorch. So what would you recommend here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tough question. But in the current case, it is really inefficient.

It feels like we need to go back to the version where you used the save_model and load_model calls. And, we somehow need to figure out how to save the model type in the save method. If I can think of anything, I will share it here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure @bcdurak. Let me know when I switch back to previous method.


io_utils.copy_dir(
temp_dir.name,
self.uri,
)

def extract_metadata(
self, model: PreTrainedModel
) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given `PreTrainedModel` object.

Args:
model: The `PreTrainedModel` object to extract metadata from.

Returns:
The extracted metadata as a dictionary.
"""
from zenml.integrations.pytorch.utils import count_module_params

module_param_metadata = count_module_params(model)
return {
**module_param_metadata,
"dtype": DType(str(model.dtype)),
"device": str(model.device),
}
4 changes: 4 additions & 0 deletions src/zenml/integrations/pytorch/materializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@
from zenml.integrations.pytorch.materializers.pytorch_module_materializer import ( # noqa
PyTorchModuleMaterializer,
)

from zenml.integrations.pytorch.materializers.pytorch_module_st_materializer import ( #noqa
PyTorchModuleSTMaterializer
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Implementation of the PyTorch DataLoader materializer using Safetensors."""

import os
from typing import Any, ClassVar, Optional, Type

import torch

try:
from safetensors.torch import load_file, save_file
except ImportError:
raise ImportError(
"You are using `PytorchMaterializer` with safetensors.",
"You can install `safetensors` by running `pip install safetensors`.",
)

from zenml.materializers.base_materializer import BaseMaterializer

DEFAULT_FILENAME = "obj.safetensors"
DEFAULT_MODEL_FILENAME = "model_architecture.json"


class BasePyTorchSTMaterializer(BaseMaterializer):
"""Base class for PyTorch materializers."""

FILENAME: ClassVar[str] = DEFAULT_FILENAME
SKIP_REGISTRATION: ClassVar[bool] = True

def load(self, data_type: Optional[Type[Any]]) -> Any:
"""Uses `safetensors` to load a PyTorch object.

Args:
data_type: The type of the object to load.

Raises:
ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor]

Returns:
The loaded PyTorch object.
"""
obj_filename = os.path.join(self.uri, self.FILENAME)
try:
if isinstance(data_type, dict):
return load_file(obj_filename)

# Load model architecture
model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME)
model_arch = torch.load(model_filename)
_model = model_arch["model"]

# Load model weights
weights = load_file(obj_filename)
_model.load_state_dict(weights)

return _model
except Exception as e:
raise ValueError(f"Invalid data_type received: {e}")

def save(self, obj: Any) -> None:
"""Uses `safetensors` to save a PyTorch object.

Args:
obj: The PyTorch object to save.

Raises:
ValueError: If the data_type is not a nn.Module or Dict[str, torch.Tensor]

"""
obj_filename = os.path.join(self.uri, self.FILENAME)
try:
if isinstance(obj, dict):
save_file(obj, obj_filename)
else:
# Save model weights
save_file(obj.state_dict(), obj_filename)

# Save model architecture
model_arch = {"model": obj}
model_filename = os.path.join(self.uri, DEFAULT_MODEL_FILENAME)
torch.save(model_arch, model_filename)
except Exception as e:
raise ValueError(f"Invalid data_type received: {e}")
Loading