Skip to content

Add fsspec storage plugin #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1134492
Add `storage_kwargs` to Snapshot.take
reyoung Oct 20, 2022
88dd7a6
add fsspec plugin
Oct 20, 2022
c28f150
tiny polish
Oct 20, 2022
6bd7e95
tiny polish
Oct 20, 2022
54f749c
Follow comments
reyoung Oct 21, 2022
5edff08
tiny polish
Oct 21, 2022
9803075
add session init and close in fsspec plugin
shicheng0829 Oct 21, 2022
a57c62e
polish fsspec plugin lock scope
shicheng0829 Oct 21, 2022
ac3b083
polish fsspec ut
shicheng0829 Oct 21, 2022
6ac994e
Tiny polish code
reyoung Oct 21, 2022
fd7126c
Fix typo
reyoung Oct 21, 2022
8f49900
Tiny polish
reyoung Oct 21, 2022
7af0c98
Merge pull request #1 from reyoung/feature/fsspec_plugin
shicheng0829 Oct 21, 2022
fb6cf1b
create dir before write and delete recursively
shicheng0829 Oct 24, 2022
c385344
fix conflict
shicheng0829 Oct 24, 2022
19b08c2
Add `storage_kwargs` to Snapshot.take
reyoung Oct 20, 2022
378afcc
Follow comments
reyoung Oct 21, 2022
d1d4df9
Make storage_options optional
reyoung Oct 25, 2022
9524bae
update fsspec storage plugin init
shicheng0829 Oct 25, 2022
62f8319
Merge branch 'main' of github.com:shicheng0829/torchsnapshot into fea…
shicheng0829 Oct 26, 2022
9eef127
fix conflict
shicheng0829 Oct 26, 2022
a86d45a
fix conflict
shicheng0829 Oct 26, 2022
11b82fc
add __init__.py in io_preparers
shicheng0829 Oct 26, 2022
fd6a5b7
init fs when init session
shicheng0829 Oct 26, 2022
6bb274a
init fs system whenever session is none
shicheng0829 Oct 26, 2022
1b45d5c
add ut of fsspec s3 read write via snapshot
shicheng0829 Oct 26, 2022
7caf1dd
support byte range read
shicheng0829 Oct 26, 2022
a826755
remove removeprefix to support python 3.7 and 3.8
shicheng0829 Oct 26, 2022
32358df
tiny polish
shicheng0829 Oct 26, 2022
d1b69d5
fix pre-commit
shicheng0829 Oct 27, 2022
f282366
jump mkdir when storage plugins don't have this method
shicheng0829 Oct 28, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aiobotocore
boto3
fsspec
google-cloud-storage
google-resumable-media
numpy
Expand Down
65 changes: 65 additions & 0 deletions tests/test_fsspec_storage_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import io
import logging
import os
import uuid

import pytest
import torch

import torchsnapshot
from torchsnapshot.io_types import ReadIO, WriteIO
from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin

logger: logging.Logger = logging.getLogger(__name__)

_TEST_BUCKET = "torchsnapshot-test"
_TENSOR_SZ = int(1_000_000 / 4)


@pytest.mark.s3_integration_test
@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="")
@pytest.mark.usefixtures("s3_health_check")
def test_fsspec_s3_read_write_via_snapshot() -> None:
path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}"
logger.info(path)

tensor = torch.rand((_TENSOR_SZ,))
app_state = {"state": torchsnapshot.StateDict(tensor=tensor)}
snapshot = torchsnapshot.Snapshot.take(path=path, app_state=app_state)

app_state["state"]["tensor"] = torch.rand((_TENSOR_SZ,))
assert not torch.allclose(tensor, app_state["state"]["tensor"])

snapshot.restore(app_state)
assert torch.allclose(tensor, app_state["state"]["tensor"])


@pytest.mark.s3_integration_test
@pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="")
@pytest.mark.usefixtures("s3_health_check")
@pytest.mark.asyncio
async def test_fsspec_s3_write_read_delete() -> None:
path = f"fsspec-s3://{_TEST_BUCKET}/{uuid.uuid4()}"
logger.info(path)
plugin = FSSpecStoragePlugin(root=path)

tensor = torch.rand((_TENSOR_SZ,))
buf = io.BytesIO()
torch.save(tensor, buf)
write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue()))

await plugin.write(write_io=write_io)

read_io = ReadIO(path="tensor")
await plugin.read(read_io=read_io)
loaded = torch.load(read_io.buf)
assert torch.allclose(tensor, loaded)

await plugin.delete(path="tensor")
await plugin.close()
5 changes: 5 additions & 0 deletions torchsnapshot/io_preparers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
6 changes: 5 additions & 1 deletion torchsnapshot/storage_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def url_to_storage_plugin(
protocol, path = "fs", url_path

if storage_options is None:
storage_options = {}
storage_options = dict()

# Built-in storage plugins
if protocol == "fs":
Expand All @@ -49,6 +49,10 @@ def url_to_storage_plugin(
from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin

return GCSStoragePlugin(root=path, storage_options=storage_options)
elif protocol.startswith("fsspec-"):
from torchsnapshot.storage_plugins.fsspec import FSSpecStoragePlugin

return FSSpecStoragePlugin(root=url_path, storage_options=storage_options)

# Registered storage plugins
eps = entry_points(group="storage_plugins")
Expand Down
76 changes: 76 additions & 0 deletions torchsnapshot/storage_plugins/fsspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import io
import os
from typing import Any, Dict, Optional

import fsspec

from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO

__all__ = ["FSSpecStoragePlugin"]


class FSSpecStoragePlugin(StoragePlugin):
def __init__(self, root: str, storage_options: Optional[Dict[str, Any]]) -> None:
root_items = root.split("://")
if len(root_items) != 2:
raise ValueError("only protocol://path is supported by fsspec plugin")
protocol, self.root = root_items
if not protocol.startswith("fsspec-"):
raise ValueError(
f"Invalid protocol: {protocol}, Only fsspec-* protocols are supported"
)
self._protocol = protocol[len("fsspec-") :]
self._fs = None
self._session = None
self._lock = asyncio.Lock()
self._storage_options = storage_options

async def _init_session(self) -> None:
async with self._lock:
if self._session is None:
self._fs = fsspec.filesystem(
protocol=self._protocol, **self._storage_options
)
self._session = await self._fs.set_session(refresh=True)

async def write(self, write_io: WriteIO) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StoragePlugin.write is responsible for creating parent directories if there are any. For example, if write_io.path is foo/bar/baz, the method needs to ensure that both foo and foo/bar exist.

There's a caveat: not all fsspec plugins support mkdir. One example I can think of is s3fs. For such plugins, we don't need to create directories, and we can't call mkdir which will raise an exception.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, the write method should be responsible for creating parent path and not all plugins support mkdir.

I think we can raise an exception if the plugin doesn't have mkdir method.

But it seems that the s3fs also provide the mkdir method.

The implement is here:

https://github.com/fsspec/s3fs/blob/main/s3fs/core.py#L803

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it seems that the s3fs also provide the mkdir method

Thanks for the pointer. Apparently I remembered it wrong.

I think we can raise an exception if the plugin doesn't have mkdir method.

If a plugin doesn't implement mkdir, most likely it's not needed and I think we can just suppress the exception.

await self._init_session()
path = os.path.join(self.root, write_io.path)
splits = path.split("/")
for i in range(len(splits)):
dir_path = "/".join(splits[:i])
if dir_path and not await self._fs._exists(dir_path):
try:
await self._fs._mkdir(dir_path)
except AttributeError:
break
await self._fs._pipe_file(path, bytes(write_io.buf))

async def read(self, read_io: ReadIO) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If byte_range is present in ReadIO, we need to honor it.

IIUC, not all fsspec plugins support seek (or efficient seek). We could provide an inefficient backoff here (we can look into that later).

@dataclass
class ReadIO:
    path: str
    buf: io.BytesIO = field(default_factory=io.BytesIO)
    byte_range: Optional[Tuple[int, int]] = None

await self._init_session()
path = os.path.join(self.root, read_io.path)
data = await self._fs._cat_file(path)
if read_io.byte_range is None:
read_io.buf = io.BytesIO(data)
else:
start, end = read_io.byte_range
read_io.buf = io.BytesIO(data[start:end])

async def delete(self, path: str) -> None:
await self._init_session()
path = os.path.join(self.root, path)
await self._fs._rm(path, recursive=True)

async def close(self) -> None:
async with self._lock:
if self._session is not None:
await self._session.close()
self._session = None
self._fs = None
2 changes: 0 additions & 2 deletions torchsnapshot/storage_plugins/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import urllib3.exceptions
from google.auth import default # @manual
from google.auth.transport.requests import AuthorizedSession # @manual

from google.resumable_media import common # @manual
from google.resumable_media.requests import ChunkedDownload, ResumableUpload # @manual

Expand All @@ -40,7 +39,6 @@
_DEFAULT_DEADLINE_SEC: int = 180
_DEFAULT_CHUNK_SIZE_BYTE: int = 100 * 1024 * 1024


T = TypeVar("T")


Expand Down