diff --git a/dev-requirements.txt b/dev-requirements.txt index bf2de90..f87e528 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,6 @@ aiobotocore boto3 +fsspec google-cloud-storage google-resumable-media numpy diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py new file mode 100644 index 0000000..92c31a5 --- /dev/null +++ b/tests/test_fsspec_storage_plugin.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import io +import logging +import os +import uuid + +import pytest +import torch + +import torchsnapshot +from torchsnapshot.io_types import ReadIO, WriteIO +from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin + +logger: logging.Logger = logging.getLogger(__name__) + +_TEST_BUCKET = "torchsnapshot-test" +_TENSOR_SZ = int(1_000_000 / 4) + + +@pytest.mark.s3_integration_test +@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") +@pytest.mark.usefixtures("s3_health_check") +def test_fsspec_s3_read_write_via_snapshot() -> None: + path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}" + logger.info(path) + + tensor = torch.rand((_TENSOR_SZ,)) + app_state = {"state": torchsnapshot.StateDict(tensor=tensor)} + snapshot = torchsnapshot.Snapshot.take(path=path, app_state=app_state) + + app_state["state"]["tensor"] = torch.rand((_TENSOR_SZ,)) + assert not torch.allclose(tensor, app_state["state"]["tensor"]) + + snapshot.restore(app_state) + assert torch.allclose(tensor, app_state["state"]["tensor"]) + + +@pytest.mark.s3_integration_test +@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") +@pytest.mark.usefixtures("s3_health_check") +@pytest.mark.asyncio +async def test_fsspec_s3_write_read_delete() -> None: + path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}" + logger.info(path) + plugin = FSSpecStoragePlugin(root=path) + + tensor = torch.rand((_TENSOR_SZ,)) + buf = io.BytesIO() + torch.save(tensor, buf) + write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) + + await plugin.write(write_io=write_io) + + read_io = ReadIO(path="tensor") + await plugin.read(read_io=read_io) + loaded = torch.load(read_io.buf) + assert torch.allclose(tensor, loaded) + + await plugin.delete(path="tensor") + await plugin.close() diff --git a/torchsnapshot/io_preparers/__init__.py b/torchsnapshot/io_preparers/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/torchsnapshot/io_preparers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index bb83715..b599d66 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -38,7 +38,7 @@ def url_to_storage_plugin( protocol, path = "fs", url_path if storage_options is None: - storage_options = {} + storage_options = dict() # Built-in storage plugins if protocol == "fs": @@ -49,6 +49,10 @@ def url_to_storage_plugin( from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin return GCSStoragePlugin(root=path, storage_options=storage_options) + elif protocol.startswith("fsspec-"): + from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin + + return FSSpecStoragePlugin(root=url_path, storage_options=storage_options) # Registered storage plugins eps = entry_points(group="storage_plugins") diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py new file mode 100644 index 0000000..e0fbab6 --- /dev/null +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import asyncio +import io +import os +from typing import Any, Dict, Optional + +import fsspec + +from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO + +__all__ = ["FSSpecStoragePlugin"] + + +class FSSpecStoragePlugin(StoragePlugin): + def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None: + root_items = root.split("://") + if len(root_items) != 2: + raise ValueError("only protocol://path is supported by fsspec plugin") + protocol, self.root = root_items + if not protocol.startswith("fsspec-"): + raise ValueError( + f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" + ) + self._protocol = protocol[len("fsspec-") :] + self._fs = None + self._session = None + self._lock = asyncio.Lock() + self._storage_options = storage_options + + async def _init_session(self) -> None: + async with self._lock: + if self._session is None: + self._fs = fsspec.filesystem( + protocol=self._protocol, **self._storage_options + ) + self._session = await self._fs.set_session(refresh=True) + + async def write(self, write_io: WriteIO) -> None: + await self._init_session() + path = os.path.join(self.root, write_io.path) + splits = path.split("/") + for i in range(len(splits)): + dir_path = "/".join(splits[:i]) + if dir_path and not await self._fs._exists(dir_path): + try: + await self._fs._mkdir(dir_path) + except AttributeError: + break + await self._fs._pipe_file(path, bytes(write_io.buf)) + + async def read(self, read_io: ReadIO) -> None: + await self._init_session() + path = os.path.join(self.root, read_io.path) + data = await self._fs._cat_file(path) + if read_io.byte_range is None: + read_io.buf = io.BytesIO(data) + else: + start, end = read_io.byte_range + read_io.buf = io.BytesIO(data[start:end]) + + async def delete(self, path: str) -> None: + await self._init_session() + path = os.path.join(self.root, path) + await self._fs._rm(path, recursive=True) + + async def close(self) -> None: + async with self._lock: + if self._session is not None: + await self._session.close() + self._session = None + self._fs = None diff --git a/torchsnapshot/storage_plugins/gcs.py b/torchsnapshot/storage_plugins/gcs.py index 8fe03d5..dbc5281 100644 --- a/torchsnapshot/storage_plugins/gcs.py +++ b/torchsnapshot/storage_plugins/gcs.py @@ -24,7 +24,6 @@ import urllib3.exceptions from google.auth import default # @manual from google.auth.transport.requests import AuthorizedSession # @manual - from google.resumable_media import common # @manual from google.resumable_media.requests import ChunkedDownload, ResumableUpload # @manual @@ -40,7 +39,6 @@ _DEFAULT_DEADLINE_SEC: int = 180 _DEFAULT_CHUNK_SIZE_BYTE: int = 100 * 1024 * 1024 - T = TypeVar("T")