diff --git a/dattri/func/projection.py b/dattri/func/projection.py index ac67d708..e3ae4923 100644 --- a/dattri/func/projection.py +++ b/dattri/func/projection.py @@ -457,7 +457,7 @@ def free_memory(self) -> None: del self.ch_input self.input_allocated = False - def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor: + def dict_project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor: """Performs the random projection on the feature matrix. Args: @@ -536,6 +536,42 @@ def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor: return ch_output[:actual_bs] + def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor: + """Performs the random projection on the feature matrix. + + Args: + features (Union[dict, Tensor]): A batch of features or a dictionary + of batch of features. + ensemble_id (int): A unique ID for this ensemble. + + + Returns: + Tensor: The projected features. + """ + # allocate zero tensor for output + ch_output = torch.zeros( + size=(self.feature_batch_size, self.proj_dim), + device=self.device, + dtype=self.dtype, + ) + # force the input to be Tensor for now + # TODO: support dict input + if isinstance(features, dict): + features = vectorize(features, device=self.device) + + pointer = 0 + for chunk_idx, chunk_dim in enumerate(self.dim_per_chunk): + ch_output.add_( + self.projector_per_chunk[chunk_idx].project( + features[:, pointer : pointer + chunk_dim].contiguous(), + ensemble_id=ensemble_id, + ), + ) + + pointer += chunk_dim + + return ch_output + class ArnoldiProjector(AbstractProjector): """Implementation of Arnoldi projector.""" @@ -832,6 +868,8 @@ def make_random_projector( proj_type = ProjectionType.rademacher if using_cuda_projector: + # TODO: make this support dict input + # currently, only tensor input will be considered max_chunk_size, param_chunk_sizes = get_parameter_chunk_sizes( param_shape_list, proj_max_batch_size, diff --git a/dattri/func/utils.py b/dattri/func/utils.py index c5db4859..9a916ec8 100644 --- a/dattri/func/utils.py +++ b/dattri/func/utils.py @@ -88,7 +88,7 @@ def _vectorize( return arr -def get_parameter_chunk_sizes( +def _get_parameter_chunk_sizes( param_shape_list: List, batch_size: int, ) -> tuple[int, int]: @@ -128,6 +128,38 @@ def get_parameter_chunk_sizes( return max_chunk_size, params_per_chunk +def get_parameter_chunk_sizes( + param_shape_list: List, + batch_size: int, +) -> tuple[int, int]: + """Compute chunk size information from feature to be projected. + + Get a tuple containing max chunk size and a list of the number of + parameters in each chunk. + + Args: + param_shape_list (List): A list of numbers indicating the total number of + features to be projected. A typical example is a list of parameter + size of each module in a torch.nn.Module model. + batch_size (int): The batch size. Each term (or module) in feature + will have the same batch size. + + Returns: + tuple[int, int]: Maximum number of parameter per chunk and a list of + number of parameters in each chunk. + """ + # get the number of total params + param_num = param_shape_list[0] + + max_chunk_size = np.iinfo(np.uint32).max // batch_size + + num_chunk = param_num // max_chunk_size + remaining = param_num % max_chunk_size + params_per_chunk = [max_chunk_size] * num_chunk + [remaining] + + return max_chunk_size, params_per_chunk + + def flatten_params(tensors: Dict[str, torch.Tensor]) -> torch.Tensor: """Flatten a dictionary of tensors into a single tensor. diff --git a/test/dattri/func/test_proj.py b/test/dattri/func/test_proj.py index 51e70362..86459f2c 100644 --- a/test/dattri/func/test_proj.py +++ b/test/dattri/func/test_proj.py @@ -480,7 +480,7 @@ def test_tensor_input_cuda(self): test_batch_size = 64 test_tensor = torch.rand(test_batch_size, 1000) - # suppose to be BasicProjector + # suppose to be CudaProjector project = random_project( test_tensor, test_batch_size, @@ -494,6 +494,26 @@ def test_tensor_input_cuda(self): result = project(test_tensor) assert result.shape == (test_batch_size, self.proj_dim) + @unittest.skipUnless(torch.cuda.is_available(), "CUDA is not available") + def test_tensor_input_chunked_cuda(self): + """Test the usage of tensor input.""" + feature_batch_size = 4 + # 0.3B is slighly larger then max_chunk_size (~0.26B) + test_tensor = torch.rand(feature_batch_size, 300000000) + # suppose to be ChunkedCudaProjector + project = random_project( + test_tensor, + feature_batch_size, + self.proj_dim, + self.proj_max_batch_size, + device="cuda", + proj_seed=0, + use_half_precision=True, + ) + + result = project(test_tensor) + assert result.shape == (feature_batch_size, self.proj_dim) + def test_arnoldi_project(self): """Test the funcitonality of arnoldi_project.""" feature_dim = 10