diff --git a/src/lava/lib/dnf/inputs/rate_code_spike_gen/models.py b/src/lava/lib/dnf/inputs/rate_code_spike_gen/models.py index 4dc8de7..026912c 100644 --- a/src/lava/lib/dnf/inputs/rate_code_spike_gen/models.py +++ b/src/lava/lib/dnf/inputs/rate_code_spike_gen/models.py @@ -97,6 +97,10 @@ def _compute_spike_distances(self, pattern: np.ndarray) -> np.ndarray: np.rint(TIME_STEPS_PER_MINUTE / pattern[idx_non_negligible])\ .astype(int) + idx_saturated = np.all([idx_non_negligible, distances == 0.], axis=0) + + distances[idx_saturated] = 1 + return distances def _compute_spike_onsets(self, distances: np.ndarray) -> np.ndarray: diff --git a/tests/lava/lib/dnf/inputs/rate_code_spike_gen/test_models.py b/tests/lava/lib/dnf/inputs/rate_code_spike_gen/test_models.py index 4e6c8ce..9471820 100644 --- a/tests/lava/lib/dnf/inputs/rate_code_spike_gen/test_models.py +++ b/tests/lava/lib/dnf/inputs/rate_code_spike_gen/test_models.py @@ -227,6 +227,41 @@ def test_generate_spikes(self) -> None: finally: source.stop() + def test_spike_rate_saturation(self) -> None: + """Tests whether the spike rate is saturated when high pattern + amplitude is given (i.e that neuron spikes every time step when its + corresponding incoming pattern is high)""" + num_steps = 10 + shape = (5,) + + pattern = np.zeros(shape) + pattern[2:3] = 20000. + + expected_spikes = np.array( + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ) + + source = SourceProcess(shape=shape, data=pattern) + spike_gen = RateCodeSpikeGen(shape=shape, seed=42) + sink = SinkProcess(shape=(shape[0], num_steps)) + + source.out_ports.a_out.connect(spike_gen.in_ports.a_in) + spike_gen.out_ports.s_out.connect(sink.in_ports.s_in) + + try: + source.run(condition=RunSteps(num_steps=num_steps), + run_cfg=Loihi1SimCfg()) + + received_spikes = sink.data.get() + + np.testing.assert_array_equal(received_spikes, expected_spikes) + finally: + source.stop() + if __name__ == '__main__': unittest.main()