Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix writing via XRootD (called from uproot) #76

Merged
merged 11 commits into from
Feb 18, 2025
101 changes: 93 additions & 8 deletions src/fsspec_xrootd/xrootd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

Check warning on line 1 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Too many lines in module (1030/1000)

Check warning on line 1 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Missing module docstring

import asyncio
import io
Expand All @@ -10,23 +10,23 @@
from collections import defaultdict
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Callable, Coroutine, Iterable, TypeVar
from typing import Any, Callable, Coroutine, Iterable, TypeVar, cast

from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks, sync, sync_wrapper
from fsspec.exceptions import FSTimeoutError
from fsspec.spec import AbstractBufferedFile
from XRootD import client

Check failure on line 18 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Unable to import 'XRootD'
from XRootD.client.flags import (

Check failure on line 19 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Unable to import 'XRootD.client.flags'
DirListFlags,
MkDirFlags,
OpenFlags,
QueryCode,
StatInfoFlags,
)
from XRootD.client.responses import HostList, XRootDStatus

Check failure on line 26 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Unable to import 'XRootD.client.responses'


class ErrorCodes(IntEnum):

Check warning on line 29 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Missing class docstring
INVALID_PATH = 400


Expand All @@ -51,13 +51,13 @@
asyncio.get_running_loop().create_future()
)

def callback(status: XRootDStatus, content: T, servers: HostList) -> None:

Check warning on line 54 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Unused argument 'servers'
if future.cancelled():
return
loop = future.get_loop()
try:
loop.call_soon_threadsafe(future.set_result, (status, content))
except Exception as exc:

Check warning on line 60 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Catching too general exception Exception
loop.call_soon_threadsafe(future.set_exception, exc)

async def wrapped(*args: Any, **kwargs: Any) -> tuple[XRootDStatus, T]:
Expand Down Expand Up @@ -142,7 +142,7 @@
handle: client.File


class ReadonlyFileHandleCache:

Check warning on line 145 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Missing class docstring
def __init__(self, loop: Any, max_items: int | None, ttl: int):
self.loop = loop
self._max_items = max_items
Expand Down Expand Up @@ -170,7 +170,7 @@
item.handle.close(callback=lambda *args: None)
cache.clear()

def close_all(self) -> None:

Check warning on line 173 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Missing function or method docstring
self._close_all(self.loop, self._cache)

async def _close(self, url: str, timeout: int) -> None:
Expand All @@ -183,7 +183,7 @@
close = sync_wrapper(_close)

async def _start_pruner(self) -> None:
self._prune_task = asyncio.create_task(self._pruner())

Check warning on line 186 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Attribute '_prune_task' defined outside __init__

async def _pruner(self) -> None:
while True:
Expand Down Expand Up @@ -220,7 +220,7 @@
return handle


class XRootDFileSystem(AsyncFileSystem): # type: ignore[misc]

Check warning on line 223 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Missing class docstring

Check warning on line 223 in src/fsspec_xrootd/xrootd.py

View workflow job for this annotation

GitHub Actions / Format

Method '_cp_file' is abstract in class 'AsyncFileSystem' but is not overridden in child class 'XRootDFileSystem'
protocol = "root"
root_marker = "/"
default_timeout = 60
Expand Down Expand Up @@ -373,9 +373,9 @@

async def _touch(self, path: str, truncate: bool = False, **kwargs: Any) -> None:
if truncate or not await self._exists(path):
status, _ = await _async_wrap(self._myclient.truncate)(
path, size=0, timeout=self.timeout
)
f = client.File()
status, _ = await _async_wrap(f.open)(path, OpenFlags.DELETE)
await _async_wrap(f.close)()
if not status.ok:
raise OSError(f"File not touched properly: {status.message}")
else:
Expand Down Expand Up @@ -756,9 +756,9 @@
from fsspec.core import caches

self.timeout = fs.timeout
# by this point, mode will have a "b" in it
# update "+" mode removed for now since seek() is read only
if "x" in mode:
if mode == "r+b":
self.mode = OpenFlags.UPDATE
elif "x" in mode:
self.mode = OpenFlags.NEW
elif "a" in mode:
self.mode = OpenFlags.UPDATE
Expand Down Expand Up @@ -834,7 +834,7 @@

self.kwargs = kwargs

if mode not in {"ab", "rb", "wb"}:
if mode not in {"ab", "rb", "wb", "r+b"}:
raise NotImplementedError("File mode not supported")
if mode == "rb":
if size is not None:
Expand All @@ -849,6 +849,13 @@
self.forced = False
self.location = None
self.offset = 0
self.size = self._myFile.stat()[1].size
if mode == "r+b":
self.cache = caches[cache_type](
self.blocksize, self._fetch_range, self.size, **cache_options
)
if "a" in mode:
self.loc = self.size

def _locate_sources(self, logical_filename: str) -> list[str]:
"""Find hosts that have the desired file.
Expand Down Expand Up @@ -943,3 +950,81 @@
if not status.ok:
raise OSError(f"File did not close properly: {status.message}")
self.closed = True

def seek(self, loc: int, whence: int = 0) -> int:
"""Set current file location

Parameters
----------
loc: int
byte location
whence: {0, 1, 2}
from start of file, current location or end of file, resp.
"""
loc = int(loc)
if whence == 0:
nloc = loc
elif whence == 1:
nloc = self.loc + loc
elif whence == 2:
nloc = self.size + loc
else:
raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)")
if nloc < 0:
raise ValueError("Seek before start of file")
self.loc = nloc
return self.loc

def writable(self) -> bool:
"""Whether opened for writing"""
return self.mode in {"wb", "ab", "xb", "r+b"} and not self.closed

def write(self, data: bytes) -> int:
"""
Write data to buffer.

Buffer only sent on flush() or if buffer is greater than
or equal to blocksize.

Parameters
----------
data: bytes
Set of bytes to be written.
"""
if not self.writable():
raise ValueError("File not in write mode")
if self.closed:
raise ValueError("I/O operation on closed file.")
if self.forced:
raise ValueError("This file has been force-flushed, can only close")
status, _n = self._myFile.write(data, self.loc, len(data), timeout=self.timeout)
self.loc += len(data)
self.size = max(self.size, self.loc)
if not status.ok:
raise OSError(f"File did not write properly: {status.message}")
return len(data)

def read(self, length: int = -1) -> bytes:
"""
Return data from cache, or fetch pieces as necessary

Parameters
----------
length: int (-1)
Number of bytes to read; if <0, all remaining bytes.
"""
length = int(length)
if self.mode not in {"rb", "r+b"}:
raise ValueError("File not in read mode")
if length < 0:
length = self.size - self.loc
if self.closed:
raise ValueError("I/O operation on closed file.")
if length == 0:
# don't even bother calling fetch
return b""
# for mypy
out = cast(bytes, self.cache._fetch(self.loc, self.loc + length))

self.loc += len(out)
return out
19 changes: 19 additions & 0 deletions tests/test_basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@ def test_write_fsspec(localserver, clear_server):
assert f.read() == TESTDATA1


def test_write_rpb_fsspec(localserver, clear_server):
"""Test writing with r+b as in uproot"""
remoteurl, localpath = localserver
fs, _ = fsspec.core.url_to_fs(remoteurl)
filename = "test.bin"
fs.touch(localpath + "/" + filename)
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
f.write(b"Hello, this is a test file for r+b mode.")
f.flush()
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
assert f.read() == b"Hello, this is a test file for r+b mode."
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
f.seek(len(b"Hello, this is a "))
f.write(b"REPLACED ")
f.flush()
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
assert f.read() == b"Hello, this is a REPLACED for r+b mode."


@pytest.mark.parametrize("start, end", [(None, None), (None, 10), (1, None), (1, 10)])
def test_read_bytes_fsspec(localserver, clear_server, start, end):
remoteurl, localpath = localserver
Expand Down
Loading