Skip to content

Commit

Permalink
add Bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 8, 2023
1 parent c50644e commit b7eb617
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 0 deletions.
71 changes: 71 additions & 0 deletions pipegoose/core/bucket/bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch

from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError
from pipegoose.distributed.parallel_context import ParallelContext


class Bucket:
"""Store tensors in a contiguous memory space."""

def __init__(self, size: int, dtype: torch.dtype, parallel_context: ParallelContext):
assert size > 0, "Bucket size must be greater than 0."
# assert parallel_context is not None, "Parallel context must not be None."

self.size = size
self.dtype = dtype
self.parallel_context = parallel_context

self._buffer = torch.zeros(size, dtype=dtype)
self._offset = 0
self._is_closed = False
self._num_tensors = 0

@property
def is_full(self) -> bool:
return self._buffer.storage().size() == self._offset

@property
def available_size(self) -> int:
return self._buffer.storage().size() - self._offset

def add_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor), "Input must be a tensor."
assert tensor.dtype == self._buffer.dtype, "Input tensor must have the same dtype as the bucket."

if self.is_closed is True:
raise BucketClosedError("Bucket is closed.")

if self.is_full is True:
raise BucketFullError("Bucket is full.")

numel = tensor.numel()
if numel > self.available_size:
raise BucketFullError("Bucket does not have enough space.")

self._buffer[self._offset : self._offset + numel].copy_(tensor.flatten())
# NOTE: set the tensor's storage to its corresponding location in the bucket
tensor.data = self._buffer[self._offset : self._offset + numel].view_as(tensor)
self._offset += numel
self._num_tensors += 1

return tensor

@property
def is_closed(self) -> bool:
return self._is_closed

def storage(self) -> torch.Storage:
return self._buffer.storage()

def close(self):
"""Close the bucket, and not allow any more tensors to be added to it."""
assert self.is_closed is False, "Bucket is already closed."
self._is_closed = True

def free(self):
"""Delete all data in the bucket."""
assert self._offset != 0, "Bucket is empty, so no need to free memory."

def __len__(self) -> int:
"""Return the number of tensors in the bucket."""
return self._num_tensors
6 changes: 6 additions & 0 deletions pipegoose/core/bucket/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class BucketFullError(Exception):
"""Exception raised when a bucket is full and a new item is added."""


class BucketClosedError(Exception):
"""Exception raised when a bucket is closed and a new item is added."""
2 changes: 2 additions & 0 deletions pipegoose/core/bucket/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class BucketManager:
pass
2 changes: 2 additions & 0 deletions pipegoose/core/bucket/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def get_memory_address_of_tensor_storage():
pass
103 changes: 103 additions & 0 deletions tests/core/bucket/test_bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
import torch

from pipegoose.core.bucket.bucket import Bucket
from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError


class FakeParallelContext:
pass


def test_bucket():
BUCKET_SIZE = 1024
DTYPE = torch.float32

tensor = torch.randn(2, 4, dtype=DTYPE)
TENSOR_STORAGE = tensor.storage()

parallel_context = FakeParallelContext()
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context)

assert bucket.size == BUCKET_SIZE
assert bucket.dtype == DTYPE
assert bucket.available_size == BUCKET_SIZE
assert len(bucket) == 0
assert bucket.is_full is False
assert bucket.is_closed is False

new_tensor = bucket.add_tensor(tensor)

assert isinstance(new_tensor, torch.Tensor)
assert torch.equal(new_tensor, tensor)
assert bucket.available_size == BUCKET_SIZE - new_tensor.numel()
assert len(bucket) == 1
# NOTE: the new tensor should be stored in the same storage as the bucket
assert new_tensor.storage().data_ptr() == bucket.storage().data_ptr()
# NOTE: the new tensor should have a different storage from the original tensor
# since it's stored in the bucket
assert new_tensor.storage().data_ptr() != TENSOR_STORAGE.data_ptr()

# bucket.clear()

# assert bucket.available_size == BUCKET_SIZE
# assert len(bucket) == 0

bucket.close()

assert bucket.is_closed is True


def test_add_tensor_that_larger_than_bucket_size():
BUCKET_SIZE = 1024
DTYPE = torch.float32
tensor = torch.randn(2, BUCKET_SIZE, dtype=DTYPE)

parallel_context = FakeParallelContext()
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context)

with pytest.raises(Exception):
bucket.add_tensor(tensor)


def test_add_tensor_that_larger_than_available_space():
BUCKET_SIZE = 1024
DTYPE = torch.float32
tensor = torch.randn(BUCKET_SIZE - 1)
redundant_tensor = torch.randn(BUCKET_SIZE, dtype=DTYPE)

parallel_context = FakeParallelContext()
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context)

bucket.add_tensor(tensor)

with pytest.raises(BucketFullError):
bucket.add_tensor(redundant_tensor)


def test_add_a_tensor_to_a_closed_bucket():
BUCKET_SIZE = 1024
DTYPE = torch.float32
tensor = torch.randn(BUCKET_SIZE - 1)

parallel_context = FakeParallelContext()
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context)

bucket.close()

with pytest.raises(BucketClosedError):
bucket.add_tensor(tensor)


def test_add_a_tensor_with_different_dtype_to_a_bucket():
BUCKET_SIZE = 1024
DTYPE = torch.float32
tensor = torch.randn(10, dtype=torch.float16)

parallel_context = FakeParallelContext()
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context)

bucket.close()

with pytest.raises(Exception):
bucket.add_tensor(tensor)
Empty file.

0 comments on commit b7eb617

Please sign in to comment.