From 0308b27104d8e96e9cc310f64d4d05d132c2a47f Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 23 Jan 2025 14:21:43 -0500 Subject: [PATCH 01/16] SparseConnection support --- bindsnet/learning/learning.py | 37 +++++++++--- bindsnet/network/topology.py | 105 +++------------------------------- 2 files changed, 36 insertions(+), 106 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index e2c171cd..b7ad0927 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -98,7 +98,10 @@ def update(self) -> None: (self.connection.wmin != -np.inf).any() or (self.connection.wmax != np.inf).any() ) and not isinstance(self, NoOp): - self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) + if self.connection.w.is_sparse: + raise Exception("SparseConnection isn't supported for wmin\\wmax") + else: + self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) class NoOp(LearningRule): @@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None: if self.nu[0].any(): source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0] - self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0) + update = self.reduction(torch.bmm(source_s, target_x), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w -= update del source_s, target_x # Post-synaptic update. @@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None: self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1] ) source_x = self.source.x.view(batch_size, -1).unsqueeze(2) - self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0) + update = self.reduction(torch.bmm(source_x, target_s), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w += update del source_x, target_s super().update() @@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None: # Pre-synaptic update. update = self.reduction(torch.bmm(source_s, target_x), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() self.connection.w += self.nu[0] * update # Post-synaptic update. update = self.reduction(torch.bmm(source_x, target_s), dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() self.connection.w += self.nu[1] * update super().update() @@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None: a_minus = torch.tensor(a_minus, device=self.connection.w.device) # Compute weight update based on the eligibility value of the past timestep. - update = reward * self.eligibility - self.connection.w += self.nu[0] * self.reduction(update, dim=0) + update = self.reduction(reward * self.eligibility, dim=0) + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w += self.nu[0] * update # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) @@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None: self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) self.eligibility_trace += self.eligibility / self.tc_e_trace + update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace + if self.connection.w.is_sparse: + update = update.to_sparse() # Compute weight update. - self.connection.w += ( - self.nu[0] * self.connection.dt * reward * self.eligibility_trace - ) + self.connection.w += update # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) @@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None: ) * source_x[:, None] # Compute weight update. - self.connection.w += self.nu[0] * reward * self.eligibility_trace + update = self.nu[0] * reward * self.eligibility_trace + if self.connection.w.is_sparse: + update = update.to_sparse() + self.connection.w += update super().update() diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index cb5fafa1..e2564cb6 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -126,9 +126,13 @@ def update(self, **kwargs) -> None: mask = kwargs.get("mask", None) if mask is not None: + if self.w.is_sparse: + raise Exception("Mask isn't supported for SparseConnection") self.w.masked_fill_(mask, 0) if self.Dales_rule is not None: + if self.w.is_sparse: + raise Exception("Dales_rule isn't supported for SparseConnection") # weight that are negative and should be positive are set to 0 self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0 # weight that are positive and should be negative are set to 0 @@ -1947,105 +1951,12 @@ def reset_state_variables(self) -> None: super().reset_state_variables() -class SparseConnection(AbstractConnection): +class SparseConnection(Connection): # language=rst """ Specifies sparse synapses between one or two populations of neurons. """ - def __init__( - self, - source: Nodes, - target: Nodes, - nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, - reduction: Optional[callable] = None, - weight_decay: float = None, - **kwargs, - ) -> None: - # language=rst - """ - Instantiates a :code:`Connection` object with sparse weights. - - :param source: A layer of nodes from which the connection originates. - :param target: A layer of nodes to which the connection connects. - :param nu: Learning rate for both pre- and post-synaptic events. It also - accepts a pair of tensors to individualize learning rates of each neuron. - In this case, their shape should be the same size as the connection weights. - :param reduction: Method for reducing parameter updates along the minibatch - dimension. - :param weight_decay: Constant multiple to decay weights by on each iteration. - - Keyword arguments: - - :param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format - :param float sparsity: Fraction of sparse connections to use. - :param LearningRule update_rule: Modifies connection parameters according to - some rule. - :param float wmin: Minimum allowed value on the connection weights. - :param float wmax: Maximum allowed value on the connection weights. - :param float norm: Total weight per target neuron normalization constant. - """ - super().__init__(source, target, nu, reduction, weight_decay, **kwargs) - - w = kwargs.get("w", None) - self.sparsity = kwargs.get("sparsity", None) - - assert ( - w is not None - and self.sparsity is None - or w is None - and self.sparsity is not None - ), 'Only one of "weights" or "sparsity" must be specified' - - if w is None and self.sparsity is not None: - i = torch.bernoulli( - 1 - self.sparsity * torch.ones(*source.shape, *target.shape) - ) - if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): - v = torch.clamp( - torch.rand(*source.shape, *target.shape), self.wmin, self.wmax - )[i.bool()] - else: - v = ( - self.wmin - + torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin) - )[i.bool()] - w = torch.sparse.FloatTensor(i.nonzero().t(), v) - elif w is not None and self.sparsity is None: - assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)" - if self.wmin != -np.inf or self.wmax != np.inf: - w = torch.clamp(w, self.wmin, self.wmax) - - self.w = Parameter(w, requires_grad=False) - - def compute(self, s: torch.Tensor) -> torch.Tensor: - # language=rst - """ - Compute convolutional pre-activations given spikes using layer weights. - - :param s: Incoming spikes. - :return: Incoming spikes multiplied by synaptic weights (with or without - decaying spike activation). - """ - return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1) - # return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1) - - def update(self, **kwargs) -> None: - # language=rst - """ - Compute connection's update rule. - """ - - def normalize(self) -> None: - # language=rst - """ - Normalize weights along the first axis according to total weight per target - neuron. - """ - - def reset_state_variables(self) -> None: - # language=rst - """ - Contains resetting logic for the connection. - """ - super().reset_state_variables() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.w = Parameter(self.w.to_sparse(), requires_grad=False) From d1d3e42960719fedb8a7367d853bede20a001c2c Mon Sep 17 00:00:00 2001 From: n-shevko Date: Sat, 25 Jan 2025 17:17:47 -0500 Subject: [PATCH 02/16] Test for SparseConnection --- test/network/test_connections.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/network/test_connections.py b/test/network/test_connections.py index db715e7c..961f248c 100644 --- a/test/network/test_connections.py +++ b/test/network/test_connections.py @@ -109,6 +109,12 @@ def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args, **kwargs): ): return + # SparseConnection isn't supported for wmin\\wmax + elif (conn_type == SparseConnection) and not ( + (torch.tensor(wmin, dtype=torch.float32) == -np.inf).all() + and (torch.tensor(wmax, dtype=torch.float32) == np.inf).all()): + continue + print( f"- w: {type(w).__name__}, " f"wmin: {type(wmax).__name__}, wmax: {type(wmax).__name__}" @@ -163,8 +169,9 @@ def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args, **kwargs): # tester.test_transfer() # Connections with learning ability - conn_types = [Connection, Conv2dConnection, LocalConnection] + conn_types = [Connection, SparseConnection, Conv2dConnection, LocalConnection] args = [ + [[100], [50], (100, 50)], [[100], [50], (100, 50)], [[1, 28, 28], [1, 26, 26], (1, 1, 3, 3), 3], [[1, 28, 28], [1, 26, 26], (784, 676), 3, 1, 1], From 471d4550542c09241855d73003a4de9102ab29c3 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Sat, 15 Feb 2025 12:00:40 -0500 Subject: [PATCH 03/16] Sparsity for MulticompartmentConnection --- bindsnet/network/topology.py | 11 ++- bindsnet/network/topology_features.py | 103 ++++++++++++++++++++------ 2 files changed, 91 insertions(+), 23 deletions(-) diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index e2564cb6..0f9f047c 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -446,12 +446,19 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: # Sum signals for each of the output/terminal neurons # |out_signal| = [batch_size, target.n] - out_signal = conn_spikes.view(s.size(0), self.source.n, self.target.n).sum(1) + if conn_spikes.size() != torch.Size([s.size(0), self.source.n, self.target.n]): + if conn_spikes.is_sparse: + conn_spikes = conn_spikes.to_dense() + conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n) + out_signal = conn_spikes.sum(1) if self.traces: self.activity = out_signal - return out_signal.view(s.size(0), *self.target.shape) + if out_signal.size() != torch.Size([s.size(0)] + self.target.shape): + return out_signal.view(s.size(0), *self.target.shape) + else: + return out_signal def compute_window(self, s: torch.Tensor) -> torch.Tensor: # language=rst diff --git a/bindsnet/network/topology_features.py b/bindsnet/network/topology_features.py index f99cf39f..3ba4b9c8 100644 --- a/bindsnet/network/topology_features.py +++ b/bindsnet/network/topology_features.py @@ -31,6 +31,7 @@ def __init__( enforce_polarity: Optional[bool] = False, decay: float = 0.0, parent_feature=None, + sparse: Optional[bool] = False, **kwargs, ) -> None: # language=rst @@ -47,6 +48,7 @@ def __init__( dimension :param decay: Constant multiple to decay weights by on each iteration :param parent_feature: Parent feature to inherit :code:`value` from + :param sparse: Should :code:`value` parameter be sparse tensor or not """ #### Initialize class variables #### @@ -61,6 +63,7 @@ def __init__( self.reduction = reduction self.decay = decay self.parent_feature = parent_feature + self.sparse = sparse self.kwargs = kwargs ## Backend ## @@ -119,6 +122,10 @@ def __init__( self.assert_valid_range() if value is not None: self.assert_feature_in_range() + if self.sparse: + self.value = self.value.to_sparse() + assert not getattr(self, 'enforce_polarity', False), \ + "enforce_polarity isn't supported for sparse tensors" @abstractmethod def reset_state_variables(self) -> None: @@ -161,7 +168,10 @@ def prime_feature(self, connection, device, **kwargs) -> None: # Check if values/norms are the correct shape if isinstance(self.value, torch.Tensor): - assert tuple(self.value.shape) == (connection.source.n, connection.target.n) + if self.sparse: + assert tuple(self.value.shape[1:]) == (connection.source.n, connection.target.n) + else: + assert tuple(self.value.shape) == (connection.source.n, connection.target.n) if self.norm is not None and isinstance(self.norm, torch.Tensor): assert self.norm.shape[0] == connection.target.n @@ -214,9 +224,15 @@ def normalize(self) -> None: """ if self.norm is not None: - abs_sum = self.value.sum(0).unsqueeze(0) - abs_sum[abs_sum == 0] = 1.0 - self.value *= self.norm / abs_sum + if self.sparse: + abs_sum = self.value.sum(1).to_dense() + abs_sum[abs_sum == 0] = 1.0 + abs_sum = abs_sum.unsqueeze(1).expand(-1, *self.value.shape[1:]) + self.value = self.value * (self.norm / abs_sum) + else: + abs_sum = self.value.sum(0).unsqueeze(0) + abs_sum[abs_sum == 0] = 1.0 + self.value *= self.norm / abs_sum def degrade(self) -> None: # language=rst @@ -299,11 +315,17 @@ def assert_feature_in_range(self): def assert_valid_shape(self, source_shape, target_shape, f): # Multidimensional feat - if len(f.shape) > 1: - assert f.shape == ( + if (not self.sparse and len(f.shape) > 1) or (self.sparse and len(f.shape[1:]) > 1): + if self.sparse: + f_shape = f.shape[1:] + expected = ('batch_size', source_shape, target_shape) + else: + f_shape = f.shape + expected = (source_shape, target_shape) + assert f_shape == ( source_shape, target_shape, - ), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {(source_shape, target_shape)}" + ), f"Feature {self.name} has an incorrect shape of {f.shape}. Should be of shape {expected}" # Else assume scalar, which is a valid shape @@ -319,6 +341,7 @@ def __init__( reduction: Optional[callable] = None, decay: float = 0.0, parent_feature=None, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -336,6 +359,7 @@ def __init__( dimension :param decay: Constant multiple to decay weights by on each iteration :param parent_feature: Parent feature to inherit :code:`value` from + :param sparse: Should :code:`value` parameter be sparse tensor or not """ ### Assertions ### @@ -349,10 +373,25 @@ def __init__( reduction=reduction, decay=decay, parent_feature=parent_feature, + sparse=sparse + ) + + def sparse_bernoulli(self): + values = torch.bernoulli(self.value.values()) + mask = values != 0 + indices = self.value.indices()[:, mask] + non_zero = values[mask] + return torch.sparse_coo_tensor( + indices, + non_zero, + self.value.size() ) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: - return conn_spikes * torch.bernoulli(self.value) + if self.sparse: + return conn_spikes * self.sparse_bernoulli() + else: + return conn_spikes * torch.bernoulli(self.value) def reset_state_variables(self) -> None: pass @@ -395,12 +434,14 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ Boolean mask which determines whether or not signals are allowed to traverse certain synapses. :param name: Name of the feature :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable + :param sparse: Should :code:`value` parameter be sparse tensor or not """ ### Assertions ### @@ -419,11 +460,9 @@ def __init__( super().__init__( name=name, value=value, + sparse=sparse ) - self.name = name - self.value = value - def compute(self, conn_spikes) -> torch.Tensor: return conn_spikes * self.value @@ -505,6 +544,7 @@ def __init__( reduction: Optional[callable] = None, enforce_polarity: Optional[bool] = False, decay: float = 0.0, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -523,6 +563,7 @@ def __init__( dimension :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` :param decay: Constant multiple to decay weights by on each iteration + :param sparse: Should :code:`value` parameter be sparse tensor or not """ self.norm_frequency = norm_frequency @@ -536,6 +577,7 @@ def __init__( nu=nu, reduction=reduction, decay=decay, + sparse=sparse ) def reset_state_variables(self) -> None: @@ -589,6 +631,7 @@ def __init__( value: Union[torch.Tensor, float, int] = None, range: Optional[Sequence[float]] = None, norm: Optional[Union[torch.Tensor, float, int]] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -598,6 +641,7 @@ def __init__( :param range: Range of acceptable values for the :code:`value` parameter :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) + :param sparse: Should :code:`value` parameter be sparse tensor or not """ super().__init__( @@ -605,6 +649,7 @@ def __init__( value=value, range=[-torch.inf, +torch.inf] if range is None else range, norm=norm, + sparse=sparse ) def reset_state_variables(self) -> None: @@ -629,15 +674,17 @@ def __init__( name: str, value: Union[torch.Tensor, float, int] = None, range: Optional[Sequence[float]] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ Adds scalars to signals :param name: Name of the feature :param value: Values to scale signals by + :param sparse: Should :code:`value` parameter be sparse tensor or not """ - super().__init__(name=name, value=value, range=range) + super().__init__(name=name, value=value, range=range, sparse=sparse) def reset_state_variables(self) -> None: pass @@ -666,6 +713,7 @@ def __init__( value: Union[torch.Tensor, float, int] = None, degrade_function: callable = None, parent_feature: Optional[AbstractFeature] = None, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -676,10 +724,11 @@ def __init__( :param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or constant to be *subtracted* from the propagating spikes. :param parent_feature: Parent feature with desired :code:`value` to inherit + :param sparse: Should :code:`value` parameter be sparse tensor or not """ # Note: parent_feature will override value. See abstract constructor - super().__init__(name=name, value=value, parent_feature=parent_feature) + super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse) self.degrade_function = degrade_function @@ -698,6 +747,7 @@ def __init__( ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.001, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -708,6 +758,7 @@ def __init__( :param value: Values to be use to build an initial mask for the synapses. :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. + :param sparse: Should :code:`value` parameter be sparse tensor or not """ # Define the ANN @@ -743,16 +794,18 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value) + super().__init__(name=name, value=value, sparse=sparse) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # Update the spike buffer if self.start_counter == False or conn_spikes.sum() > 0: self.start_counter = True - self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = ( - conn_spikes.flatten() - ) + if self.sparse: + flat_conn_spikes = conn_spikes.to_dense().flatten() + else: + flat_conn_spikes = conn_spikes.flatten() + self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes self.counter += 1 # Update the masks @@ -767,6 +820,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # self.mask = torch.clamp(self.mask, -1, 1) self.value = (self.mask > 0).float() + if self.sparse: + self.value = self.value.to_sparse() return conn_spikes * self.value @@ -788,6 +843,7 @@ def __init__( ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.01, + sparse: Optional[bool] = False ) -> None: # language=rst """ @@ -798,6 +854,7 @@ def __init__( :param value: Values to be use to build an initial mask for the synapses. :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. + :param sparse: Should :code:`value` parameter be sparse tensor or not """ # Define the ANN @@ -833,16 +890,18 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value) + super().__init__(name=name, value=value, sparse=sparse) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # Update the spike buffer if self.start_counter == False or conn_spikes.sum() > 0: self.start_counter = True - self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = ( - conn_spikes.flatten() - ) + if self.sparse: + flat_conn_spikes = conn_spikes.to_dense().flatten() + else: + flat_conn_spikes = conn_spikes.flatten() + self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes self.counter += 1 # Update the masks @@ -857,6 +916,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: # self.mask = torch.clamp(self.mask, -1, 1) self.value = (self.mask > 0).float() + if self.sparse: + self.value = self.value.to_sparse() return conn_spikes * self.value From 8b1cb4e40977f8f3ccc5ce14384594c2cfb05fb8 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Tue, 25 Feb 2025 19:25:22 -0500 Subject: [PATCH 04/16] Sparse batch_eth_mnist --- bindsnet/learning/MCC_learning.py | 33 ++++++++++----- bindsnet/models/models.py | 67 ++++++++++++++++++++++++------- bindsnet/network/topology.py | 6 ++- examples/mnist/batch_eth_mnist.py | 7 +++- 4 files changed, 88 insertions(+), 25 deletions(-) diff --git a/bindsnet/learning/MCC_learning.py b/bindsnet/learning/MCC_learning.py index 14565a80..66760724 100644 --- a/bindsnet/learning/MCC_learning.py +++ b/bindsnet/learning/MCC_learning.py @@ -102,7 +102,10 @@ def update(self, **kwargs) -> None: if ((self.min is not None) or (self.max is not None)) and not isinstance( self, NoOp ): - self.feature_value.clamp_(self.min, self.max) + if self.feature_value.is_sparse: + self.feature_value = self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse() + else: + self.feature_value.clamp_(self.min, self.max) @abstractmethod def reset_state_variables(self) -> None: @@ -247,10 +250,16 @@ def _connection_update(self, **kwargs) -> None: torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt ) else: - self.feature_value -= ( - self.reduction(torch.bmm(source_s, target_x), dim=0) - * self.connection.dt - ) + if self.feature_value.is_sparse: + self.feature_value -= ( + torch.bmm(source_s, target_x) + * self.connection.dt + ).to_sparse() + else: + self.feature_value -= ( + self.reduction(torch.bmm(source_s, target_x), dim=0) + * self.connection.dt + ) del source_s, target_x # Post-synaptic update. @@ -278,10 +287,16 @@ def _connection_update(self, **kwargs) -> None: torch.mean(self.average_buffer_post, dim=0) * self.connection.dt ) else: - self.feature_value += ( - self.reduction(torch.bmm(source_x, target_s), dim=0) - * self.connection.dt - ) + if self.feature_value.is_sparse: + self.feature_value += ( + torch.bmm(source_x, target_s) + * self.connection.dt + ).to_sparse() + else: + self.feature_value += ( + self.reduction(torch.bmm(source_x, target_s), dim=0) + * self.connection.dt + ) del source_x, target_s super().update() diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index 8ae3f136..f0463410 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -4,11 +4,14 @@ import torch from scipy.spatial.distance import euclidean from torch.nn.modules.utils import _pair +from torch import device from bindsnet.learning import PostPre +from bindsnet.learning.MCC_learning import PostPre as MMCPostPre from bindsnet.network import Network from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes -from bindsnet.network.topology import Connection, LocalConnection +from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection +from bindsnet.network.topology_features import Weight class TwoLayerNetwork(Network): @@ -94,6 +97,9 @@ class DiehlAndCook2015(Network): def __init__( self, n_inpt: int, + device: device, + batch_size: int, + sparse: bool = False, n_neurons: int = 100, exc: float = 22.5, inh: float = 17.5, @@ -169,28 +175,61 @@ def __init__( ) # Connections - w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) - input_exc_conn = Connection( + if sparse: + w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons) + else: + w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) + input_exc_conn = MulticompartmentConnection( source=input_layer, target=exc_layer, - w=w, - update_rule=PostPre, - nu=nu, - reduction=reduction, - wmin=wmin, - wmax=wmax, - norm=norm, + device=device, + pipeline=[ + Weight( + 'weight', + w, + range=[wmin, wmax], + norm=norm, + reduction=reduction, + nu=nu, + learning_rule=MMCPostPre, + sparse=sparse + ) + ] ) w = self.exc * torch.diag(torch.ones(self.n_neurons)) - exc_inh_conn = Connection( - source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc + if sparse: + w = w.unsqueeze(0).expand(batch_size, -1, -1) + exc_inh_conn = MulticompartmentConnection( + source=exc_layer, + target=inh_layer, + device=device, + pipeline=[ + Weight( + 'weight', + w, + range=[0, self.exc], + sparse=sparse + ) + ] ) w = -self.inh * ( torch.ones(self.n_neurons, self.n_neurons) - torch.diag(torch.ones(self.n_neurons)) ) - inh_exc_conn = Connection( - source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0 + if sparse: + w = w.unsqueeze(0).expand(batch_size, -1, -1) + inh_exc_conn = MulticompartmentConnection( + source=inh_layer, + target=exc_layer, + device=device, + pipeline=[ + Weight( + 'weight', + w, + range=[-self.inh, 0], + sparse=sparse + ) + ] ) # Add to network diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 0f9f047c..442e9a15 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -450,7 +450,11 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: if conn_spikes.is_sparse: conn_spikes = conn_spikes.to_dense() conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n) - out_signal = conn_spikes.sum(1) + + if conn_spikes.is_sparse: + out_signal = conn_spikes.to_dense().sum(1) + else: + out_signal = conn_spikes.sum(1) if self.traces: self.activity = out_signal diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 8338af19..272f1614 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -44,7 +44,8 @@ parser.add_argument("--test", dest="train", action="store_false") parser.add_argument("--plot", dest="plot", action="store_true") parser.add_argument("--gpu", dest="gpu", action="store_true") -parser.set_defaults(plot=True, gpu=True) +parser.add_argument("--sparse", dest="sparse", action="store_true") +parser.set_defaults(gpu=True) args = parser.parse_args() @@ -66,6 +67,7 @@ train = args.train plot = args.plot gpu = args.gpu +sparse = args.sparse update_steps = int(n_train / batch_size / n_updates) update_interval = update_steps * batch_size @@ -93,6 +95,9 @@ # Build network. network = DiehlAndCook2015( + device=device, + sparse=sparse, + batch_size=batch_size, n_inpt=784, n_neurons=n_neurons, exc=exc, From 12d8e03af8d8ab5b61d337f406429dd561d7ae82 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 27 Feb 2025 12:48:08 -0500 Subject: [PATCH 05/16] Sparsity for s monitor --- bindsnet/evaluation/evaluation.py | 7 ++++++- bindsnet/network/monitors.py | 11 +++++++---- examples/mnist/batch_eth_mnist.py | 21 +++++++++++---------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index 5271d762..77ba1b21 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -44,8 +44,9 @@ def assign_labels( indices = torch.nonzero(labels == i).view(-1) # Compute average firing rates for this label. + selected_spikes = torch.index_select(spikes, dim=0, index=torch.tensor(indices)) rates[:, i] = alpha * rates[:, i] + ( - torch.sum(spikes[indices], 0) / n_labeled + torch.sum(selected_spikes, 0) / n_labeled ) # Compute proportions of spike activity per class. @@ -111,6 +112,8 @@ def all_activity( # Sum over time dimension (spike ordering doesn't matter). spikes = spikes.sum(1) + if spikes.is_sparse: + spikes = spikes.to_dense() rates = torch.zeros((n_samples, n_labels), device=spikes.device) for i in range(n_labels): @@ -152,6 +155,8 @@ def proportion_weighting( # Sum over time dimension (spike ordering doesn't matter). spikes = spikes.sum(1) + if spikes.is_sparse: + spikes = spikes.to_dense() rates = torch.zeros((n_samples, n_labels), device=spikes.device) for i in range(n_labels): diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index f11a2339..dc8e9d94 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -45,6 +45,7 @@ def __init__( time: Optional[int] = None, batch_size: int = 1, device: str = "cpu", + sparse: Optional[bool] = False ): # language=rst """ @@ -62,6 +63,7 @@ def __init__( self.time = time self.batch_size = batch_size self.device = device + self.sparse = sparse # if time is not specified the monitor variable accumulate the logs if self.time is None: @@ -98,11 +100,12 @@ def record(self) -> None: for v in self.state_vars: data = getattr(self.obj, v).unsqueeze(0) # self.recording[v].append(data.detach().clone().to(self.device)) - self.recording[v].append( - torch.empty_like(data, device=self.device, requires_grad=False).copy_( - data, non_blocking=True - ) + record = torch.empty_like(data, device=self.device, requires_grad=False).copy_( + data, non_blocking=True ) + if self.sparse: + record = record.to_sparse() + self.recording[v].append(record) # remove the oldest element (first in the list) if self.time is not None: self.recording[v].pop(0) diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 272f1614..26271b31 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -147,7 +147,7 @@ spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( - network.layers[layer], state_vars=["s"], time=int(time / dt), device=device + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True ) network.add_monitor(spikes[layer], name="%s_spikes" % layer) @@ -165,7 +165,8 @@ perf_ax = None voltage_axes, voltage_ims = None, None -spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device) +spike_record = [torch.zeros((batch_size, int(time / dt), n_neurons), device=device).to_sparse() for _ in range(update_interval // batch_size)] +spike_record_idx = 0 # Train the network. print("\nBegin training...") @@ -197,12 +198,13 @@ # Convert the array of labels into a tensor label_tensor = torch.tensor(labels, device=device) + spike_record_tensor = torch.cat(spike_record, dim=0) # Get network predictions. all_activity_pred = all_activity( - spikes=spike_record, assignments=assignments, n_labels=n_classes + spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes ) proportion_pred = proportion_weighting( - spikes=spike_record, + spikes=spike_record_tensor, assignments=assignments, proportions=proportions, n_labels=n_classes, @@ -240,7 +242,7 @@ # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( - spikes=spike_record, + spikes=spike_record_tensor, labels=label_tensor, n_labels=n_classes, rates=rates, @@ -261,11 +263,10 @@ # Add to spikes recording. s = spikes["Ae"].get("s").permute((1, 0, 2)) - spike_record[ - (step * batch_size) - % update_interval : (step * batch_size % update_interval) - + s.size(0) - ] = s + spike_record[spike_record_idx] = s + spike_record_idx += 1 + if spike_record_idx == len(spike_record): + spike_record_idx = 0 # Get voltage recording. exc_voltages = exc_voltage_monitor.get("v") From 2c0760d53183deb17379908584f85a7578561181 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 6 Mar 2025 13:00:02 -0500 Subject: [PATCH 06/16] Add batch dimension in case if sparse == True --- bindsnet/models/models.py | 8 ++- bindsnet/network/topology_features.py | 74 +++++++++++++++++++-------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index f0463410..20eb57c2 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -175,10 +175,7 @@ def __init__( ) # Connections - if sparse: - w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons) - else: - w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) + w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) input_exc_conn = MulticompartmentConnection( source=input_layer, target=exc_layer, @@ -192,7 +189,8 @@ def __init__( reduction=reduction, nu=nu, learning_rule=MMCPostPre, - sparse=sparse + sparse=sparse, + batch_size=batch_size ) ] ) diff --git a/bindsnet/network/topology_features.py b/bindsnet/network/topology_features.py index 3ba4b9c8..81145451 100644 --- a/bindsnet/network/topology_features.py +++ b/bindsnet/network/topology_features.py @@ -32,6 +32,7 @@ def __init__( decay: float = 0.0, parent_feature=None, sparse: Optional[bool] = False, + batch_size: int = 1, **kwargs, ) -> None: # language=rst @@ -49,6 +50,7 @@ def __init__( :param decay: Constant multiple to decay weights by on each iteration :param parent_feature: Parent feature to inherit :code:`value` from :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ #### Initialize class variables #### @@ -64,6 +66,7 @@ def __init__( self.decay = decay self.parent_feature = parent_feature self.sparse = sparse + self.batch_size = batch_size self.kwargs = kwargs ## Backend ## @@ -120,12 +123,19 @@ def __init__( ) self.assert_valid_range() - if value is not None: - self.assert_feature_in_range() - if self.sparse: - self.value = self.value.to_sparse() - assert not getattr(self, 'enforce_polarity', False), \ - "enforce_polarity isn't supported for sparse tensors" + if value is None: + return + + self.assert_feature_in_range() + if not self.sparse: + return + + if len(self.value.shape) == 2: + self.value = self.value.unsqueeze(0).repeat(self.batch_size, 1, 1) + + self.value = self.value.to_sparse() + assert not getattr(self, 'enforce_polarity', False), \ + "enforce_polarity isn't supported for sparse tensors" @abstractmethod def reset_state_variables(self) -> None: @@ -341,7 +351,8 @@ def __init__( reduction: Optional[callable] = None, decay: float = 0.0, parent_feature=None, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -360,6 +371,7 @@ def __init__( :param decay: Constant multiple to decay weights by on each iteration :param parent_feature: Parent feature to inherit :code:`value` from :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ ### Assertions ### @@ -373,7 +385,8 @@ def __init__( reduction=reduction, decay=decay, parent_feature=parent_feature, - sparse=sparse + sparse=sparse, + batch_size=batch_size ) def sparse_bernoulli(self): @@ -434,7 +447,8 @@ def __init__( self, name: str, value: Union[torch.Tensor, float, int] = None, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -442,6 +456,7 @@ def __init__( :param name: Name of the feature :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ ### Assertions ### @@ -460,7 +475,8 @@ def __init__( super().__init__( name=name, value=value, - sparse=sparse + sparse=sparse, + batch_size=batch_size ) def compute(self, conn_spikes) -> torch.Tensor: @@ -544,7 +560,8 @@ def __init__( reduction: Optional[callable] = None, enforce_polarity: Optional[bool] = False, decay: float = 0.0, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -564,6 +581,7 @@ def __init__( :param enforce_polarity: Will prevent synapses from changing signs if :code:`True` :param decay: Constant multiple to decay weights by on each iteration :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ self.norm_frequency = norm_frequency @@ -577,7 +595,8 @@ def __init__( nu=nu, reduction=reduction, decay=decay, - sparse=sparse + sparse=sparse, + batch_size=batch_size ) def reset_state_variables(self) -> None: @@ -631,7 +650,8 @@ def __init__( value: Union[torch.Tensor, float, int] = None, range: Optional[Sequence[float]] = None, norm: Optional[Union[torch.Tensor, float, int]] = None, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -642,6 +662,7 @@ def __init__( :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample and after the value has been updated by the learning rule (if there is one) :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ super().__init__( @@ -649,7 +670,8 @@ def __init__( value=value, range=[-torch.inf, +torch.inf] if range is None else range, norm=norm, - sparse=sparse + sparse=sparse, + batch_size=batch_size ) def reset_state_variables(self) -> None: @@ -674,7 +696,8 @@ def __init__( name: str, value: Union[torch.Tensor, float, int] = None, range: Optional[Sequence[float]] = None, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -682,9 +705,10 @@ def __init__( :param name: Name of the feature :param value: Values to scale signals by :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ - super().__init__(name=name, value=value, range=range, sparse=sparse) + super().__init__(name=name, value=value, range=range, sparse=sparse, batch_size=batch_size) def reset_state_variables(self) -> None: pass @@ -713,7 +737,8 @@ def __init__( value: Union[torch.Tensor, float, int] = None, degrade_function: callable = None, parent_feature: Optional[AbstractFeature] = None, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -725,10 +750,11 @@ def __init__( constant to be *subtracted* from the propagating spikes. :param parent_feature: Parent feature with desired :code:`value` to inherit :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ # Note: parent_feature will override value. See abstract constructor - super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse) + super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse, batch_size=batch_size) self.degrade_function = degrade_function @@ -747,7 +773,8 @@ def __init__( ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.001, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -759,6 +786,7 @@ def __init__( :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ # Define the ANN @@ -794,7 +822,7 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value, sparse=sparse) + super().__init__(name=name, value=value, sparse=sparse, batch_size=batch_size) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: @@ -843,7 +871,8 @@ def __init__( ann_values: Union[list, tuple] = None, const_update_rate: float = 0.1, const_decay: float = 0.01, - sparse: Optional[bool] = False + sparse: Optional[bool] = False, + batch_size: int = 1 ) -> None: # language=rst """ @@ -855,6 +884,7 @@ def __init__( :param const_update_rate: The mask upatate rate of the ANN decision. :param const_decay: The spontaneous activation of the synapses. :param sparse: Should :code:`value` parameter be sparse tensor or not + :param batch_size: Mini-batch size. """ # Define the ANN @@ -890,7 +920,7 @@ def forward(self, x): self.const_update_rate = const_update_rate self.const_decay = const_decay - super().__init__(name=name, value=value, sparse=sparse) + super().__init__(name=name, value=value, sparse=sparse, batch_size=batch_size) def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]: From 9678254ed233dfc2d416e0a0f54f7e8d58373c70 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 6 Mar 2025 14:23:57 -0500 Subject: [PATCH 07/16] Add doc for sparse=True --- docs/source/guide/guide_part_i.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/source/guide/guide_part_i.rst b/docs/source/guide/guide_part_i.rst index d8bc35d7..a66d18fb 100644 --- a/docs/source/guide/guide_part_i.rst +++ b/docs/source/guide/guide_part_i.rst @@ -172,6 +172,27 @@ To create a simple all-to-all connection with a weight and bias: pipeline=[weight, bias] ) +Feature values (e.g., weights, biases) can be represented as sparse tensors for memory efficiency. +To enable sparse tensor support: + +1. Set the sparse=True parameter when initializing the feature. +2. If the value tensor does not include a batch dimension, explicitly specify the batch_size parameter. + +Example 1: Batch dimension included in the value tensor (first axis): + +.. code-block:: python + + weights = Weight(name='weight_feature', value=torch.rand(2, 100, 1000), sparse=True) # Batch size = 2 + bias = Bias(name='bias_feature', value=torch.rand(2, 100, 1000), sparse=True) + + +Example 2: Batch dimension specified via batch_size parameter (no batch axis in value): + +.. code-block:: python + + weights = Weight(name='weight_feature', value=torch.rand(100, 1000), sparse=True, batch_size=2) + bias = Bias(name='bias_feature', value=torch.rand(100, 1000), sparse=True, batch_size=2) + Specifying monitors ******************* From fd0ef9b9392e1ccfc5433e9033fb137bc3aef5f5 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 6 Mar 2025 14:34:47 -0500 Subject: [PATCH 08/16] Make --sparse influence Monitors --- examples/mnist/batch_eth_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 26271b31..68e3d347 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -147,7 +147,7 @@ spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( - network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=sparse ) network.add_monitor(spikes[layer], name="%s_spikes" % layer) From ec6356605a506092c788ca8b92b5b12c25142b1b Mon Sep 17 00:00:00 2001 From: n-shevko Date: Thu, 6 Mar 2025 15:25:33 -0500 Subject: [PATCH 09/16] Note about performance --- docs/source/guide/guide_part_i.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/guide/guide_part_i.rst b/docs/source/guide/guide_part_i.rst index a66d18fb..a10800ca 100644 --- a/docs/source/guide/guide_part_i.rst +++ b/docs/source/guide/guide_part_i.rst @@ -193,6 +193,9 @@ Example 2: Batch dimension specified via batch_size parameter (no batch axis in weights = Weight(name='weight_feature', value=torch.rand(100, 1000), sparse=True, batch_size=2) bias = Bias(name='bias_feature', value=torch.rand(100, 1000), sparse=True, batch_size=2) + +Note that subtraction and addition operations for sparse tensors are 90 times slower than those for dense tensors (tested on RTX 3060) + Specifying monitors ******************* From bed0623195172ce251213d0c586721ed02e6a9f4 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Sat, 8 Mar 2025 15:33:27 -0500 Subject: [PATCH 10/16] Make device and batch_size optional --- bindsnet/models/models.py | 4 ++-- examples/mnist/batch_eth_mnist.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index 20eb57c2..eea1a8c8 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -97,8 +97,8 @@ class DiehlAndCook2015(Network): def __init__( self, n_inpt: int, - device: device, - batch_size: int, + device: str = "cpu", + batch_size: int = None, sparse: bool = False, n_neurons: int = 100, exc: float = 22.5, diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 68e3d347..f1e0efb0 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -277,7 +277,7 @@ image = batch["image"][:, 0].view(28, 28) inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28) lable = batch["label"][0] - input_exc_weights = network.connections[("X", "Ae")].w + input_exc_weights = network.connections[("X", "Ae")].feature_index['weight'].value square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28 ) From 326d270b8f473310225ff0f6e31666bcbbe26f88 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Sat, 8 Mar 2025 15:54:58 -0500 Subject: [PATCH 11/16] Add sparse support for other learning rules --- bindsnet/learning/MCC_learning.py | 32 +++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/bindsnet/learning/MCC_learning.py b/bindsnet/learning/MCC_learning.py index 66760724..d524e4b8 100644 --- a/bindsnet/learning/MCC_learning.py +++ b/bindsnet/learning/MCC_learning.py @@ -523,16 +523,18 @@ def _connection_update(self, **kwargs) -> None: self.average_buffer_index + 1 ) % self.average_update - if self.continues_update: - self.feature_value += self.nu[0] * torch.mean( - self.average_buffer, dim=0 - ) - elif self.average_buffer_index == 0: - self.feature_value += self.nu[0] * torch.mean( + if self.continues_update or self.average_buffer_index == 0: + update = self.nu[0] * torch.mean( self.average_buffer, dim=0 ) + if self.feature_value.is_sparse: + update = update.to_sparse() + self.feature_value += update else: - self.feature_value += self.nu[0] * self.reduction(update, dim=0) + update = self.nu[0] * self.reduction(update, dim=0) + if self.feature_value.is_sparse: + update = update.to_sparse() + self.feature_value += update # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) @@ -701,14 +703,16 @@ def _connection_update(self, **kwargs) -> None: self.average_buffer_index + 1 ) % self.average_update - if self.continues_update: - self.feature_value += torch.mean(self.average_buffer, dim=0) - elif self.average_buffer_index == 0: - self.feature_value += torch.mean(self.average_buffer, dim=0) + if self.continues_update or self.average_buffer_index == 0: + update = torch.mean(self.average_buffer, dim=0) + if self.feature_value.is_sparse: + update = update.to_sparse() + self.feature_value += update else: - self.feature_value += ( - self.nu[0] * self.connection.dt * reward * self.eligibility_trace - ) + update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace + if self.feature_value.is_sparse: + update = update.to_sparse() + self.feature_value += update # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay From 95626a90af95821064eab680173c14de439ac431 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Wed, 19 Mar 2025 13:43:44 -0400 Subject: [PATCH 12/16] Sparse tensors for monitors explanation --- docs/source/guide/guide_part_i.rst | 41 +++++++++++++++++++ examples/benchmark/sparse_vs_dense_tensors.py | 38 +++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 examples/benchmark/sparse_vs_dense_tensors.py diff --git a/docs/source/guide/guide_part_i.rst b/docs/source/guide/guide_part_i.rst index a10800ca..7ae4f9a8 100644 --- a/docs/source/guide/guide_part_i.rst +++ b/docs/source/guide/guide_part_i.rst @@ -278,6 +278,47 @@ Similarly, one can get the contents of a network monitor by calling :code:`netwo function takes no arguments; it returns a dictionary mapping network components to a sub-dictionary mapping state variables to their tensor-valued recording. + +:py:class:`bindsnet.network.monitors.AbstractMonitor` objects can also store sparse tensor-valued variables. +For example, spikes can be stored efficiently using a sparse monitor: + +.. code-block:: python + + Monitor( + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True + ) + +Note that using sparse tensors is advantageous only when the percentage of non-zero values is less than 4% of the total values. +The table below compares memory consumption between sparse and dense tensors: + +======================= ====================== ====================== ==================== +Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values +======================= ====================== ====================== ==================== +15 119 13 0.5 +30 119 25 1.0 +45 119 38 1.5 +60 119 50 2.0 +75 119 63 2.5 +89 119 75 3.0 +104 119 87 3.5 +119 119 100 4.0 +134 119 113 4.5 +149 119 125 5.0 +164 119 138 5.5 +179 119 150 6.0 +194 119 163 6.5 +209 119 176 7.0 +224 119 188 7.5 +239 119 201 8.0 +253 119 213 8.5 +268 119 225 9.0 +283 119 238 9.5 +======================= ====================== ====================== ==================== + +The tensor size does not affect the values in the third column. +This table was generated by :code:`examples/benchmark/sparse_vs_dense_tensors.py` + + Running Simulations ------------------- diff --git a/examples/benchmark/sparse_vs_dense_tensors.py b/examples/benchmark/sparse_vs_dense_tensors.py new file mode 100644 index 00000000..1b52b168 --- /dev/null +++ b/examples/benchmark/sparse_vs_dense_tensors.py @@ -0,0 +1,38 @@ +import torch + + +assert torch.cuda.is_available(), 'Benchmark works only on cuda' +device = torch.device("cuda") + + +def create_spikes_tensor(percent_of_true_values, sparse): + spikes_tensor = torch.bernoulli( + torch.full((500, 500, 500), percent_of_true_values, device=device) + ).bool() + if sparse: + spikes_tensor = spikes_tensor.to_sparse() + + torch.cuda.reset_peak_memory_stats(device=device) + return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)) + + +print('======================= ====================== ====================== ====================') +print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values') +print('======================= ====================== ====================== ====================') +percent_of_true_values = 0.005 +while percent_of_true_values < 0.1: + result = {} + for sparse in [True, False]: + result[sparse] = create_spikes_tensor(percent_of_true_values, sparse) + percent = round((result[True] / result[False]) * 100) + + row = [ + str(result[True]).ljust(23), + str(result[False]).ljust(22), + str(percent).ljust(22), + str(round(percent_of_true_values * 100, 1)).ljust(20), + ] + print(' '.join(row)) + percent_of_true_values += 0.005 + +print('======================= ====================== ====================== ====================') From 9ff6f0157cec9b9898fb0a781a85631469428b56 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Mon, 14 Apr 2025 15:38:56 -0400 Subject: [PATCH 13/16] Add runtime experiment --- examples/benchmark/sparse_vs_dense_tensors.py | 124 ++++++++++++++---- 1 file changed, 97 insertions(+), 27 deletions(-) diff --git a/examples/benchmark/sparse_vs_dense_tensors.py b/examples/benchmark/sparse_vs_dense_tensors.py index 1b52b168..3b7f83ac 100644 --- a/examples/benchmark/sparse_vs_dense_tensors.py +++ b/examples/benchmark/sparse_vs_dense_tensors.py @@ -1,38 +1,108 @@ import torch +import time +import argparse + +from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting + + +parser = argparse.ArgumentParser() +parser.add_argument("--benchmark_type", choices=['memory', 'runtime'], default='memory') +args = parser.parse_args() assert torch.cuda.is_available(), 'Benchmark works only on cuda' -device = torch.device("cuda") +device = torch.device("cpu") +shape = (500, 500, 500) -def create_spikes_tensor(percent_of_true_values, sparse): +def create_spikes_tensor(percent_of_true_values, sparse, return_memory_usage=True): spikes_tensor = torch.bernoulli( - torch.full((500, 500, 500), percent_of_true_values, device=device) + torch.full(shape, percent_of_true_values, device=device) ).bool() if sparse: spikes_tensor = spikes_tensor.to_sparse() - torch.cuda.reset_peak_memory_stats(device=device) - return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)) - - -print('======================= ====================== ====================== ====================') -print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values') -print('======================= ====================== ====================== ====================') -percent_of_true_values = 0.005 -while percent_of_true_values < 0.1: - result = {} - for sparse in [True, False]: - result[sparse] = create_spikes_tensor(percent_of_true_values, sparse) - percent = round((result[True] / result[False]) * 100) - - row = [ - str(result[True]).ljust(23), - str(result[False]).ljust(22), - str(percent).ljust(22), - str(round(percent_of_true_values * 100, 1)).ljust(20), - ] - print(' '.join(row)) - percent_of_true_values += 0.005 - -print('======================= ====================== ====================== ====================') + if return_memory_usage: + torch.cuda.reset_peak_memory_stats(device=device) + return round(torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)) + else: + return spikes_tensor + + +def memory_benchmark(): + print('======================= ====================== ====================== ====================') + print('Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values') + print('======================= ====================== ====================== ====================') + percent_of_true_values = 0.005 + while percent_of_true_values < 0.1: + result = {} + for sparse in [True, False]: + result[sparse] = create_spikes_tensor(percent_of_true_values, sparse) + percent = round((result[True] / result[False]) * 100) + + row = [ + str(result[True]).ljust(23), + str(result[False]).ljust(22), + str(percent).ljust(22), + str(round(percent_of_true_values * 100, 1)).ljust(20), + ] + print(' '.join(row)) + percent_of_true_values += 0.005 + + print('======================= ====================== ====================== ====================') + + +def run(sparse): + n_classes = 10 + proportions = torch.zeros((500, n_classes), device=device) + rates = torch.zeros((500, n_classes), device=device) + assignments = -torch.ones(500, device=device) + spike_record = [] + for _ in range(5): + tmp = torch.zeros(shape, device=device) + spike_record.append(tmp.to_sparse() if sparse else tmp) + + spike_record_idx = 0 + + delta = 0 + for _ in range(10): + start = time.perf_counter() + label_tensor = torch.randint(0, n_classes, (n_classes,), device=device) + spike_record_tensor = torch.cat(spike_record, dim=0) + all_activity( + spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes + ) + proportion_weighting( + spikes=spike_record_tensor, + assignments=assignments, + proportions=proportions, + n_labels=n_classes, + ) + + assignments, proportions, rates = assign_labels( + spikes=spike_record_tensor, + labels=label_tensor, + n_labels=n_classes, + rates=rates, + ) + delta += time.perf_counter() - start + spike_record[spike_record_idx] = create_spikes_tensor( + 0.03, + sparse, + return_memory_usage=False + ) + spike_record_idx += 1 + if spike_record_idx == len(spike_record): + spike_record_idx = 0 + return round(delta, 1) + + +def runtime_benchmark(): + print(f"Sparse runtime: {run(True)} seconds") + print(f"Dense runtime: {run(False)} seconds") + + +if args.benchmark_type == 'memory': + memory_benchmark() +else: + runtime_benchmark() From 3c9528a710c9cce09ec37dfda50efc3618402bf7 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Tue, 15 Apr 2025 10:27:50 -0400 Subject: [PATCH 14/16] Use cuda --- examples/benchmark/sparse_vs_dense_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/sparse_vs_dense_tensors.py b/examples/benchmark/sparse_vs_dense_tensors.py index 3b7f83ac..5c1855ae 100644 --- a/examples/benchmark/sparse_vs_dense_tensors.py +++ b/examples/benchmark/sparse_vs_dense_tensors.py @@ -11,7 +11,7 @@ assert torch.cuda.is_available(), 'Benchmark works only on cuda' -device = torch.device("cpu") +device = torch.device("cuda") shape = (500, 500, 500) From 717ee17796fbb30af62c01029a09b5582732fab2 Mon Sep 17 00:00:00 2001 From: n-shevko Date: Tue, 15 Apr 2025 10:34:15 -0400 Subject: [PATCH 15/16] Make tensor smaller to fit in 7gb of gpu memory --- examples/benchmark/sparse_vs_dense_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/sparse_vs_dense_tensors.py b/examples/benchmark/sparse_vs_dense_tensors.py index 5c1855ae..ec862867 100644 --- a/examples/benchmark/sparse_vs_dense_tensors.py +++ b/examples/benchmark/sparse_vs_dense_tensors.py @@ -12,7 +12,7 @@ assert torch.cuda.is_available(), 'Benchmark works only on cuda' device = torch.device("cuda") -shape = (500, 500, 500) +shape = (300, 500, 500) def create_spikes_tensor(percent_of_true_values, sparse, return_memory_usage=True): From 988cb91824f84ed5b5003de3db584b487f52321a Mon Sep 17 00:00:00 2001 From: n-shevko Date: Tue, 15 Apr 2025 11:47:11 -0400 Subject: [PATCH 16/16] Documentation update --- docs/source/guide/guide_part_i.rst | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/source/guide/guide_part_i.rst b/docs/source/guide/guide_part_i.rst index 7ae4f9a8..dd067204 100644 --- a/docs/source/guide/guide_part_i.rst +++ b/docs/source/guide/guide_part_i.rst @@ -288,8 +288,17 @@ For example, spikes can be stored efficiently using a sparse monitor: network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True ) -Note that using sparse tensors is advantageous only when the percentage of non-zero values is less than 4% of the total values. -The table below compares memory consumption between sparse and dense tensors: + +Performance Considerations: + + +While sparse tensors reduce memory usage when the percentage of non-zero values is below 4% (see table below), +there is a trade-off in computational speed. Benchmarks on an RTX 3070 GPU show: + +* Sparse runtime: 1.2 seconds +* Dense runtime: 0.5 seconds + +The dense implementation achieves 2x faster execution compared to sparse tensors in this configuration. ======================= ====================== ====================== ==================== Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non zero values @@ -315,9 +324,7 @@ Sparse (megabytes used) Dense (megabytes used) Ratio (Sparse/Dense) % % of non z 283 119 238 9.5 ======================= ====================== ====================== ==================== -The tensor size does not affect the values in the third column. -This table was generated by :code:`examples/benchmark/sparse_vs_dense_tensors.py` - +This table and performance metrics were generated by :code:`examples/benchmark/sparse_vs_dense_tensors.py` Running Simulations -------------------