Skip to content

Commit

Permalink
fix: always pass bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
rilshok committed Nov 29, 2024
1 parent 3d118c5 commit 4fdb26d
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/iokit/extensions/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def __init__(self, data: dict[str, str], **kwargs: Any):
super().__init__(data=data_bytes, **kwargs)

def load(self) -> dict[str, str | None]:
stream = StringIO(self.data.getvalue().decode())
return dict(dotenv.dotenv_values(stream=stream))
with StringIO(self.data.getvalue().decode()) as stream:
return dict(dotenv.dotenv_values(stream=stream))
10 changes: 5 additions & 5 deletions src/iokit/extensions/gz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

class Gzip(State, suffix="gz"):
def __init__(self, state: State, *, compression: int = 1, **kwargs: Any):
data = BytesIO()
gzip_file = gzip.GzipFile(fileobj=data, mode="wb", compresslevel=compression, mtime=0)
with gzip_file as gzip_buffer:
gzip_buffer.write(state.data.getvalue())
super().__init__(data=data, name=state.name, **kwargs)
with BytesIO() as buffer:
gzip_file = gzip.GzipFile(fileobj=buffer, mode="wb", compresslevel=compression, mtime=0)
with gzip_file as gzip_buffer:
gzip_buffer.write(state.data.getvalue())
super().__init__(data=buffer.getvalue(), name=state.name, **kwargs)

def load(self) -> State:
gzip_file = gzip.GzipFile(fileobj=self.data, mode="rb")
Expand Down
12 changes: 6 additions & 6 deletions src/iokit/extensions/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def __init__(
allow_nan: bool = False,
**kwargs: Any,
):
buffer = BytesIO()
dumps = json_dumps(compact=compact, ensure_ascii=ensure_ascii, allow_nan=allow_nan)
with Writer(buffer, compact=compact, sort_keys=False, dumps=dumps) as writer:
for item in sequence:
writer.write(item)
super().__init__(data=buffer, **kwargs)
with BytesIO() as buffer:
dumps = json_dumps(compact=compact, ensure_ascii=ensure_ascii, allow_nan=allow_nan)
with Writer(buffer, compact=compact, sort_keys=False, dumps=dumps) as writer:
for item in sequence:
writer.write(item)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> list[Any]:
with Reader(self.data) as reader:
Expand Down
16 changes: 8 additions & 8 deletions src/iokit/extensions/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

class Tar(State, suffix="tar"):
def __init__(self, states: Iterable[State], **kwargs: Any):
buffer = BytesIO()
with tarfile.open(fileobj=buffer, mode="w") as tar_buffer:
for state in states:
file_data = tarfile.TarInfo(name=str(state.name))
file_data.size = state.size
file_data.mtime = int(state.time.timestamp())
tar_buffer.addfile(fileobj=state.data, tarinfo=file_data)
with BytesIO() as buffer:
with tarfile.open(fileobj=buffer, mode="w") as tar_buffer:
for state in states:
file_data = tarfile.TarInfo(name=str(state.name))
file_data.size = state.size
file_data.mtime = int(state.time.timestamp())
tar_buffer.addfile(fileobj=state.data, tarinfo=file_data)

super().__init__(data=buffer, **kwargs)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> list[State]:
states: list[State] = []
Expand Down
10 changes: 5 additions & 5 deletions src/iokit/extensions/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

class Zip(State, suffix="zip"):
def __init__(self, states: Iterable[State], **kwargs: Any):
buffer = BytesIO()
with zipfile.ZipFile(buffer, mode="w") as zip_buffer:
for state in states:
zip_buffer.writestr(str(state.name), data=state.data.getvalue())
with BytesIO() as buffer:
with zipfile.ZipFile(buffer, mode="w") as zip_buffer:
for state in states:
zip_buffer.writestr(str(state.name), data=state.data.getvalue())

super().__init__(data=buffer, **kwargs)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> list[State]:
states: list[State] = []
Expand Down
2 changes: 1 addition & 1 deletion tests/test_state_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def test_state_inheritance_json() -> None:
assert MyJson._suffixes == ("myjson",)
myjson = MyJson({"a": 1}, name="test")
assert myjson.name == "test.myjson"
loaded = State(data=myjson.data, name="test.myjson")
loaded = State(data=myjson.data.getvalue(), name="test.myjson")
assert loaded.load() == myjson.load() == {"a": 1}

0 comments on commit 4fdb26d

Please sign in to comment.