-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
|
||
from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError | ||
from pipegoose.distributed.parallel_context import ParallelContext | ||
|
||
|
||
class Bucket: | ||
"""Store tensors in a contiguous memory space.""" | ||
|
||
def __init__(self, size: int, dtype: torch.dtype, parallel_context: ParallelContext): | ||
assert size > 0, "Bucket size must be greater than 0." | ||
# assert parallel_context is not None, "Parallel context must not be None." | ||
|
||
self.size = size | ||
self.dtype = dtype | ||
self.parallel_context = parallel_context | ||
|
||
self._buffer = torch.zeros(size, dtype=dtype) | ||
self._offset = 0 | ||
self._is_closed = False | ||
self._num_tensors = 0 | ||
|
||
@property | ||
def is_full(self) -> bool: | ||
return self._buffer.storage().size() == self._offset | ||
|
||
@property | ||
def available_size(self) -> int: | ||
return self._buffer.storage().size() - self._offset | ||
|
||
def add_tensor(self, tensor: torch.Tensor) -> torch.Tensor: | ||
assert isinstance(tensor, torch.Tensor), "Input must be a tensor." | ||
assert tensor.dtype == self._buffer.dtype, "Input tensor must have the same dtype as the bucket." | ||
|
||
if self.is_closed is True: | ||
raise BucketClosedError("Bucket is closed.") | ||
|
||
if self.is_full is True: | ||
raise BucketFullError("Bucket is full.") | ||
|
||
numel = tensor.numel() | ||
if numel > self.available_size: | ||
raise BucketFullError("Bucket does not have enough space.") | ||
|
||
self._buffer[self._offset : self._offset + numel].copy_(tensor.flatten()) | ||
# NOTE: set the tensor's storage to its corresponding location in the bucket | ||
tensor.data = self._buffer[self._offset : self._offset + numel].view_as(tensor) | ||
self._offset += numel | ||
self._num_tensors += 1 | ||
|
||
return tensor | ||
|
||
@property | ||
def is_closed(self) -> bool: | ||
return self._is_closed | ||
|
||
def storage(self) -> torch.Storage: | ||
return self._buffer.storage() | ||
|
||
def close(self): | ||
"""Close the bucket, and not allow any more tensors to be added to it.""" | ||
assert self.is_closed is False, "Bucket is already closed." | ||
self._is_closed = True | ||
|
||
def free(self): | ||
"""Delete all data in the bucket.""" | ||
assert self._offset != 0, "Bucket is empty, so no need to free memory." | ||
|
||
def __len__(self) -> int: | ||
"""Return the number of tensors in the bucket.""" | ||
return self._num_tensors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
class BucketFullError(Exception): | ||
"""Exception raised when a bucket is full and a new item is added.""" | ||
|
||
|
||
class BucketClosedError(Exception): | ||
"""Exception raised when a bucket is closed and a new item is added.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class BucketManager: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def get_memory_address_of_tensor_storage(): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import pytest | ||
import torch | ||
|
||
from pipegoose.core.bucket.bucket import Bucket | ||
from pipegoose.core.bucket.exception import BucketClosedError, BucketFullError | ||
|
||
|
||
class FakeParallelContext: | ||
pass | ||
|
||
|
||
def test_bucket(): | ||
BUCKET_SIZE = 1024 | ||
DTYPE = torch.float32 | ||
|
||
tensor = torch.randn(2, 4, dtype=DTYPE) | ||
TENSOR_STORAGE = tensor.storage() | ||
|
||
parallel_context = FakeParallelContext() | ||
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) | ||
|
||
assert bucket.size == BUCKET_SIZE | ||
assert bucket.dtype == DTYPE | ||
assert bucket.available_size == BUCKET_SIZE | ||
assert len(bucket) == 0 | ||
assert bucket.is_full is False | ||
assert bucket.is_closed is False | ||
|
||
new_tensor = bucket.add_tensor(tensor) | ||
|
||
assert isinstance(new_tensor, torch.Tensor) | ||
assert torch.equal(new_tensor, tensor) | ||
assert bucket.available_size == BUCKET_SIZE - new_tensor.numel() | ||
assert len(bucket) == 1 | ||
# NOTE: the new tensor should be stored in the same storage as the bucket | ||
assert new_tensor.storage().data_ptr() == bucket.storage().data_ptr() | ||
# NOTE: the new tensor should have a different storage from the original tensor | ||
# since it's stored in the bucket | ||
assert new_tensor.storage().data_ptr() != TENSOR_STORAGE.data_ptr() | ||
|
||
# bucket.clear() | ||
|
||
# assert bucket.available_size == BUCKET_SIZE | ||
# assert len(bucket) == 0 | ||
|
||
bucket.close() | ||
|
||
assert bucket.is_closed is True | ||
|
||
|
||
def test_add_tensor_that_larger_than_bucket_size(): | ||
BUCKET_SIZE = 1024 | ||
DTYPE = torch.float32 | ||
tensor = torch.randn(2, BUCKET_SIZE, dtype=DTYPE) | ||
|
||
parallel_context = FakeParallelContext() | ||
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) | ||
|
||
with pytest.raises(Exception): | ||
bucket.add_tensor(tensor) | ||
|
||
|
||
def test_add_tensor_that_larger_than_available_space(): | ||
BUCKET_SIZE = 1024 | ||
DTYPE = torch.float32 | ||
tensor = torch.randn(BUCKET_SIZE - 1) | ||
redundant_tensor = torch.randn(BUCKET_SIZE, dtype=DTYPE) | ||
|
||
parallel_context = FakeParallelContext() | ||
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) | ||
|
||
bucket.add_tensor(tensor) | ||
|
||
with pytest.raises(BucketFullError): | ||
bucket.add_tensor(redundant_tensor) | ||
|
||
|
||
def test_add_a_tensor_to_a_closed_bucket(): | ||
BUCKET_SIZE = 1024 | ||
DTYPE = torch.float32 | ||
tensor = torch.randn(BUCKET_SIZE - 1) | ||
|
||
parallel_context = FakeParallelContext() | ||
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) | ||
|
||
bucket.close() | ||
|
||
with pytest.raises(BucketClosedError): | ||
bucket.add_tensor(tensor) | ||
|
||
|
||
def test_add_a_tensor_with_different_dtype_to_a_bucket(): | ||
BUCKET_SIZE = 1024 | ||
DTYPE = torch.float32 | ||
tensor = torch.randn(10, dtype=torch.float16) | ||
|
||
parallel_context = FakeParallelContext() | ||
bucket = Bucket(BUCKET_SIZE, DTYPE, parallel_context) | ||
|
||
bucket.close() | ||
|
||
with pytest.raises(Exception): | ||
bucket.add_tensor(tensor) |
Empty file.