From b7eb617d559ad90b8bcc0ccdd7944052f20b7a4f Mon Sep 17 00:00:00 2001 From: xrsrke Date: Sun, 8 Oct 2023 08:01:20 +0700 Subject: [PATCH] add Bucket --- pipegoose/core/bucket/bucket.py | 71 ++++++++++++++++ pipegoose/core/bucket/exception.py | 6 ++ pipegoose/core/bucket/manager.py | 2 + pipegoose/core/bucket/utils.py | 2 + tests/core/bucket/test_bucket.py | 103 +++++++++++++++++++++++ tests/core/bucket/test_bucket_manager.py | 0 6 files changed, 184 insertions(+) create mode 100644 pipegoose/core/bucket/bucket.py create mode 100644 pipegoose/core/bucket/exception.py create mode 100644 pipegoose/core/bucket/manager.py create mode 100644 pipegoose/core/bucket/utils.py create mode 100644 tests/core/bucket/test_bucket.py create mode 100644 tests/core/bucket/test_bucket_manager.py diff --git a/pipegoose/core/bucket/bucket.py b/pipegoose/core/bucket/bucket.py new file mode 100644 index 0000000..0b144a8 --- /dev/null +++ b/pipegoose/core/bucket/bucket.py @@ -0,0 +1,71 @@ +import torch + +from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError +from pipegoose.distributed.parallel_context import ParallelContext + + +class Bucket: + """Store tensors in a contiguous memory space.""" + + def __init__(self, size: int, dtype: torch.dtype, parallel_context: ParallelContext): + assert size > 0, "Bucket size must be greater than 0." + # assert parallel_context is not None, "Parallel context must not be None." + + self.size = size + self.dtype = dtype + self.parallel_context = parallel_context + + self._buffer = torch.zeros(size, dtype=dtype) + self._offset = 0 + self._is_closed = False + self._num_tensors = 0 + + @property + def is_full(self) -> bool: + return self._buffer.storage().size() == self._offset + + @property + def available_size(self) -> int: + return self._buffer.storage().size() - self._offset + + def add_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + assert isinstance(tensor, torch.Tensor), "Input must be a tensor." + assert tensor.dtype == self._buffer.dtype, "Input tensor must have the same dtype as the bucket." + + if self.is_closed is True: + raise BucketClosedError("Bucket is closed.") + + if self.is_full is True: + raise BucketFullError("Bucket is full.") + + numel = tensor.numel() + if numel > self.available_size: + raise BucketFullError("Bucket does not have enough space.") + + self._buffer[self._offset : self._offset + numel].copy_(tensor.flatten()) + # NOTE: set the tensor's storage to its corresponding location in the bucket + tensor.data = self._buffer[self._offset : self._offset + numel].view_as(tensor) + self._offset += numel + self._num_tensors += 1 + + return tensor + + @property + def is_closed(self) -> bool: + return self._is_closed + + def storage(self) -> torch.Storage: + return self._buffer.storage() + + def close(self): + """Close the bucket, and not allow any more tensors to be added to it.""" + assert self.is_closed is False, "Bucket is already closed." + self._is_closed = True + + def free(self): + """Delete all data in the bucket.""" + assert self._offset != 0, "Bucket is empty, so no need to free memory." + + def __len__(self) -> int: + """Return the number of tensors in the bucket.""" + return self._num_tensors diff --git a/pipegoose/core/bucket/exception.py b/pipegoose/core/bucket/exception.py new file mode 100644 index 0000000..c501548 --- /dev/null +++ b/pipegoose/core/bucket/exception.py @@ -0,0 +1,6 @@ +class BucketFullError(Exception): + """Exception raised when a bucket is full and a new item is added.""" + + +class BucketClosedError(Exception): + """Exception raised when a bucket is closed and a new item is added.""" diff --git a/pipegoose/core/bucket/manager.py b/pipegoose/core/bucket/manager.py new file mode 100644 index 0000000..a9aa7eb --- /dev/null +++ b/pipegoose/core/bucket/manager.py @@ -0,0 +1,2 @@ +class BucketManager: + pass diff --git a/pipegoose/core/bucket/utils.py b/pipegoose/core/bucket/utils.py new file mode 100644 index 0000000..0cceced --- /dev/null +++ b/pipegoose/core/bucket/utils.py @@ -0,0 +1,2 @@ +def get_memory_address_of_tensor_storage(): + pass diff --git a/tests/core/bucket/test_bucket.py b/tests/core/bucket/test_bucket.py new file mode 100644 index 0000000..3910309 --- /dev/null +++ b/tests/core/bucket/test_bucket.py @@ -0,0 +1,103 @@ +import pytest +import torch + +from pipegoose.core.bucket.bucket import Bucket +from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError + + +class FakeParallelContext: + pass + + +def test_bucket(): + BUCKET_SIZE = 1024 + DTYPE = torch.float32 + + tensor = torch.randn(2, 4, dtype=DTYPE) + TENSOR_STORAGE = tensor.storage() + + parallel_context = FakeParallelContext() + bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) + + assert bucket.size == BUCKET_SIZE + assert bucket.dtype == DTYPE + assert bucket.available_size == BUCKET_SIZE + assert len(bucket) == 0 + assert bucket.is_full is False + assert bucket.is_closed is False + + new_tensor = bucket.add_tensor(tensor) + + assert isinstance(new_tensor, torch.Tensor) + assert torch.equal(new_tensor, tensor) + assert bucket.available_size == BUCKET_SIZE - new_tensor.numel() + assert len(bucket) == 1 + # NOTE: the new tensor should be stored in the same storage as the bucket + assert new_tensor.storage().data_ptr() == bucket.storage().data_ptr() + # NOTE: the new tensor should have a different storage from the original tensor + # since it's stored in the bucket + assert new_tensor.storage().data_ptr() != TENSOR_STORAGE.data_ptr() + + # bucket.clear() + + # assert bucket.available_size == BUCKET_SIZE + # assert len(bucket) == 0 + + bucket.close() + + assert bucket.is_closed is True + + +def test_add_tensor_that_larger_than_bucket_size(): + BUCKET_SIZE = 1024 + DTYPE = torch.float32 + tensor = torch.randn(2, BUCKET_SIZE, dtype=DTYPE) + + parallel_context = FakeParallelContext() + bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) + + with pytest.raises(Exception): + bucket.add_tensor(tensor) + + +def test_add_tensor_that_larger_than_available_space(): + BUCKET_SIZE = 1024 + DTYPE = torch.float32 + tensor = torch.randn(BUCKET_SIZE - 1) + redundant_tensor = torch.randn(BUCKET_SIZE, dtype=DTYPE) + + parallel_context = FakeParallelContext() + bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) + + bucket.add_tensor(tensor) + + with pytest.raises(BucketFullError): + bucket.add_tensor(redundant_tensor) + + +def test_add_a_tensor_to_a_closed_bucket(): + BUCKET_SIZE = 1024 + DTYPE = torch.float32 + tensor = torch.randn(BUCKET_SIZE - 1) + + parallel_context = FakeParallelContext() + bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) + + bucket.close() + + with pytest.raises(BucketClosedError): + bucket.add_tensor(tensor) + + +def test_add_a_tensor_with_different_dtype_to_a_bucket(): + BUCKET_SIZE = 1024 + DTYPE = torch.float32 + tensor = torch.randn(10, dtype=torch.float16) + + parallel_context = FakeParallelContext() + bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) + + bucket.close() + + with pytest.raises(Exception): + bucket.add_tensor(tensor) diff --git a/tests/core/bucket/test_bucket_manager.py b/tests/core/bucket/test_bucket_manager.py new file mode 100644 index 0000000..e69de29