Skip to content

Commit

Permalink
Allow np.array as input weights for Sparse (#772)
Browse files Browse the repository at this point in the history
* ndarray as input weights for Sparse

* docs

* codacy

* remove implementation details from docstring and from tests

* move tests to corresponding classes

* put weight casting into extra method

* Removed unused import

---------

Co-authored-by: Mathis Richter <[email protected]>
  • Loading branch information
2 people authored and epaxon committed Sep 1, 2023
1 parent 4ffb3cd commit cbe517a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 19 deletions.
37 changes: 22 additions & 15 deletions src/lava/proc/sparse/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See: https://spdx.org/licenses/

import numpy as np
from scipy.sparse import spmatrix
from scipy.sparse import spmatrix, csr_matrix
import typing as ty

from lava.magma.core.process.process import AbstractProcess, LogConfig
Expand All @@ -21,8 +21,8 @@ class Sparse(AbstractProcess):
Parameters
----------
weights : scipy.sparse.spmatrix
2D connection weight matrix as sparse matrix of form
weights : scipy.sparse.spmatrix or np.ndarray
2D connection weight matrix of form
(num_flat_output_neurons, num_flat_input_neurons).
weight_exp : int, optional
Expand Down Expand Up @@ -55,9 +55,10 @@ class Sparse(AbstractProcess):
spikes as binary spikes (num_message_bits = 0) or as graded
spikes (num_message_bits > 0). Default is 0.
"""

def __init__(self,
*,
weights: spmatrix,
weights: ty.Union[spmatrix, np.ndarray],
name: ty.Optional[str] = None,
num_message_bits: ty.Optional[int] = 0,
log_config: ty.Optional[LogConfig] = None,
Expand All @@ -68,9 +69,7 @@ def __init__(self,
log_config=log_config,
**kwargs)

# Transform weights to csr matrix
weights = weights.tocsr()

weights = self._create_csr_matrix_from_weights(weights)
shape = weights.shape

# Ports
Expand All @@ -82,6 +81,15 @@ def __init__(self,
self.a_buff = Var(shape=(shape[0],), init=0)
self.num_message_bits = Var(shape=(1,), init=num_message_bits)

@staticmethod
def _create_csr_matrix_from_weights(weights):
# Transform weights to csr matrix
if isinstance(weights, np.ndarray):
weights = csr_matrix(weights)
else:
weights = weights.tocsr()
return weights


class LearningSparse(LearningConnectionProcess, Sparse):
"""Sparse connections between neurons. Realizes the following abstract
Expand All @@ -90,8 +98,8 @@ class LearningSparse(LearningConnectionProcess, Sparse):
Parameters
----------
weights : scipy.sparse.spmatrix
2D connection weight matrix as sparse matrix of form
weights : scipy.sparse.spmatrix or np.ndarray
2D connection weight matrix of form
(num_flat_output_neurons, num_flat_input_neurons).
weight_exp : int, optional
Expand Down Expand Up @@ -148,9 +156,10 @@ class LearningSparse(LearningConnectionProcess, Sparse):
x1 and regular impulse addition to x2 will be considered by the
learning rule Products conditioned on x0.
"""

def __init__(self,
*,
weights: spmatrix,
weights: ty.Union[spmatrix, np.ndarray],
name: ty.Optional[str] = None,
num_message_bits: ty.Optional[int] = 0,
log_config: ty.Optional[LogConfig] = None,
Expand All @@ -171,9 +180,7 @@ def __init__(self,
graded_spike_cfg=graded_spike_cfg,
**kwargs)

# Transform weights to csr matrix
weights = weights.tocsr()

weights = self._create_csr_matrix_from_weights(weights)
shape = weights.shape

# Ports
Expand All @@ -189,7 +196,7 @@ def __init__(self,
class DelaySparse(Sparse):
def __init__(self,
*,
weights: spmatrix,
weights: ty.Union[spmatrix, np.ndarray],
delays: ty.Union[spmatrix, int],
max_delay: ty.Optional[int] = 0,
name: ty.Optional[str] = None,
Expand All @@ -201,7 +208,7 @@ def __init__(self,
Parameters
----------
weights : spmatrix
weights : scipy.sparse.spmatrix or np.ndarray
2D connection weight matrix of form (num_flat_output_neurons,
num_flat_input_neurons) in C-order (row major).
Expand Down
50 changes: 46 additions & 4 deletions tests/lava/proc/sparse/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TestFunctions(unittest.TestCase):
"""Test helper function for Sparse"""

def test_find_with_explicit_zeros(self):

mat = np.random.randint(-10, 10, (3, 5))
spmat = csr_matrix(mat)
spmat.data[0] = 0
Expand All @@ -41,9 +40,19 @@ def test_init(self):

conn = Sparse(weights=weights_sparse)

self.assertIsInstance(conn.weights.init, spmatrix)
np.testing.assert_array_equal(conn.weights.init.toarray(), weights)

def test_init_of_sparse_with_ndarray(self):
"""Tests instantiation of Sparse with ndarray as
weights"""

shape = (3, 2)
weights = np.random.random(shape)

conn = Sparse(weights=weights)

np.testing.assert_array_equal(conn.weights.get().toarray(), weights)


class TestLearningSparseProcess(unittest.TestCase):
"""Tests for LearningSparse class"""
Expand Down Expand Up @@ -72,9 +81,28 @@ def test_init(self):
conn = LearningSparse(weights=weights_sparse,
learning_rule=learning_rule)

self.assertIsInstance(conn.weights.init, spmatrix)
np.testing.assert_array_equal(conn.weights.init.toarray(), weights)

def test_init_of_learningsparse_with_ndarray(self):
"""Tests instantiation of LearningSparse with
ndarray as weights"""

shape = (3, 2)
weights = np.random.random(shape)

learning_rule = STDPLoihi(
learning_rate=1,
A_plus=1,
A_minus=-2,
tau_plus=10,
tau_minus=10,
t_epoch=2,
)

conn = LearningSparse(weights=weights, learning_rule=learning_rule)

np.testing.assert_array_equal(conn.weights.get().toarray(), weights)


class TestDelaySparseProcess(unittest.TestCase):
"""Tests for Sparse class"""
Expand All @@ -95,7 +123,6 @@ def test_init(self):

conn = DelaySparse(weights=weights_sparse, delays=delays_sparse)

self.assertIsInstance(conn.weights.init, spmatrix)
np.testing.assert_array_equal(conn.weights.init.toarray(), weights)

def test_validate_shapes(self):
Expand Down Expand Up @@ -132,3 +159,18 @@ def test_validate_nonzero_delays(self):
DelaySparse,
weights=weights_sparse,
delays=delays_sparse)

def test_init_of_delaysparse_with_ndarray(self):
"""Tests instantiation of DelaySparse with ndarray as weights"""

shape = (3, 2)
weights = np.random.random(shape)
delays = np.random.randint(0, 3, shape)

conn = DelaySparse(weights=weights, delays=delays)

np.testing.assert_array_equal(conn.weights.get().toarray(), weights)


if __name__ == '__main__':
unittest.main()

0 comments on commit cbe517a

Please sign in to comment.