Skip to content

Commit

Permalink
sdk/python: ObjectFile memory optimizations
Browse files Browse the repository at this point in the history
Reduce extra copy in memory by accumulating memoryviews and inducing one copy with the final cast.

Signed-off-by: Ryan Koo <[email protected]>
  • Loading branch information
rkoo19 committed Nov 5, 2024
1 parent b44ec63 commit e261aaa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
33 changes: 16 additions & 17 deletions python/aistore/sdk/obj/obj_file/object_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
#

import requests

from io import BufferedIOBase
from overrides import override
from typing import Optional
from overrides import override
from aistore.sdk.obj.content_iterator import ContentIterator
from aistore.sdk.obj.obj_file.utils import handle_chunked_encoding_error
from aistore.sdk.utils import get_logger
Expand Down Expand Up @@ -36,15 +35,15 @@ def __init__(self, content_iterator: ContentIterator, max_resume: int):
self._content_iterator = content_iterator
self._iterable = self._content_iterator.iter_from_position(0)
self._max_resume = max_resume # Maximum number of resume attempts allowed
self._remainder = bytearray() # Holds leftover data from the last chunk
self._remainder = None # Remainder from the last chunk as a memoryview
self._resume_position = 0 # Tracks the current position in the stream
self._resume_total = 0 # Tracks the number of resume attempts
self._closed = False

@override
def __enter__(self):
self._iterable = self._content_iterator.iter_from_position(0)
self._remainder = bytearray()
self._remainder = None
self._resume_position = 0
self._resume_total = 0
self._closed = False
Expand Down Expand Up @@ -79,25 +78,24 @@ def read(self, size: Optional[int] = -1) -> bytes:

# If size is -1, set it to infinity to read until the end of the stream
size = float("inf") if size == -1 else size

result = bytearray()
result = []

try:
# Consume any remaining data from a previous chunk before fetching new data
if self._remainder:
if size < len(self._remainder):
result += self._remainder[:size]
del self._remainder[:size]
result.append(self._remainder[:size])
self._remainder = self._remainder[size:]
size = 0
else:
result += self._remainder
result.append(self._remainder)
size -= len(self._remainder)
self._remainder.clear()
self._remainder = None

# Fetch new chunks from the stream as needed
while size:
try:
chunk = next(self._iterable)
chunk = memoryview(next(self._iterable))
except StopIteration:
# End of stream, exit loop
break
Expand All @@ -112,18 +110,18 @@ def read(self, size: Optional[int] = -1) -> bytes:
)
continue

# Track the position of the stream by adding the length
# of each fetched chunk
# Track the position of the stream by adding the length of each fetched chunk
self._resume_position += len(chunk)

# Add the part of the chunk that fits within the requested size and
# store any leftover data for the next read
if size < len(chunk):
result += chunk[:size]
self._remainder += chunk[size:]
result.append(chunk[:size])
self._remainder = chunk[size:]
size = 0
else:
result += chunk
result.append(chunk)
self._remainder = None
size -= len(chunk)

except Exception as err:
Expand All @@ -132,7 +130,8 @@ def read(self, size: Optional[int] = -1) -> bytes:
self.close()
raise err

return bytes(result)
# Assemble the final bytes object with a single data copy
return b"".join(result)

@override
def close(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions python/tests/unit/sdk/obj/test_object_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_init(self):
self.assertEqual(self.object_file._max_resume, 3)
self.assertEqual(self.object_file._resume_position, 0)
self.assertEqual(self.object_file._resume_total, 0)
self.assertEqual(self.object_file._remainder, bytearray())
self.assertIsNone(self.object_file._remainder)
self.assertFalse(self.object_file._closed)

# Verify that iter_from_position(0) is called
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_context_manager(self):
# State should be reset inside context
self.assertFalse(obj_file._closed)
self.assertEqual(self.object_file._resume_position, 0)
self.assertEqual(self.object_file._remainder, bytearray())
self.assertIsNone(self.object_file._remainder)

# After context, file should be closed
self.assertTrue(self.object_file._closed)
Expand Down

0 comments on commit e261aaa

Please sign in to comment.