From 1134492fe154543097f46c44ce32785d386e24fe Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 20 Oct 2022 13:40:03 +0800 Subject: [PATCH 01/26] Add `storage_kwargs` to Snapshot.take --- torchsnapshot/snapshot.py | 6 +++++- torchsnapshot/storage_plugin.py | 16 ++++++++-------- torchsnapshot/storage_plugins/s3.py | 3 ++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index 8244205..d6477e7 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -179,6 +179,7 @@ def take( app_state: AppState, pg: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, + storage_kwargs: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[ Callable[[str, torch.Tensor, bool], torch.Tensor] ] = None, @@ -196,6 +197,7 @@ def take( replicated: A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored. + storage_kwargs: The StoragePlugin's extra keyword arguments. See each StoragePlugin for doc. Returns: The newly taken snapshot. @@ -212,8 +214,10 @@ def take( app_state=app_state, replicated=replicated or [], ) + if storage_kwargs is None: + storage_kwargs = dict() storage = url_to_storage_plugin_in_event_loop( - url_path=path, event_loop=event_loop + url_path=path, event_loop=event_loop, **storage_kwargs, ) pending_io_work, metadata = cls._take_impl( path=path, diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index f567dc5..a9666aa 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -14,7 +14,7 @@ from .storage_plugins.s3 import S3StoragePlugin -def url_to_storage_plugin(url_path: str) -> StoragePlugin: +def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: """ Initialize storage plugin from url path. @@ -34,13 +34,13 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: # Built-in storage plugins if protocol == "fs": - return FSStoragePlugin(root=path) + return FSStoragePlugin(root=path, **kwargs) elif protocol == "s3": - return S3StoragePlugin(root=path) + return S3StoragePlugin(root=path, **kwargs) elif protocol == "gs": from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin - return GCSStoragePlugin(root=path) + return GCSStoragePlugin(root=path, **kwargs) # Registered storage plugins eps = entry_points(group="storage_plugins") @@ -60,9 +60,9 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: def url_to_storage_plugin_in_event_loop( - url_path: str, event_loop: asyncio.AbstractEventLoop + url_path: str, event_loop: asyncio.AbstractEventLoop, **kwargs ) -> StoragePlugin: - async def _url_to_storage_plugin(url_path: str) -> StoragePlugin: - return url_to_storage_plugin(url_path=url_path) + async def _url_to_storage_plugin(url_path: str, **_kwargs) -> StoragePlugin: + return url_to_storage_plugin(url_path=url_path, **_kwargs) - return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path)) + return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path, **kwargs)) diff --git a/torchsnapshot/storage_plugins/s3.py b/torchsnapshot/storage_plugins/s3.py index 8ed07e4..17472f1 100644 --- a/torchsnapshot/storage_plugins/s3.py +++ b/torchsnapshot/storage_plugins/s3.py @@ -13,7 +13,7 @@ class S3StoragePlugin(StoragePlugin): - def __init__(self, root: str) -> None: + def __init__(self, root: str, **kwargs) -> None: try: from aiobotocore.session import get_session # @manual except ImportError: @@ -30,6 +30,7 @@ def __init__(self, root: str) -> None: self.bucket: str = components[0] self.root: str = "/".join(components[1:]) # pyre-ignore + # TODO: read AWS tokens from **kwargs? self.session = get_session() async def write(self, write_io: WriteIO) -> None: From 88dd7a6b188f7feb2486afb25ae99b77bc084da1 Mon Sep 17 00:00:00 2001 From: chengcshi Date: Thu, 20 Oct 2022 16:20:40 +0800 Subject: [PATCH 02/26] add fsspec plugin --- requirements.txt | 1 + tests/test_fsspec_storage_plugin.py | 48 +++++++++++++++++++++++++ torchsnapshot/storage_plugins/fsspec.py | 43 ++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 tests/test_fsspec_storage_plugin.py create mode 100644 torchsnapshot/storage_plugins/fsspec.py diff --git a/requirements.txt b/requirements.txt index 4ca65ce..62b3927 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ nest_asyncio psutil torch typing-extensions +fsspec diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py new file mode 100644 index 0000000..04871f4 --- /dev/null +++ b/tests/test_fsspec_storage_plugin.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import logging +import os +import tempfile +import unittest + +import torch + +from torchsnapshot.io_types import ReadIO, WriteIO +from torchsnapshot.storage_plugins.fsspec import FSSpecPlugin +from torchsnapshot.test_utils import async_test + +logger: logging.Logger = logging.getLogger(__name__) + +_TENSOR_SZ = int(100_000_000 / 4) + + +class FSSpecStoragePluginTest(unittest.TestCase): + @async_test + async def test_write_read_delete(self) -> None: + with tempfile.TemporaryDirectory() as path: + logger.info(path) + plugin = FSSpecPlugin(root=path, protocol="file") + + tensor = torch.rand((_TENSOR_SZ,)) + tensor_path = os.path.join(path, "tensor") + buf = io.BytesIO() + torch.save(tensor, buf) + write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) + + await plugin.write(write_io=write_io) + self.assertTrue(os.path.exists(tensor_path)) + + read_io = ReadIO(path="tensor") + await plugin.read(read_io=read_io) + loaded = torch.load(read_io.buf) + self.assertTrue(torch.allclose(tensor, loaded)) + + await plugin.delete(path="tensor") + self.assertFalse(os.path.exists(tensor_path)) + await plugin.close() diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py new file mode 100644 index 0000000..4e02ce9 --- /dev/null +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import io +import os + +import fsspec + +from torchsnapshot.io_types import StoragePlugin, ReadIO, WriteIO + + +class FSSpecPlugin(StoragePlugin): + def __init__(self, root: str, protocol: str, **storage_options) -> None: + self.root = root + self.fs = fsspec.filesystem(protocol, **storage_options) + + async def write(self, write_io: WriteIO) -> None: + path = os.path.join(self.root, write_io.path) + with self.fs.open(path, 'wb+') as f: + f.write(write_io.buf) + + async def read(self, read_io: ReadIO) -> None: + path = os.path.join(self.root, read_io.path) + byte_range = read_io.byte_range + + with self.fs.open(path, 'rb') as f: + if byte_range is None: + read_io.buf = io.BytesIO(f.read()) + else: + offset = byte_range[0] + size = byte_range[1] - byte_range[0] + await f.seek(offset) + read_io.buf = io.BytesIO(f.read(size)) + + async def delete(self, path: str) -> None: + path = os.path.join(self.root, path) + self.fs.delete(path) + + async def close(self) -> None: + pass From c28f15009e9941bdc638090a93f9211b938b994b Mon Sep 17 00:00:00 2001 From: chengcshi Date: Thu, 20 Oct 2022 17:35:34 +0800 Subject: [PATCH 03/26] tiny polish --- tests/test_fsspec_storage_plugin.py | 44 +++++++++++-------------- torchsnapshot/storage_plugins/fsspec.py | 20 +++++++---- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index 04871f4..1a1e156 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -4,12 +4,10 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import io import logging -import os -import tempfile import unittest +import uuid import torch @@ -19,30 +17,28 @@ logger: logging.Logger = logging.getLogger(__name__) +_TEST_BUCKET = "torchsnapshot-test" _TENSOR_SZ = int(100_000_000 / 4) class FSSpecStoragePluginTest(unittest.TestCase): @async_test async def test_write_read_delete(self) -> None: - with tempfile.TemporaryDirectory() as path: - logger.info(path) - plugin = FSSpecPlugin(root=path, protocol="file") - - tensor = torch.rand((_TENSOR_SZ,)) - tensor_path = os.path.join(path, "tensor") - buf = io.BytesIO() - torch.save(tensor, buf) - write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) - - await plugin.write(write_io=write_io) - self.assertTrue(os.path.exists(tensor_path)) - - read_io = ReadIO(path="tensor") - await plugin.read(read_io=read_io) - loaded = torch.load(read_io.buf) - self.assertTrue(torch.allclose(tensor, loaded)) - - await plugin.delete(path="tensor") - self.assertFalse(os.path.exists(tensor_path)) - await plugin.close() + path = f"{_TEST_BUCKET}/{uuid.uuid4()}" + logger.info(path) + plugin = FSSpecPlugin(root=path, protocol="s3") + + tensor = torch.rand((_TENSOR_SZ,)) + buf = io.BytesIO() + torch.save(tensor, buf) + write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) + + await plugin.write(write_io=write_io) + + read_io = ReadIO(path="tensor") + await plugin.read(read_io=read_io) + loaded = torch.load(read_io.buf) + assert torch.allclose(tensor, loaded) + + await plugin.delete(path="tensor") + await plugin.close() diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 4e02ce9..41b9553 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -15,29 +15,37 @@ class FSSpecPlugin(StoragePlugin): def __init__(self, root: str, protocol: str, **storage_options) -> None: self.root = root - self.fs = fsspec.filesystem(protocol, **storage_options) + if protocol not in ["http", "s3"]: + raise ValueError(f"Protocol {protocol} does not support async") + self.fs = fsspec.filesystem(protocol, asynchronous=True, **storage_options) async def write(self, write_io: WriteIO) -> None: + session = await self.fs.set_session() path = os.path.join(self.root, write_io.path) - with self.fs.open(path, 'wb+') as f: - f.write(write_io.buf) + with self.fs.open(path, 'wb') as f: + await f.write(write_io.buf) + await session.close() async def read(self, read_io: ReadIO) -> None: + session = await self.fs.set_session() path = os.path.join(self.root, read_io.path) byte_range = read_io.byte_range with self.fs.open(path, 'rb') as f: if byte_range is None: - read_io.buf = io.BytesIO(f.read()) + read_io.buf = io.BytesIO(await f.read()) else: offset = byte_range[0] size = byte_range[1] - byte_range[0] await f.seek(offset) - read_io.buf = io.BytesIO(f.read(size)) + read_io.buf = io.BytesIO(await f.read(size)) + await session.close() async def delete(self, path: str) -> None: + session = await self.fs.set_session() path = os.path.join(self.root, path) - self.fs.delete(path) + await self.fs.delete(path) + await session.close() async def close(self) -> None: pass From 6bd7e9593ca2a23a0a1c68ca99cb7dd9703763fd Mon Sep 17 00:00:00 2001 From: chengcshi Date: Thu, 20 Oct 2022 19:55:21 +0800 Subject: [PATCH 04/26] tiny polish --- torchsnapshot/storage_plugins/fsspec.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 41b9553..e40302b 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -20,17 +20,17 @@ def __init__(self, root: str, protocol: str, **storage_options) -> None: self.fs = fsspec.filesystem(protocol, asynchronous=True, **storage_options) async def write(self, write_io: WriteIO) -> None: - session = await self.fs.set_session() path = os.path.join(self.root, write_io.path) + session = await self.fs.set_session() with self.fs.open(path, 'wb') as f: await f.write(write_io.buf) await session.close() async def read(self, read_io: ReadIO) -> None: - session = await self.fs.set_session() path = os.path.join(self.root, read_io.path) byte_range = read_io.byte_range + session = await self.fs.set_session() with self.fs.open(path, 'rb') as f: if byte_range is None: read_io.buf = io.BytesIO(await f.read()) @@ -42,8 +42,8 @@ async def read(self, read_io: ReadIO) -> None: await session.close() async def delete(self, path: str) -> None: - session = await self.fs.set_session() path = os.path.join(self.root, path) + session = await self.fs.set_session() await self.fs.delete(path) await session.close() From 54f749c55e4eb95c82b3d1074c0df2267f42886f Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 21 Oct 2022 11:27:34 +0800 Subject: [PATCH 05/26] Follow comments --- torchsnapshot/snapshot.py | 40 ++++++++++++++++++++-------- torchsnapshot/storage_plugin.py | 26 ++++++++++++------ torchsnapshot/storage_plugins/fs.py | 4 +-- torchsnapshot/storage_plugins/gcs.py | 4 +-- torchsnapshot/storage_plugins/s3.py | 5 ++-- 5 files changed, 54 insertions(+), 25 deletions(-) diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index d6477e7..ef2fcbb 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -157,6 +157,7 @@ def __init__( self, path: str, pg: Optional[dist.ProcessGroup] = None, + storage_options: Optional[Dict[str, Any]] = None, ) -> None: """ Initializes the reference to an existing snapshot. @@ -167,10 +168,13 @@ def __init__( When unspecified: - If distributed is initialized, the global process group will be used. - If distributed is not initialized, single process is assumed. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. """ self.path: str = path self.pg: Optional[dist.ProcessGroup] = pg self._metadata: Optional[SnapshotMetadata] = None + self._storage_options = storage_options @classmethod def take( @@ -179,7 +183,7 @@ def take( app_state: AppState, pg: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, - storage_kwargs: Optional[Dict[str, Any]] = None, + storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[ Callable[[str, torch.Tensor, bool], torch.Tensor] ] = None, @@ -197,7 +201,8 @@ def take( replicated: A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored. - storage_kwargs: The StoragePlugin's extra keyword arguments. See each StoragePlugin for doc. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. Returns: The newly taken snapshot. @@ -214,10 +219,8 @@ def take( app_state=app_state, replicated=replicated or [], ) - if storage_kwargs is None: - storage_kwargs = dict() storage = url_to_storage_plugin_in_event_loop( - url_path=path, event_loop=event_loop, **storage_kwargs, + url_path=path, event_loop=event_loop, storage_options=storage_options ) pending_io_work, metadata = cls._take_impl( path=path, @@ -242,7 +245,7 @@ def take( storage.sync_close(event_loop=event_loop) event_loop.close() - snapshot = cls(path=path, pg=pg) + snapshot = cls(path=path, pg=pg, storage_options=storage_options) snapshot._metadata = metadata return snapshot @@ -253,6 +256,7 @@ def async_take( app_state: AppState, pg: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, + storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[ Callable[[str, torch.Tensor, bool], torch.Tensor] ] = None, @@ -275,6 +279,8 @@ def async_take( replicated: A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. Returns: A handle with which the newly taken snapshot can be obtained via @@ -294,7 +300,7 @@ def async_take( replicated=replicated or [], ) storage = url_to_storage_plugin_in_event_loop( - url_path=path, event_loop=event_loop + url_path=path, event_loop=event_loop, storage_options=storage_options ) pending_io_work, metadata = cls._take_impl( @@ -315,6 +321,7 @@ def async_take( metadata=metadata, storage=storage, event_loop=event_loop, + storage_options=storage_options, ) @classmethod @@ -449,6 +456,7 @@ def restore(self, app_state: AppState) -> None: Args: app_state: The program state to restore from the snapshot. + """ torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore") self._validate_app_state(app_state) @@ -457,7 +465,9 @@ def restore(self, app_state: AppState) -> None: pg_wrapper = PGWrapper(self.pg) rank = pg_wrapper.get_rank() storage = url_to_storage_plugin_in_event_loop( - url_path=self.path, event_loop=event_loop + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, ) app_state = app_state.copy() @@ -499,7 +509,9 @@ def metadata(self) -> SnapshotMetadata: if self._metadata is None: event_loop = asyncio.new_event_loop() storage = url_to_storage_plugin_in_event_loop( - url_path=self.path, event_loop=event_loop + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, ) self._metadata = self._read_snapshot_metadata( storage=storage, event_loop=event_loop @@ -569,7 +581,9 @@ def read_object( event_loop = asyncio.new_event_loop() pg_wrapper = PGWrapper(self.pg) storage = url_to_storage_plugin_in_event_loop( - url_path=self.path, event_loop=event_loop + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, ) entry = manifest[unranked_path] if isinstance(entry, PrimitiveEntry): @@ -916,12 +930,14 @@ def __init__( metadata: SnapshotMetadata, storage: StoragePlugin, event_loop: asyncio.AbstractEventLoop, + storage_options: Optional[Dict[str, Any]] = None, ) -> None: self.path = path self.pg: Optional[dist.ProcessGroup] = pg_wrapper.pg # pyre-ignore self.exc_info: Optional[Any] = None self._done = False + self._storage_options = storage_options self.thread = Thread( target=self._complete_snapshot, @@ -989,7 +1005,9 @@ def wait(self) -> Snapshot: raise RuntimeError( f"Encountered exception while taking snapshot asynchronously:\n{formatted}" ) - return Snapshot(path=self.path, pg=self.pg) + return Snapshot( + path=self.path, pg=self.pg, storage_options=self._storage_options + ) def done(self) -> bool: return self._done diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index a9666aa..c17ba72 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import asyncio +from typing import Any, Dict, Optional from importlib_metadata import entry_points @@ -14,13 +15,17 @@ from .storage_plugins.s3 import S3StoragePlugin -def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: +def url_to_storage_plugin( + url_path: str, storage_options: Optional[Dict[str, Any]] = None +) -> StoragePlugin: """ Initialize storage plugin from url path. Args: url_path: The url path following the pattern [protocol]://[path]. The protocol defaults to `fs` if unspecified. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. Returns: The initialized storage plugin. @@ -32,15 +37,18 @@ def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: else: protocol, path = "fs", url_path + if storage_options is None: + storage_options = dict() + # Built-in storage plugins if protocol == "fs": - return FSStoragePlugin(root=path, **kwargs) + return FSStoragePlugin(root=path, storage_options=storage_options) elif protocol == "s3": - return S3StoragePlugin(root=path, **kwargs) + return S3StoragePlugin(root=path, storage_options=storage_options) elif protocol == "gs": from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin - return GCSStoragePlugin(root=path, **kwargs) + return GCSStoragePlugin(root=path, storage_options=storage_options) # Registered storage plugins eps = entry_points(group="storage_plugins") @@ -60,9 +68,11 @@ def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: def url_to_storage_plugin_in_event_loop( - url_path: str, event_loop: asyncio.AbstractEventLoop, **kwargs + url_path: str, + event_loop: asyncio.AbstractEventLoop, + storage_options: Optional[Dict[str, Any]] = None, ) -> StoragePlugin: - async def _url_to_storage_plugin(url_path: str, **_kwargs) -> StoragePlugin: - return url_to_storage_plugin(url_path=url_path, **_kwargs) + async def _url_to_storage_plugin() -> StoragePlugin: + return url_to_storage_plugin(url_path=url_path, storage_options=storage_options) - return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path, **kwargs)) + return event_loop.run_until_complete(_url_to_storage_plugin()) diff --git a/torchsnapshot/storage_plugins/fs.py b/torchsnapshot/storage_plugins/fs.py index 500c8f1..f65af0f 100644 --- a/torchsnapshot/storage_plugins/fs.py +++ b/torchsnapshot/storage_plugins/fs.py @@ -8,7 +8,7 @@ import io import os import pathlib -from typing import Set +from typing import Any, Dict, Set import aiofiles import aiofiles.os @@ -17,7 +17,7 @@ class FSStoragePlugin(StoragePlugin): - def __init__(self, root: str) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: self.root = root self._dir_cache: Set[pathlib.Path] = set() diff --git a/torchsnapshot/storage_plugins/gcs.py b/torchsnapshot/storage_plugins/gcs.py index 9778fad..4166b54 100644 --- a/torchsnapshot/storage_plugins/gcs.py +++ b/torchsnapshot/storage_plugins/gcs.py @@ -16,7 +16,7 @@ import random import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Awaitable, Callable, Optional, TypeVar +from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar from urllib.parse import quote import google.auth.exceptions # @manual @@ -59,7 +59,7 @@ class GCSStoragePlugin(StoragePlugin): "{bucket}/o/{blob_name}?alt=media" ) - def __init__(self, root: str) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: components = root.split("/") if len(components) < 2: raise RuntimeError( diff --git a/torchsnapshot/storage_plugins/s3.py b/torchsnapshot/storage_plugins/s3.py index 17472f1..e41a7fa 100644 --- a/torchsnapshot/storage_plugins/s3.py +++ b/torchsnapshot/storage_plugins/s3.py @@ -7,13 +7,14 @@ import io import os +from typing import Any, Dict from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO from torchsnapshot.memoryview_stream import MemoryviewStream class S3StoragePlugin(StoragePlugin): - def __init__(self, root: str, **kwargs) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: try: from aiobotocore.session import get_session # @manual except ImportError: @@ -30,7 +31,7 @@ def __init__(self, root: str, **kwargs) -> None: self.bucket: str = components[0] self.root: str = "/".join(components[1:]) # pyre-ignore - # TODO: read AWS tokens from **kwargs? + # TODO: read AWS tokens from storage_options? self.session = get_session() async def write(self, write_io: WriteIO) -> None: From 5edff085663f17b8197756214500cc937a4d59f0 Mon Sep 17 00:00:00 2001 From: chengcshi Date: Fri, 21 Oct 2022 16:43:22 +0800 Subject: [PATCH 06/26] tiny polish --- tests/test_fsspec_storage_plugin.py | 43 ++++++++++++------------- torchsnapshot/storage_plugins/fsspec.py | 25 +++----------- 2 files changed, 25 insertions(+), 43 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index 1a1e156..34a0c5a 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -6,39 +6,38 @@ # LICENSE file in the root directory of this source tree. import io import logging -import unittest import uuid +import pytest import torch -from torchsnapshot.io_types import ReadIO, WriteIO +from torchsnapshot.io_types import WriteIO, ReadIO from torchsnapshot.storage_plugins.fsspec import FSSpecPlugin -from torchsnapshot.test_utils import async_test logger: logging.Logger = logging.getLogger(__name__) -_TEST_BUCKET = "torchsnapshot-test" -_TENSOR_SZ = int(100_000_000 / 4) +_TEST_BUCKET = "chengcshi" +# _TENSOR_SZ = int(100_000_000 / 4) +_TENSOR_SZ = 10 -class FSSpecStoragePluginTest(unittest.TestCase): - @async_test - async def test_write_read_delete(self) -> None: - path = f"{_TEST_BUCKET}/{uuid.uuid4()}" - logger.info(path) - plugin = FSSpecPlugin(root=path, protocol="s3") +@pytest.mark.asyncio +async def test_fsspec_s3_write_read_delete() -> None: + path = f"{_TEST_BUCKET}/{uuid.uuid4()}" + logger.info(path) + plugin = FSSpecPlugin(root=path, protocol="s3") - tensor = torch.rand((_TENSOR_SZ,)) - buf = io.BytesIO() - torch.save(tensor, buf) - write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) + tensor = torch.rand((_TENSOR_SZ,)) + buf = io.BytesIO() + torch.save(tensor, buf) + write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) - await plugin.write(write_io=write_io) + await plugin.write(write_io=write_io) - read_io = ReadIO(path="tensor") - await plugin.read(read_io=read_io) - loaded = torch.load(read_io.buf) - assert torch.allclose(tensor, loaded) + read_io = ReadIO(path="tensor") + await plugin.read(read_io=read_io) + loaded = torch.load(read_io.buf) + assert torch.allclose(tensor, loaded) - await plugin.delete(path="tensor") - await plugin.close() + await plugin.delete(path="tensor") + await plugin.close() diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index e40302b..23fba00 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -15,37 +15,20 @@ class FSSpecPlugin(StoragePlugin): def __init__(self, root: str, protocol: str, **storage_options) -> None: self.root = root - if protocol not in ["http", "s3"]: - raise ValueError(f"Protocol {protocol} does not support async") self.fs = fsspec.filesystem(protocol, asynchronous=True, **storage_options) async def write(self, write_io: WriteIO) -> None: path = os.path.join(self.root, write_io.path) - session = await self.fs.set_session() - with self.fs.open(path, 'wb') as f: - await f.write(write_io.buf) - await session.close() + await self.fs._pipe_file(path, bytes(write_io.buf)) async def read(self, read_io: ReadIO) -> None: path = os.path.join(self.root, read_io.path) - byte_range = read_io.byte_range - - session = await self.fs.set_session() - with self.fs.open(path, 'rb') as f: - if byte_range is None: - read_io.buf = io.BytesIO(await f.read()) - else: - offset = byte_range[0] - size = byte_range[1] - byte_range[0] - await f.seek(offset) - read_io.buf = io.BytesIO(await f.read(size)) - await session.close() + result = await self.fs._cat_file(path) + read_io.buf = io.BytesIO(result) async def delete(self, path: str) -> None: path = os.path.join(self.root, path) - session = await self.fs.set_session() - await self.fs.delete(path) - await session.close() + await self.fs._rm_file(path) async def close(self) -> None: pass From 98030751fb5f480d1c8ce3ec00f75bad5a350a0b Mon Sep 17 00:00:00 2001 From: shicheng0829 <523656402@qq.com> Date: Fri, 21 Oct 2022 17:13:15 +0800 Subject: [PATCH 07/26] add session init and close in fsspec plugin --- tests/test_fsspec_storage_plugin.py | 13 ++++++++----- torchsnapshot/storage_plugins/fsspec.py | 25 ++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index 34a0c5a..74df8b1 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import io import logging +import os import uuid import pytest @@ -16,16 +17,18 @@ logger: logging.Logger = logging.getLogger(__name__) -_TEST_BUCKET = "chengcshi" -# _TENSOR_SZ = int(100_000_000 / 4) -_TENSOR_SZ = 10 +_TEST_BUCKET = "torchsnapshot-test" +_TENSOR_SZ = int(100_000_000 / 4) +@pytest.mark.s3_integration_test +@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") +@pytest.mark.usefixtures("s3_health_check") @pytest.mark.asyncio async def test_fsspec_s3_write_read_delete() -> None: - path = f"{_TEST_BUCKET}/{uuid.uuid4()}" + path = f"s3://{_TEST_BUCKET}/{uuid.uuid4()}" logger.info(path) - plugin = FSSpecPlugin(root=path, protocol="s3") + plugin = FSSpecPlugin(root=path) tensor = torch.rand((_TENSOR_SZ,)) buf = io.BytesIO() diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 23fba00..7b18bdd 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -4,6 +4,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import io import os @@ -13,22 +14,40 @@ class FSSpecPlugin(StoragePlugin): - def __init__(self, root: str, protocol: str, **storage_options) -> None: + def __init__(self, root: str, **storage_options) -> None: self.root = root - self.fs = fsspec.filesystem(protocol, asynchronous=True, **storage_options) + self.fs = fsspec.filesystem(protocol=self.root.split("://")[0], **storage_options) + self._session = None + + async def _init_session(self) -> None: + lock = asyncio.Lock() + async with lock: + if self._session is None: + self._session = await self.fs.set_session() async def write(self, write_io: WriteIO) -> None: + await self._init_session() path = os.path.join(self.root, write_io.path) await self.fs._pipe_file(path, bytes(write_io.buf)) async def read(self, read_io: ReadIO) -> None: + await self._init_session() path = os.path.join(self.root, read_io.path) result = await self.fs._cat_file(path) read_io.buf = io.BytesIO(result) async def delete(self, path: str) -> None: + await self._init_session() path = os.path.join(self.root, path) await self.fs._rm_file(path) async def close(self) -> None: - pass + lock = asyncio.Lock() + async with lock: + if self._session is not None: + try: + await self._session.close() + except AttributeError: + # bug in aiobotocore 1.4.1 + await self._session._endpoint.http_session._session.close() + self._session = None From a57c62e10ab1f9eae3be4b10f986b38486d4d26c Mon Sep 17 00:00:00 2001 From: shicheng0829 <523656402@qq.com> Date: Fri, 21 Oct 2022 17:32:52 +0800 Subject: [PATCH 08/26] polish fsspec plugin lock scope --- tests/test_fsspec_storage_plugin.py | 18 ++++++++++-------- torchsnapshot/storage_plugins/fsspec.py | 19 ++++++++----------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index 74df8b1..919c20f 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -6,7 +6,6 @@ # LICENSE file in the root directory of this source tree. import io import logging -import os import uuid import pytest @@ -17,18 +16,21 @@ logger: logging.Logger = logging.getLogger(__name__) -_TEST_BUCKET = "torchsnapshot-test" -_TENSOR_SZ = int(100_000_000 / 4) +# _TEST_BUCKET = "torchsnapshot-test" +_TEST_BUCKET = "chengcshi" +# _TENSOR_SZ = int(100_000_000 / 4) +_TENSOR_SZ = 10 -@pytest.mark.s3_integration_test -@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") -@pytest.mark.usefixtures("s3_health_check") +# @pytest.mark.s3_integration_test +# @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") +# @pytest.mark.usefixtures("s3_health_check") @pytest.mark.asyncio async def test_fsspec_s3_write_read_delete() -> None: - path = f"s3://{_TEST_BUCKET}/{uuid.uuid4()}" + path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}" logger.info(path) - plugin = FSSpecPlugin(root=path) + plugin = FSSpecPlugin(root=path, key="AKIA34KDUMSNPNTFPTSA", + secret="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") tensor = torch.rand((_TENSOR_SZ,)) buf = io.BytesIO() diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 7b18bdd..b5938e5 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -15,13 +15,15 @@ class FSSpecPlugin(StoragePlugin): def __init__(self, root: str, **storage_options) -> None: - self.root = root - self.fs = fsspec.filesystem(protocol=self.root.split("://")[0], **storage_options) + protocol, self.root = root.split("://") + if not protocol.startswith("fsspec-"): + raise ValueError(f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported") + self.fs = fsspec.filesystem(protocol=protocol.removeprefix("fsspec-"), **storage_options) self._session = None + self._lock = asyncio.Lock() async def _init_session(self) -> None: - lock = asyncio.Lock() - async with lock: + async with self._lock: if self._session is None: self._session = await self.fs.set_session() @@ -42,12 +44,7 @@ async def delete(self, path: str) -> None: await self.fs._rm_file(path) async def close(self) -> None: - lock = asyncio.Lock() - async with lock: + async with self._lock: if self._session is not None: - try: - await self._session.close() - except AttributeError: - # bug in aiobotocore 1.4.1 - await self._session._endpoint.http_session._session.close() + await self._session.close() self._session = None From ac3b0833dad00673f3a85cbad34b72338f4b329a Mon Sep 17 00:00:00 2001 From: shicheng0829 <523656402@qq.com> Date: Fri, 21 Oct 2022 17:35:58 +0800 Subject: [PATCH 09/26] polish fsspec ut --- tests/test_fsspec_storage_plugin.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index 919c20f..ccedf99 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import io import logging +import os import uuid import pytest @@ -16,21 +17,18 @@ logger: logging.Logger = logging.getLogger(__name__) -# _TEST_BUCKET = "torchsnapshot-test" -_TEST_BUCKET = "chengcshi" -# _TENSOR_SZ = int(100_000_000 / 4) -_TENSOR_SZ = 10 +_TEST_BUCKET = "torchsnapshot-test" +_TENSOR_SZ = int(100_000_000 / 4) -# @pytest.mark.s3_integration_test -# @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") -# @pytest.mark.usefixtures("s3_health_check") +@pytest.mark.s3_integration_test +@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") +@pytest.mark.usefixtures("s3_health_check") @pytest.mark.asyncio async def test_fsspec_s3_write_read_delete() -> None: path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}" logger.info(path) - plugin = FSSpecPlugin(root=path, key="AKIA34KDUMSNPNTFPTSA", - secret="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") + plugin = FSSpecPlugin(root=path) tensor = torch.rand((_TENSOR_SZ,)) buf = io.BytesIO() From 6ac994e0d5ae8dc0dba722d2d01ef537757fbd07 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 21 Oct 2022 17:47:51 +0800 Subject: [PATCH 10/26] Tiny polish code --- dev-requirements.txt | 1 + requirements.txt | 1 - torchsnapshot/storage_plugin.py | 4 ++++ torchsnapshot/storage_plugins/fsspec.py | 12 ++++++++---- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index bf2de90..2468d07 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,3 +8,4 @@ pytest pytest-asyncio pytest-cov pytest-timeout +fsspec diff --git a/requirements.txt b/requirements.txt index 62b3927..4ca65ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,3 @@ nest_asyncio psutil torch typing-extensions -fsspec diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index f567dc5..452675d 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -41,6 +41,10 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin return GCSStoragePlugin(root=path) + elif protocol.startswith("fsspec+"): + from torchsnapshot.storage_plugins.fsspec import FSSpecPlugin + + return FSSpecPlugin(root=path) # Registered storage plugins eps = entry_points(group="storage_plugins") diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index b5938e5..f9a0dce 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -10,15 +10,19 @@ import fsspec -from torchsnapshot.io_types import StoragePlugin, ReadIO, WriteIO +from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO + +__all__ = ["FSSpecPlugin"] class FSSpecPlugin(StoragePlugin): - def __init__(self, root: str, **storage_options) -> None: + def __init__(self, root: str) -> None: protocol, self.root = root.split("://") if not protocol.startswith("fsspec-"): - raise ValueError(f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported") - self.fs = fsspec.filesystem(protocol=protocol.removeprefix("fsspec-"), **storage_options) + raise ValueError( + f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" + ) + self.fs = fsspec.filesystem(protocol=protocol.removeprefix("fsspec-")) self._session = None self._lock = asyncio.Lock() From fd7126c98e83ee937d0be070f0453bd8f52e5fa5 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 21 Oct 2022 17:49:24 +0800 Subject: [PATCH 11/26] Fix typo --- torchsnapshot/storage_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index 452675d..aafef0c 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -41,7 +41,7 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin return GCSStoragePlugin(root=path) - elif protocol.startswith("fsspec+"): + elif protocol.startswith("fsspec-"): from torchsnapshot.storage_plugins.fsspec import FSSpecPlugin return FSSpecPlugin(root=path) From 8f499002fdee416ea3e0636f944d87f1b7aafe4a Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 21 Oct 2022 17:52:21 +0800 Subject: [PATCH 12/26] Tiny polish --- tests/test_fsspec_storage_plugin.py | 6 +++--- torchsnapshot/storage_plugin.py | 4 ++-- torchsnapshot/storage_plugins/fsspec.py | 9 ++++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index ccedf99..100626b 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -12,8 +12,8 @@ import pytest import torch -from torchsnapshot.io_types import WriteIO, ReadIO -from torchsnapshot.storage_plugins.fsspec import FSSpecPlugin +from torchsnapshot.io_types import ReadIO, WriteIO +from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin logger: logging.Logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ async def test_fsspec_s3_write_read_delete() -> None: path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}" logger.info(path) - plugin = FSSpecPlugin(root=path) + plugin = FSSpecStoragePlugin(root=path) tensor = torch.rand((_TENSOR_SZ,)) buf = io.BytesIO() diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index aafef0c..673c637 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -42,9 +42,9 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: return GCSStoragePlugin(root=path) elif protocol.startswith("fsspec-"): - from torchsnapshot.storage_plugins.fsspec import FSSpecPlugin + from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin - return FSSpecPlugin(root=path) + return FSSpecStoragePlugin(root=path) # Registered storage plugins eps = entry_points(group="storage_plugins") diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index f9a0dce..1d4c2d5 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -12,12 +12,15 @@ from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO -__all__ = ["FSSpecPlugin"] +__all__ = ["FSSpecStoragePlugin"] -class FSSpecPlugin(StoragePlugin): +class FSSpecStoragePlugin(StoragePlugin): def __init__(self, root: str) -> None: - protocol, self.root = root.split("://") + root_items = root.split("://") + if len(root_items) != 2: + raise ValueError("only protocol://path is supported by fsspec plugin") + protocol, self.root = root_items if not protocol.startswith("fsspec-"): raise ValueError( f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" From fb6cf1b4312a13906728a08c7c7417efd76ec946 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Mon, 24 Oct 2022 14:10:01 +0800 Subject: [PATCH 13/26] create dir before write and delete recursively --- dev-requirements.txt | 2 +- torchsnapshot/storage_plugins/fsspec.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 2468d07..f87e528 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,6 @@ aiobotocore boto3 +fsspec google-cloud-storage google-resumable-media numpy @@ -8,4 +9,3 @@ pytest pytest-asyncio pytest-cov pytest-timeout -fsspec diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 1d4c2d5..70704a0 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -16,7 +16,7 @@ class FSSpecStoragePlugin(StoragePlugin): - def __init__(self, root: str) -> None: + def __init__(self, root: str, **storage_options) -> None: root_items = root.split("://") if len(root_items) != 2: raise ValueError("only protocol://path is supported by fsspec plugin") @@ -25,7 +25,7 @@ def __init__(self, root: str) -> None: raise ValueError( f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" ) - self.fs = fsspec.filesystem(protocol=protocol.removeprefix("fsspec-")) + self.fs = fsspec.filesystem(protocol=protocol.removeprefix("fsspec-"), **storage_options) self._session = None self._lock = asyncio.Lock() @@ -37,18 +37,22 @@ async def _init_session(self) -> None: async def write(self, write_io: WriteIO) -> None: await self._init_session() path = os.path.join(self.root, write_io.path) + splits = path.split("/") + for i in range(len(splits)): + dir_path = "/".join(splits[:i]) + if dir_path and not await self.fs._exists(dir_path): + await self.fs._mkdir(dir_path) await self.fs._pipe_file(path, bytes(write_io.buf)) async def read(self, read_io: ReadIO) -> None: await self._init_session() path = os.path.join(self.root, read_io.path) - result = await self.fs._cat_file(path) - read_io.buf = io.BytesIO(result) + read_io.buf = io.BytesIO(await self.fs._cat_file(path)) async def delete(self, path: str) -> None: await self._init_session() path = os.path.join(self.root, path) - await self.fs._rm_file(path) + await self.fs._rm(path, recursive=True) async def close(self) -> None: async with self._lock: From 19b08c2ab0008531f69c094952258e66d95f26b2 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 20 Oct 2022 13:40:03 +0800 Subject: [PATCH 14/26] Add `storage_kwargs` to Snapshot.take --- torchsnapshot/snapshot.py | 6 +++++- torchsnapshot/storage_plugin.py | 16 ++++++++-------- torchsnapshot/storage_plugins/s3.py | 3 ++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index cdf5102..bab24d1 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -177,6 +177,7 @@ def take( app_state: AppState, pg: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, + storage_kwargs: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[ Callable[[str, torch.Tensor, bool], torch.Tensor] ] = None, @@ -194,6 +195,7 @@ def take( replicated: A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored. + storage_kwargs: The StoragePlugin's extra keyword arguments. See each StoragePlugin for doc. Returns: The newly taken snapshot. @@ -210,8 +212,10 @@ def take( app_state=app_state, replicated=replicated or [], ) + if storage_kwargs is None: + storage_kwargs = dict() storage = url_to_storage_plugin_in_event_loop( - url_path=path, event_loop=event_loop + url_path=path, event_loop=event_loop, **storage_kwargs, ) pending_io_work, metadata = cls._take_impl( path=path, diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index f567dc5..a9666aa 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -14,7 +14,7 @@ from .storage_plugins.s3 import S3StoragePlugin -def url_to_storage_plugin(url_path: str) -> StoragePlugin: +def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: """ Initialize storage plugin from url path. @@ -34,13 +34,13 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: # Built-in storage plugins if protocol == "fs": - return FSStoragePlugin(root=path) + return FSStoragePlugin(root=path, **kwargs) elif protocol == "s3": - return S3StoragePlugin(root=path) + return S3StoragePlugin(root=path, **kwargs) elif protocol == "gs": from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin - return GCSStoragePlugin(root=path) + return GCSStoragePlugin(root=path, **kwargs) # Registered storage plugins eps = entry_points(group="storage_plugins") @@ -60,9 +60,9 @@ def url_to_storage_plugin(url_path: str) -> StoragePlugin: def url_to_storage_plugin_in_event_loop( - url_path: str, event_loop: asyncio.AbstractEventLoop + url_path: str, event_loop: asyncio.AbstractEventLoop, **kwargs ) -> StoragePlugin: - async def _url_to_storage_plugin(url_path: str) -> StoragePlugin: - return url_to_storage_plugin(url_path=url_path) + async def _url_to_storage_plugin(url_path: str, **_kwargs) -> StoragePlugin: + return url_to_storage_plugin(url_path=url_path, **_kwargs) - return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path)) + return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path, **kwargs)) diff --git a/torchsnapshot/storage_plugins/s3.py b/torchsnapshot/storage_plugins/s3.py index 8ed07e4..17472f1 100644 --- a/torchsnapshot/storage_plugins/s3.py +++ b/torchsnapshot/storage_plugins/s3.py @@ -13,7 +13,7 @@ class S3StoragePlugin(StoragePlugin): - def __init__(self, root: str) -> None: + def __init__(self, root: str, **kwargs) -> None: try: from aiobotocore.session import get_session # @manual except ImportError: @@ -30,6 +30,7 @@ def __init__(self, root: str) -> None: self.bucket: str = components[0] self.root: str = "/".join(components[1:]) # pyre-ignore + # TODO: read AWS tokens from **kwargs? self.session = get_session() async def write(self, write_io: WriteIO) -> None: From 378afcc11163b4ff82c0812cc8e2cde153dace36 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 21 Oct 2022 11:27:34 +0800 Subject: [PATCH 15/26] Follow comments --- torchsnapshot/snapshot.py | 40 ++++++++++++++++++++-------- torchsnapshot/storage_plugin.py | 26 ++++++++++++------ torchsnapshot/storage_plugins/fs.py | 4 +-- torchsnapshot/storage_plugins/gcs.py | 4 +-- torchsnapshot/storage_plugins/s3.py | 5 ++-- 5 files changed, 54 insertions(+), 25 deletions(-) diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index bab24d1..30bd399 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -155,6 +155,7 @@ def __init__( self, path: str, pg: Optional[dist.ProcessGroup] = None, + storage_options: Optional[Dict[str, Any]] = None, ) -> None: """ Initializes the reference to an existing snapshot. @@ -165,10 +166,13 @@ def __init__( When unspecified: - If distributed is initialized, the global process group will be used. - If distributed is not initialized, single process is assumed. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. """ self.path: str = path self.pg: Optional[dist.ProcessGroup] = pg self._metadata: Optional[SnapshotMetadata] = None + self._storage_options = storage_options @classmethod def take( @@ -177,7 +181,7 @@ def take( app_state: AppState, pg: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, - storage_kwargs: Optional[Dict[str, Any]] = None, + storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[ Callable[[str, torch.Tensor, bool], torch.Tensor] ] = None, @@ -195,7 +199,8 @@ def take( replicated: A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored. - storage_kwargs: The StoragePlugin's extra keyword arguments. See each StoragePlugin for doc. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. Returns: The newly taken snapshot. @@ -212,10 +217,8 @@ def take( app_state=app_state, replicated=replicated or [], ) - if storage_kwargs is None: - storage_kwargs = dict() storage = url_to_storage_plugin_in_event_loop( - url_path=path, event_loop=event_loop, **storage_kwargs, + url_path=path, event_loop=event_loop, storage_options=storage_options ) pending_io_work, metadata = cls._take_impl( path=path, @@ -240,7 +243,7 @@ def take( storage.sync_close(event_loop=event_loop) event_loop.close() - snapshot = cls(path=path, pg=pg) + snapshot = cls(path=path, pg=pg, storage_options=storage_options) snapshot._metadata = metadata return snapshot @@ -251,6 +254,7 @@ def async_take( app_state: AppState, pg: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, + storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[ Callable[[str, torch.Tensor, bool], torch.Tensor] ] = None, @@ -273,6 +277,8 @@ def async_take( replicated: A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. Returns: A handle with which the newly taken snapshot can be obtained via @@ -292,7 +298,7 @@ def async_take( replicated=replicated or [], ) storage = url_to_storage_plugin_in_event_loop( - url_path=path, event_loop=event_loop + url_path=path, event_loop=event_loop, storage_options=storage_options ) pending_io_work, metadata = cls._take_impl( @@ -313,6 +319,7 @@ def async_take( metadata=metadata, storage=storage, event_loop=event_loop, + storage_options=storage_options, ) @classmethod @@ -441,6 +448,7 @@ def restore(self, app_state: AppState) -> None: Args: app_state: The program state to restore from the snapshot. + """ torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore") self._validate_app_state(app_state) @@ -449,7 +457,9 @@ def restore(self, app_state: AppState) -> None: pg_wrapper = PGWrapper(self.pg) rank = pg_wrapper.get_rank() storage = url_to_storage_plugin_in_event_loop( - url_path=self.path, event_loop=event_loop + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, ) app_state = app_state.copy() @@ -491,7 +501,9 @@ def metadata(self) -> SnapshotMetadata: if self._metadata is None: event_loop = asyncio.new_event_loop() storage = url_to_storage_plugin_in_event_loop( - url_path=self.path, event_loop=event_loop + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, ) self._metadata = self._read_snapshot_metadata( storage=storage, event_loop=event_loop @@ -561,7 +573,9 @@ def read_object( event_loop = asyncio.new_event_loop() pg_wrapper = PGWrapper(self.pg) storage = url_to_storage_plugin_in_event_loop( - url_path=self.path, event_loop=event_loop + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, ) entry = manifest[unranked_path] if isinstance(entry, PrimitiveEntry): @@ -859,12 +873,14 @@ def __init__( metadata: SnapshotMetadata, storage: StoragePlugin, event_loop: asyncio.AbstractEventLoop, + storage_options: Optional[Dict[str, Any]] = None, ) -> None: self.path = path self.pg: Optional[dist.ProcessGroup] = pg_wrapper.pg # pyre-ignore self.exc_info: Optional[Any] = None self._done = False + self._storage_options = storage_options self.thread = Thread( target=self._complete_snapshot, @@ -932,7 +948,9 @@ def wait(self) -> Snapshot: raise RuntimeError( f"Encountered exception while taking snapshot asynchronously:\n{formatted}" ) - return Snapshot(path=self.path, pg=self.pg) + return Snapshot( + path=self.path, pg=self.pg, storage_options=self._storage_options + ) def done(self) -> bool: return self._done diff --git a/torchsnapshot/storage_plugin.py b/torchsnapshot/storage_plugin.py index a9666aa..c17ba72 100644 --- a/torchsnapshot/storage_plugin.py +++ b/torchsnapshot/storage_plugin.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import asyncio +from typing import Any, Dict, Optional from importlib_metadata import entry_points @@ -14,13 +15,17 @@ from .storage_plugins.s3 import S3StoragePlugin -def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: +def url_to_storage_plugin( + url_path: str, storage_options: Optional[Dict[str, Any]] = None +) -> StoragePlugin: """ Initialize storage plugin from url path. Args: url_path: The url path following the pattern [protocol]://[path]. The protocol defaults to `fs` if unspecified. + storage_options: Additional keyword options for the StoragePlugin to use. + See each StoragePlugin's documentation for customizations. Returns: The initialized storage plugin. @@ -32,15 +37,18 @@ def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: else: protocol, path = "fs", url_path + if storage_options is None: + storage_options = dict() + # Built-in storage plugins if protocol == "fs": - return FSStoragePlugin(root=path, **kwargs) + return FSStoragePlugin(root=path, storage_options=storage_options) elif protocol == "s3": - return S3StoragePlugin(root=path, **kwargs) + return S3StoragePlugin(root=path, storage_options=storage_options) elif protocol == "gs": from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin - return GCSStoragePlugin(root=path, **kwargs) + return GCSStoragePlugin(root=path, storage_options=storage_options) # Registered storage plugins eps = entry_points(group="storage_plugins") @@ -60,9 +68,11 @@ def url_to_storage_plugin(url_path: str, **kwargs) -> StoragePlugin: def url_to_storage_plugin_in_event_loop( - url_path: str, event_loop: asyncio.AbstractEventLoop, **kwargs + url_path: str, + event_loop: asyncio.AbstractEventLoop, + storage_options: Optional[Dict[str, Any]] = None, ) -> StoragePlugin: - async def _url_to_storage_plugin(url_path: str, **_kwargs) -> StoragePlugin: - return url_to_storage_plugin(url_path=url_path, **_kwargs) + async def _url_to_storage_plugin() -> StoragePlugin: + return url_to_storage_plugin(url_path=url_path, storage_options=storage_options) - return event_loop.run_until_complete(_url_to_storage_plugin(url_path=url_path, **kwargs)) + return event_loop.run_until_complete(_url_to_storage_plugin()) diff --git a/torchsnapshot/storage_plugins/fs.py b/torchsnapshot/storage_plugins/fs.py index 500c8f1..f65af0f 100644 --- a/torchsnapshot/storage_plugins/fs.py +++ b/torchsnapshot/storage_plugins/fs.py @@ -8,7 +8,7 @@ import io import os import pathlib -from typing import Set +from typing import Any, Dict, Set import aiofiles import aiofiles.os @@ -17,7 +17,7 @@ class FSStoragePlugin(StoragePlugin): - def __init__(self, root: str) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: self.root = root self._dir_cache: Set[pathlib.Path] = set() diff --git a/torchsnapshot/storage_plugins/gcs.py b/torchsnapshot/storage_plugins/gcs.py index 9778fad..4166b54 100644 --- a/torchsnapshot/storage_plugins/gcs.py +++ b/torchsnapshot/storage_plugins/gcs.py @@ -16,7 +16,7 @@ import random import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Awaitable, Callable, Optional, TypeVar +from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar from urllib.parse import quote import google.auth.exceptions # @manual @@ -59,7 +59,7 @@ class GCSStoragePlugin(StoragePlugin): "{bucket}/o/{blob_name}?alt=media" ) - def __init__(self, root: str) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: components = root.split("/") if len(components) < 2: raise RuntimeError( diff --git a/torchsnapshot/storage_plugins/s3.py b/torchsnapshot/storage_plugins/s3.py index 17472f1..e41a7fa 100644 --- a/torchsnapshot/storage_plugins/s3.py +++ b/torchsnapshot/storage_plugins/s3.py @@ -7,13 +7,14 @@ import io import os +from typing import Any, Dict from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO from torchsnapshot.memoryview_stream import MemoryviewStream class S3StoragePlugin(StoragePlugin): - def __init__(self, root: str, **kwargs) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: try: from aiobotocore.session import get_session # @manual except ImportError: @@ -30,7 +31,7 @@ def __init__(self, root: str, **kwargs) -> None: self.bucket: str = components[0] self.root: str = "/".join(components[1:]) # pyre-ignore - # TODO: read AWS tokens from **kwargs? + # TODO: read AWS tokens from storage_options? self.session = get_session() async def write(self, write_io: WriteIO) -> None: From d1d4df932a5aa467ef5d97b861581e9bb56c2609 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Tue, 25 Oct 2022 11:18:41 +0800 Subject: [PATCH 16/26] Make storage_options optional --- torchsnapshot/storage_plugins/fs.py | 6 ++++-- torchsnapshot/storage_plugins/gcs.py | 4 +++- torchsnapshot/storage_plugins/s3.py | 6 ++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/torchsnapshot/storage_plugins/fs.py b/torchsnapshot/storage_plugins/fs.py index f65af0f..4fefdd5 100644 --- a/torchsnapshot/storage_plugins/fs.py +++ b/torchsnapshot/storage_plugins/fs.py @@ -8,7 +8,7 @@ import io import os import pathlib -from typing import Any, Dict, Set +from typing import Any, Dict, Optional, Set import aiofiles import aiofiles.os @@ -17,7 +17,9 @@ class FSStoragePlugin(StoragePlugin): - def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: + def __init__( + self, root: str, storage_options: Optional[Dict[str, Any]] = None + ) -> None: self.root = root self._dir_cache: Set[pathlib.Path] = set() diff --git a/torchsnapshot/storage_plugins/gcs.py b/torchsnapshot/storage_plugins/gcs.py index 4166b54..8fe03d5 100644 --- a/torchsnapshot/storage_plugins/gcs.py +++ b/torchsnapshot/storage_plugins/gcs.py @@ -59,7 +59,9 @@ class GCSStoragePlugin(StoragePlugin): "{bucket}/o/{blob_name}?alt=media" ) - def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: + def __init__( + self, root: str, storage_options: Optional[Dict[str, Any]] = None + ) -> None: components = root.split("/") if len(components) < 2: raise RuntimeError( diff --git a/torchsnapshot/storage_plugins/s3.py b/torchsnapshot/storage_plugins/s3.py index e41a7fa..4c825d8 100644 --- a/torchsnapshot/storage_plugins/s3.py +++ b/torchsnapshot/storage_plugins/s3.py @@ -7,14 +7,16 @@ import io import os -from typing import Any, Dict +from typing import Any, Dict, Optional from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO from torchsnapshot.memoryview_stream import MemoryviewStream class S3StoragePlugin(StoragePlugin): - def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: + def __init__( + self, root: str, storage_options: Optional[Dict[str, Any]] = None + ) -> None: try: from aiobotocore.session import get_session # @manual except ImportError: From 9524bae2e603d7cb0ab080f51212b7fc8b49edb0 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Tue, 25 Oct 2022 19:53:29 +0800 Subject: [PATCH 17/26] update fsspec storage plugin init --- torchsnapshot/storage_plugins/fsspec.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 70704a0..659514e 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -7,6 +7,7 @@ import asyncio import io import os +from typing import Dict, Any import fsspec @@ -16,7 +17,7 @@ class FSSpecStoragePlugin(StoragePlugin): - def __init__(self, root: str, **storage_options) -> None: + def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: root_items = root.split("://") if len(root_items) != 2: raise ValueError("only protocol://path is supported by fsspec plugin") From 11b82fc18502f3dff64d459daf6998ea730415ba Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 12:11:28 +0800 Subject: [PATCH 18/26] add __init__.py in io_preparers --- torchsnapshot/io_preparers/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 torchsnapshot/io_preparers/__init__.py diff --git a/torchsnapshot/io_preparers/__init__.py b/torchsnapshot/io_preparers/__init__.py new file mode 100644 index 0000000..2e41cd7 --- /dev/null +++ b/torchsnapshot/io_preparers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. From fd6a5b786b29e844991398c2a5d88240705570e3 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 12:36:07 +0800 Subject: [PATCH 19/26] init fs when init session --- torchsnapshot/storage_plugins/fsspec.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 659514e..3783096 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -26,13 +26,16 @@ def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: raise ValueError( f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" ) - self.fs = fsspec.filesystem(protocol=protocol.removeprefix("fsspec-"), **storage_options) + self._protocol = protocol.removeprefix("fsspec-") + self.fs = fsspec.filesystem(protocol=self._protocol, **storage_options) self._session = None self._lock = asyncio.Lock() + self._storage_options = storage_options async def _init_session(self) -> None: async with self._lock: if self._session is None: + self.fs = fsspec.filesystem(protocol=self._protocol, **self._storage_options) self._session = await self.fs.set_session() async def write(self, write_io: WriteIO) -> None: From 6bb274a6f3ed33246a4146a08a29ae6963991818 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 14:11:22 +0800 Subject: [PATCH 20/26] init fs system whenever session is none --- torchsnapshot/storage_plugins/fsspec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 3783096..6ff17c8 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -34,8 +34,8 @@ def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: async def _init_session(self) -> None: async with self._lock: + self.fs = fsspec.filesystem(protocol=self._protocol, **self._storage_options) if self._session is None: - self.fs = fsspec.filesystem(protocol=self._protocol, **self._storage_options) self._session = await self.fs.set_session() async def write(self, write_io: WriteIO) -> None: From 1b45d5cf3cdabb99001c3ef4899e2ec52cd84f71 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 14:53:29 +0800 Subject: [PATCH 21/26] add ut of fsspec s3 read write via snapshot --- tests/test_fsspec_storage_plugin.py | 21 ++++++++++++++++++++- torchsnapshot/storage_plugins/fsspec.py | 4 ++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/test_fsspec_storage_plugin.py b/tests/test_fsspec_storage_plugin.py index 100626b..92c31a5 100644 --- a/tests/test_fsspec_storage_plugin.py +++ b/tests/test_fsspec_storage_plugin.py @@ -12,13 +12,32 @@ import pytest import torch +import torchsnapshot from torchsnapshot.io_types import ReadIO, WriteIO from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin logger: logging.Logger = logging.getLogger(__name__) _TEST_BUCKET = "torchsnapshot-test" -_TENSOR_SZ = int(100_000_000 / 4) +_TENSOR_SZ = int(1_000_000 / 4) + + +@pytest.mark.s3_integration_test +@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") +@pytest.mark.usefixtures("s3_health_check") +def test_fsspec_s3_read_write_via_snapshot() -> None: + path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}" + logger.info(path) + + tensor = torch.rand((_TENSOR_SZ,)) + app_state = {"state": torchsnapshot.StateDict(tensor=tensor)} + snapshot = torchsnapshot.Snapshot.take(path=path, app_state=app_state) + + app_state["state"]["tensor"] = torch.rand((_TENSOR_SZ,)) + assert not torch.allclose(tensor, app_state["state"]["tensor"]) + + snapshot.restore(app_state) + assert torch.allclose(tensor, app_state["state"]["tensor"]) @pytest.mark.s3_integration_test diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 6ff17c8..242bfb4 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -7,7 +7,7 @@ import asyncio import io import os -from typing import Dict, Any +from typing import Dict, Any, Optional import fsspec @@ -17,7 +17,7 @@ class FSSpecStoragePlugin(StoragePlugin): - def __init__(self, root: str, storage_options: Dict[str, Any]) -> None: + def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None: root_items = root.split("://") if len(root_items) != 2: raise ValueError("only protocol://path is supported by fsspec plugin") From 7caf1ddf7934e89c3aabea1ee0d35bbbdc0fde79 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 16:12:00 +0800 Subject: [PATCH 22/26] support byte range read --- torchsnapshot/storage_plugins/fsspec.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 242bfb4..52a66f8 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -51,7 +51,12 @@ async def write(self, write_io: WriteIO) -> None: async def read(self, read_io: ReadIO) -> None: await self._init_session() path = os.path.join(self.root, read_io.path) - read_io.buf = io.BytesIO(await self.fs._cat_file(path)) + data = await self.fs._cat_file(path) + if read_io.byte_range is None: + read_io.buf = io.BytesIO(data) + else: + start, end = read_io.byte_range + read_io.buf = io.BytesIO(data[start:end]) async def delete(self, path: str) -> None: await self._init_session() From a82675599914f50b4c4616e08e65ee5aa0cfdaeb Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 16:43:23 +0800 Subject: [PATCH 23/26] remove removeprefix to support python 3.7 and 3.8 --- torchsnapshot/storage_plugins/fsspec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 52a66f8..12267bd 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -26,7 +26,7 @@ def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None raise ValueError( f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" ) - self._protocol = protocol.removeprefix("fsspec-") + self._protocol = protocol[len("fsspec-"):] self.fs = fsspec.filesystem(protocol=self._protocol, **storage_options) self._session = None self._lock = asyncio.Lock() From 32358df8c8fbd2823b612fa08582e935dafd79fa Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Wed, 26 Oct 2022 17:09:23 +0800 Subject: [PATCH 24/26] tiny polish --- torchsnapshot/storage_plugins/fsspec.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 12267bd..36941c9 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -27,16 +27,16 @@ def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" ) self._protocol = protocol[len("fsspec-"):] - self.fs = fsspec.filesystem(protocol=self._protocol, **storage_options) + self._fs = None self._session = None self._lock = asyncio.Lock() self._storage_options = storage_options async def _init_session(self) -> None: async with self._lock: - self.fs = fsspec.filesystem(protocol=self._protocol, **self._storage_options) if self._session is None: - self._session = await self.fs.set_session() + self._fs = fsspec.filesystem(protocol=self._protocol, **self._storage_options) + self._session = await self._fs.set_session(refresh=True) async def write(self, write_io: WriteIO) -> None: await self._init_session() @@ -44,14 +44,14 @@ async def write(self, write_io: WriteIO) -> None: splits = path.split("/") for i in range(len(splits)): dir_path = "/".join(splits[:i]) - if dir_path and not await self.fs._exists(dir_path): - await self.fs._mkdir(dir_path) - await self.fs._pipe_file(path, bytes(write_io.buf)) + if dir_path and not await self._fs._exists(dir_path): + await self._fs._mkdir(dir_path) + await self._fs._pipe_file(path, bytes(write_io.buf)) async def read(self, read_io: ReadIO) -> None: await self._init_session() path = os.path.join(self.root, read_io.path) - data = await self.fs._cat_file(path) + data = await self._fs._cat_file(path) if read_io.byte_range is None: read_io.buf = io.BytesIO(data) else: @@ -61,10 +61,11 @@ async def read(self, read_io: ReadIO) -> None: async def delete(self, path: str) -> None: await self._init_session() path = os.path.join(self.root, path) - await self.fs._rm(path, recursive=True) + await self._fs._rm(path, recursive=True) async def close(self) -> None: async with self._lock: if self._session is not None: await self._session.close() self._session = None + self._fs = None From d1b69d5db82cd18f263360267272617da802cde1 Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Thu, 27 Oct 2022 11:11:36 +0800 Subject: [PATCH 25/26] fix pre-commit --- torchsnapshot/storage_plugins/fsspec.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 36941c9..60d611d 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -7,7 +7,7 @@ import asyncio import io import os -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional import fsspec @@ -26,7 +26,7 @@ def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None raise ValueError( f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported" ) - self._protocol = protocol[len("fsspec-"):] + self._protocol = protocol[len("fsspec-") :] self._fs = None self._session = None self._lock = asyncio.Lock() @@ -35,7 +35,9 @@ def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None async def _init_session(self) -> None: async with self._lock: if self._session is None: - self._fs = fsspec.filesystem(protocol=self._protocol, **self._storage_options) + self._fs = fsspec.filesystem( + protocol=self._protocol, **self._storage_options + ) self._session = await self._fs.set_session(refresh=True) async def write(self, write_io: WriteIO) -> None: From f28236653f7832e5c2dcd6de106872ac6854f0ab Mon Sep 17 00:00:00 2001 From: Cheng <523656402@qq.com> Date: Fri, 28 Oct 2022 18:53:57 +0800 Subject: [PATCH 26/26] jump mkdir when storage plugins don't have this method --- torchsnapshot/storage_plugins/fsspec.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchsnapshot/storage_plugins/fsspec.py b/torchsnapshot/storage_plugins/fsspec.py index 60d611d..e0fbab6 100644 --- a/torchsnapshot/storage_plugins/fsspec.py +++ b/torchsnapshot/storage_plugins/fsspec.py @@ -47,7 +47,10 @@ async def write(self, write_io: WriteIO) -> None: for i in range(len(splits)): dir_path = "/".join(splits[:i]) if dir_path and not await self._fs._exists(dir_path): - await self._fs._mkdir(dir_path) + try: + await self._fs._mkdir(dir_path) + except AttributeError: + break await self._fs._pipe_file(path, bytes(write_io.buf)) async def read(self, read_io: ReadIO) -> None: