Skip to content

Commit

Permalink
[dattri.func] Fix ChunkedCudaProjector bug (#113)
Browse files Browse the repository at this point in the history
* - fix chunkcudaprojector bug
- add chunkcudaprojector test case for tensor input

* - fix redundancy

* - fix darglint format

* - fix chunked cuda bug

* - modify test file

---------

Co-authored-by: tingwl0122 <[email protected]>
  • Loading branch information
tingwl0122 and tingwl0122 authored Jul 23, 2024
1 parent 46be065 commit c88dd11
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 3 deletions.
40 changes: 39 additions & 1 deletion dattri/func/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def free_memory(self) -> None:
del self.ch_input
self.input_allocated = False

def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
def dict_project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
"""Performs the random projection on the feature matrix.
Args:
Expand Down Expand Up @@ -536,6 +536,42 @@ def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:

return ch_output[:actual_bs]

def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
"""Performs the random projection on the feature matrix.
Args:
features (Union[dict, Tensor]): A batch of features or a dictionary
of batch of features.
ensemble_id (int): A unique ID for this ensemble.
Returns:
Tensor: The projected features.
"""
# allocate zero tensor for output
ch_output = torch.zeros(
size=(self.feature_batch_size, self.proj_dim),
device=self.device,
dtype=self.dtype,
)
# force the input to be Tensor for now
# TODO: support dict input
if isinstance(features, dict):
features = vectorize(features, device=self.device)

pointer = 0
for chunk_idx, chunk_dim in enumerate(self.dim_per_chunk):
ch_output.add_(
self.projector_per_chunk[chunk_idx].project(
features[:, pointer : pointer + chunk_dim].contiguous(),
ensemble_id=ensemble_id,
),
)

pointer += chunk_dim

return ch_output


class ArnoldiProjector(AbstractProjector):
"""Implementation of Arnoldi projector."""
Expand Down Expand Up @@ -832,6 +868,8 @@ def make_random_projector(
proj_type = ProjectionType.rademacher

if using_cuda_projector:
# TODO: make this support dict input
# currently, only tensor input will be considered
max_chunk_size, param_chunk_sizes = get_parameter_chunk_sizes(
param_shape_list,
proj_max_batch_size,
Expand Down
34 changes: 33 additions & 1 deletion dattri/func/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _vectorize(
return arr


def get_parameter_chunk_sizes(
def _get_parameter_chunk_sizes(
param_shape_list: List,
batch_size: int,
) -> tuple[int, int]:
Expand Down Expand Up @@ -128,6 +128,38 @@ def get_parameter_chunk_sizes(
return max_chunk_size, params_per_chunk


def get_parameter_chunk_sizes(
param_shape_list: List,
batch_size: int,
) -> tuple[int, int]:
"""Compute chunk size information from feature to be projected.
Get a tuple containing max chunk size and a list of the number of
parameters in each chunk.
Args:
param_shape_list (List): A list of numbers indicating the total number of
features to be projected. A typical example is a list of parameter
size of each module in a torch.nn.Module model.
batch_size (int): The batch size. Each term (or module) in feature
will have the same batch size.
Returns:
tuple[int, int]: Maximum number of parameter per chunk and a list of
number of parameters in each chunk.
"""
# get the number of total params
param_num = param_shape_list[0]

max_chunk_size = np.iinfo(np.uint32).max // batch_size

num_chunk = param_num // max_chunk_size
remaining = param_num % max_chunk_size
params_per_chunk = [max_chunk_size] * num_chunk + [remaining]

return max_chunk_size, params_per_chunk


def flatten_params(tensors: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Flatten a dictionary of tensors into a single tensor.
Expand Down
22 changes: 21 additions & 1 deletion test/dattri/func/test_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_tensor_input_cuda(self):
test_batch_size = 64

test_tensor = torch.rand(test_batch_size, 1000)
# suppose to be BasicProjector
# suppose to be CudaProjector
project = random_project(
test_tensor,
test_batch_size,
Expand All @@ -494,6 +494,26 @@ def test_tensor_input_cuda(self):
result = project(test_tensor)
assert result.shape == (test_batch_size, self.proj_dim)

@unittest.skipUnless(torch.cuda.is_available(), "CUDA is not available")
def test_tensor_input_chunked_cuda(self):
"""Test the usage of tensor input."""
feature_batch_size = 4
# 0.3B is slighly larger then max_chunk_size (~0.26B)
test_tensor = torch.rand(feature_batch_size, 300000000)
# suppose to be ChunkedCudaProjector
project = random_project(
test_tensor,
feature_batch_size,
self.proj_dim,
self.proj_max_batch_size,
device="cuda",
proj_seed=0,
use_half_precision=True,
)

result = project(test_tensor)
assert result.shape == (feature_batch_size, self.proj_dim)

def test_arnoldi_project(self):
"""Test the funcitonality of arnoldi_project."""
feature_dim = 10
Expand Down

0 comments on commit c88dd11

Please sign in to comment.