Skip to content

Commit

Permalink
LIF refractory floating point (#655)
Browse files Browse the repository at this point in the history
* LIF refractory float

* Fix off-by-one bug

* Comment style consistency

---------

Co-authored-by: PhilippPlank <[email protected]>
Co-authored-by: Mathis Richter <[email protected]>
Co-authored-by: weidel-p <[email protected]>
  • Loading branch information
4 people authored Jun 2, 2023
1 parent 3f819ec commit b8e014a
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 10 deletions.
68 changes: 61 additions & 7 deletions src/lava/proc/lif/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Copyright (C) 2021-22 Intel Corporation
# Copyright (C) 2021-23 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

from lava.magma.core.model.py.neuron import (
LearningNeuronModelFloat,
LearningNeuronModelFixed,
)
import numpy as np
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.lif.process import LIF, LIFReset, TernaryLIF, LearningLIF

from lava.magma.core.model.py.neuron import (
LearningNeuronModelFloat,
LearningNeuronModelFixed,
)
from lava.proc.lif.process import (LIF, LIFReset, TernaryLIF, LearningLIF,
LIFRefractory)


class AbstractPyLifModelFloat(PyLoihiProcessModel):
Expand Down Expand Up @@ -449,6 +449,60 @@ def run_spk(self):
self.s_out.send(s_out)


@implements(proc=LIFRefractory, protocol=LoihiProtocol)
@requires(CPU)
@tag("floating_pt")
class PyLifRefractoryModelFloat(AbstractPyLifModelFloat):
"""Implementation of Leaky-Integrate-and-Fire neural process with
refractory period in floating point precision.
"""

s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
vth: float = LavaPyType(float, float)
refractory_period_end: np.ndarray = LavaPyType(np.ndarray, int)

def __init__(self, proc_params):
super(PyLifRefractoryModelFloat, self).__init__(proc_params)
self.refractory_period = proc_params["refractory_period"]

def spiking_activation(self):
"""Spiking activation function for LIF Refractory."""
return self.v > self.vth

def subthr_dynamics(self, activation_in: np.ndarray):
"""Sub-threshold dynamics of current and voltage variables for
all refractory LIF models. This is where the 'leaky integration'
happens.
"""
self.u[:] = self.u * (1 - self.du)
self.u[:] += activation_in
non_refractory = self.refractory_period_end < self.time_step
self.v[non_refractory] = (self.v[non_refractory] * (
(1 - self.dv) + self.u[non_refractory])
+ self.bias_mant[non_refractory])

def process_spikes(self, spike_vector: np.ndarray):
self.refractory_period_end[spike_vector] = (self.time_step
+ self.refractory_period)
super().reset_voltage(spike_vector)

def run_spk(self):
"""The run function that performs the actual computation during
execution orchestrated by a PyLoihiProcessModel using the
LoihiProtocol.
"""
# Receive synaptic input
a_in_data = self.a_in.recv()

self.subthr_dynamics(activation_in=a_in_data)

s_out = self.spiking_activation()

# Reset voltage of spiked neurons to 0
self.process_spikes(spike_vector=s_out)
self.s_out.send(s_out)


@implements(proc=LearningLIF, protocol=LoihiProtocol)
@requires(CPU)
@tag("bit_accurate_loihi", "fixed_pt")
Expand Down
74 changes: 73 additions & 1 deletion src/lava/proc/lif/process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-23 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

Expand Down Expand Up @@ -344,3 +344,75 @@ def __init__(

self.proc_params["reset_interval"] = reset_interval
self.proc_params["reset_offset"] = reset_offset


class LIFRefractory(LIF):

"""Leaky-Integrate-and-Fire (LIF) process with refractory period.
Parameters
----------
shape : tuple(int)
Number and topology of LIF neurons.
u : float, list, numpy.ndarray, optional
Initial value of the neurons' current.
v : float, list, numpy.ndarray, optional
Initial value of the neurons' voltage (membrane potential).
du : float, optional
Inverse of decay time-constant for current decay. Currently, only a
single decay can be set for the entire population of neurons.
dv : float, optional
Inverse of decay time-constant for voltage decay. Currently, only a
single decay can be set for the entire population of neurons.
bias_mant : float, list, numpy.ndarray, optional
Mantissa part of neuron bias.
bias_exp : float, list, numpy.ndarray, optional
Exponent part of neuron bias, if needed. Mostly for fixed point
implementations. Ignored for floating point implementations.
vth : float, optional
Neuron threshold voltage, exceeding which, the neuron will spike.
Currently, only a single threshold can be set for the entire
population of neurons.
refractory_period : int, optional
The interval of the refractory period. 1 timestep by default.
See Also
--------
lava.proc.lif.process.LIF: 'Regular' leaky-integrate-and-fire neuron for
documentation on rest of the behavior.
"""

def __init__(
self,
*,
shape: ty.Tuple[int, ...],
u: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
v: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
du: ty.Optional[float] = 0,
dv: ty.Optional[float] = 0,
bias_mant: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
bias_exp: ty.Optional[ty.Union[float, list, np.ndarray]] = 0,
vth: ty.Optional[float] = 10,
refractory_period: ty.Optional[int] = 1,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
) -> None:
super().__init__(
shape=shape,
u=u,
v=v,
du=du,
dv=dv,
bias_mant=bias_mant,
bias_exp=bias_exp,
vth=vth,
name=name,
log_config=log_config,
)

if refractory_period < 1:
raise ValueError("Refractory period must be > 0.")

self.proc_params["refractory_period"] = refractory_period
self.refractory_period_end = Var(shape=shape, init=0)
39 changes: 37 additions & 2 deletions tests/lava/proc/lif/test_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2021-22 Intel Corporation
# Copyright (C) 2021-23 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest
import numpy as np
from numpy.testing import assert_almost_equal

from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
Expand All @@ -16,7 +17,7 @@
from lava.magma.core.run_configs import Loihi2SimCfg, RunConfig
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.proc.lif.process import LIF, LIFReset, TernaryLIF
from lava.proc.lif.process import LIF, LIFReset, TernaryLIF, LIFRefractory
from lava.proc import io


Expand Down Expand Up @@ -824,3 +825,37 @@ def test_fixed_model(self):
self.assertTrue(np.array_equal(u[:, reset_offset - 1:], u_gt_post))
self.assertTrue(np.array_equal(v[:, :reset_offset - 1], v_gt_pre))
self.assertTrue(np.array_equal(v[:, reset_offset - 1:], v_gt_post))


class TestLIFRefractory(unittest.TestCase):
"""Test LIF Refractory process model"""

def test_float_model(self):
"""Test float model"""
num_neurons = 2
num_steps = 8
refractory_period = 1

# Two neurons with different biases
lif_refractory = LIFRefractory(shape=(num_neurons,),
u=np.arange(num_neurons),
bias_mant=np.arange(num_neurons) + 1,
bias_exp=np.ones(
(num_neurons,), dtype=float),
vth=4.,
refractory_period=refractory_period)

v_logger = io.sink.Read(buffer=num_steps)
v_logger.connect_var(lif_refractory.v)

lif_refractory.run(condition=RunSteps(num_steps),
run_cfg=Loihi2SimCfg(select_tag="floating_pt"))

v = v_logger.data.get()
lif_refractory.stop()

# Voltage is expected to remain at reset level for two time steps
v_expected = np.array([[1, 2, 3, 4, 0, 0, 1, 2],
[2, 0, 0, 2, 0, 0, 2, 0]], dtype=float)

assert_almost_equal(v, v_expected)
Empty file.

0 comments on commit b8e014a

Please sign in to comment.