-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
plateau- fixed point implementation of Plateau neuron model
Signed-off-by: kevin <[email protected]>
- Loading branch information
Showing
4 changed files
with
408 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import numpy as np | ||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.magma.core.model.py.ports import PyInPort, PyOutPort | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.resources import CPU | ||
from lava.magma.core.decorator import implements, requires, tag | ||
from lava.magma.core.model.py.model import PyLoihiProcessModel | ||
from lava.proc.plateau.process import Plateau | ||
|
||
|
||
@implements(proc=Plateau, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
@tag("fixed_pt") | ||
class PyPlateauModelFixed(PyLoihiProcessModel): | ||
""" Implementation of Plateau neuron process in fixed point precision. | ||
Precisions of state variables | ||
- du_dend : unsigned 12-bit integer (0 to 4095) | ||
- du_soma : unsigned 12-bit integer (0 to 4095) | ||
- vth_dend : unsigned 17-bit integer (0 to 131071) | ||
- vth_soma : unsigned 17-bit integer (0 to 131071) | ||
- up_dur : unsigned 8-bit integer (0 to 255) | ||
""" | ||
|
||
a_dend_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16) | ||
a_soma_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16) | ||
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24) | ||
v_dend: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24) | ||
v_soma: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24) | ||
dv_dend: int = LavaPyType(int, np.uint16, precision=12) | ||
dv_soma: int = LavaPyType(int, np.uint16, precision=12) | ||
vth_dend: int = LavaPyType(int, np.int32, precision=17) | ||
vth_soma: int = LavaPyType(int, np.int32, precision=17) | ||
up_dur: int = LavaPyType(int, np.uint16, precision=8) | ||
up_state: int = LavaPyType(np.ndarray, np.uint16, precision=8) | ||
|
||
def __init__(self, proc_params): | ||
super(PyPlateauModelFixed, self).__init__(proc_params) | ||
self.uv_bitwidth = 24 | ||
self.max_uv_val = 2 ** (self.uv_bitwidth - 1) | ||
self.decay_shift = 12 | ||
self.decay_unity = 2 ** self.decay_shift | ||
self.vth_shift = 6 | ||
self.act_shift = 6 | ||
self.isthrscaled = False | ||
|
||
def scale_threshold(self): | ||
self.effective_vth_dend = np.left_shift(self.vth_dend, self.vth_shift) | ||
self.effective_vth_soma = np.left_shift(self.vth_soma, self.vth_shift) | ||
self.isthrscaled = True | ||
|
||
def subthr_dynamics( | ||
self, | ||
activation_dend_in: np.ndarray, | ||
activation_soma_in: np.ndarray | ||
): | ||
"""Run the sub-threshold dynamics for both the dendrite and soma of the | ||
neuron. Both use 'leaky integration'. | ||
""" | ||
for v, dv, a_in in [ | ||
(self.v_dend, self.dv_dend, activation_dend_in), | ||
(self.v_soma, self.dv_soma, activation_soma_in), | ||
]: | ||
decayed_volt = np.int64(v) * (self.decay_unity - dv) | ||
decayed_volt = np.sign(decayed_volt) * np.right_shift( | ||
np.abs(decayed_volt), 12 | ||
) | ||
decayed_volt = np.int32(decayed_volt) | ||
updated_volt = decayed_volt + np.left_shift(a_in, self.act_shift) | ||
|
||
neg_voltage_limit = -np.int32(self.max_uv_val) + 1 | ||
pos_voltage_limit = np.int32(self.max_uv_val) - 1 | ||
|
||
v[:] = np.clip( | ||
updated_volt, neg_voltage_limit, pos_voltage_limit | ||
) | ||
|
||
def update_up_state(self): | ||
"""Decrements the up state (if necessary) and checks v_dend to see if | ||
up state needs to be (re)set. If up state is (re)set, then v_dend is | ||
reset to 0. | ||
""" | ||
self.up_state[self.up_state > 0] -= 1 | ||
self.up_state[self.v_dend > self.effective_vth_dend] = self.up_dur | ||
self.v_dend[self.v_dend > self.effective_vth_dend] = 0 | ||
|
||
def soma_spike_and_reset(self): | ||
"""Check the spiking conditions for the plateau soma. Checks if: | ||
v_soma > v_th_soma | ||
up_state > 0 | ||
For any neurons n that satisfy both conditions, sets: | ||
s_out_buff[n] = True | ||
v_soma = 0 | ||
""" | ||
s_out_buff = np.logical_and( | ||
self.v_soma > self.effective_vth_soma, | ||
self.up_state > 0 | ||
) | ||
self.v_soma[s_out_buff] = 0 | ||
|
||
return s_out_buff | ||
|
||
def run_spk(self): | ||
"""The run function that performs the actual computation during | ||
execution orchestrated bgy a PyLoihiProcessModel using the | ||
LoihiProtocol. | ||
""" | ||
|
||
# Receive synaptic input | ||
a_dend_in_data = self.a_dend_in.recv() | ||
a_soma_in_data = self.a_soma_in.recv() | ||
|
||
# Check threshold scaling | ||
if not self.isthrscaled: | ||
self.scale_threshold() | ||
|
||
self.subthr_dynamics(a_dend_in_data, a_soma_in_data) | ||
|
||
self.update_up_state() | ||
|
||
self.s_out_buff = self.soma_spike_and_reset() | ||
|
||
self.s_out.send(self.s_out_buff) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import typing as ty | ||
from lava.magma.core.process.process import AbstractProcess | ||
from lava.magma.core.process.variable import Var | ||
from lava.magma.core.process.ports.ports import InPort, OutPort | ||
|
||
|
||
class Plateau(AbstractProcess): | ||
"""Plateau Neuron Process. | ||
Couples two modified LIF dynamics. The neuron posesses two potentials, | ||
v_dend and v_soma. Both follow sub-threshold LIF dynamics. When v_dend | ||
crosses v_th_dend, it resets and sets the up_state to the value up_dur. | ||
The supra-threshold behavior of v_soma depends on up_state: | ||
if up_state == 0: | ||
v_soma follows sub-threshold dynamics | ||
if up_state > 0: | ||
v_soma resets and the neuron sends out a spike | ||
Parameters | ||
---------- | ||
shape : tuple(int) | ||
Number and topology of Plateau neurons. | ||
dv_dend : float | ||
Inverse of the decay time-constant for the dendrite potential. | ||
dv_soma : float | ||
Inverse of the decay time-constant for the soma potential. | ||
vth_dend : float | ||
Dendrite threshold voltage, exceeding which, the neuron will enter the | ||
UP state. | ||
vth_soma : float | ||
Soma threshold voltage, exceeding which, the neuron will spike if it is | ||
also in the UP state. | ||
up_dur : int | ||
The duration, in timesteps, of the UP state. | ||
""" | ||
def __init__( | ||
self, | ||
shape: ty.Tuple[int, ...], | ||
dv_dend: float, | ||
dv_soma: float, | ||
vth_dend: float, | ||
vth_soma: float, | ||
up_dur: int, | ||
name: ty.Optional[str] = None, | ||
): | ||
super().__init__( | ||
shape=shape, | ||
dv_dend=dv_dend, | ||
dv_soma=dv_soma, | ||
name=name, | ||
up_dur=up_dur, | ||
vth_dend=vth_dend, | ||
vth_soma=vth_soma | ||
) | ||
self.a_dend_in = InPort(shape=shape) | ||
self.a_soma_in = InPort(shape=shape) | ||
self.s_out = OutPort(shape=shape) | ||
self.v_dend = Var(shape=shape, init=0) | ||
self.v_soma = Var(shape=shape, init=0) | ||
self.dv_dend = Var(shape=(1,), init=dv_dend) | ||
self.dv_soma = Var(shape=(1,), init=dv_soma) | ||
self.vth_dend = Var(shape=(1,), init=vth_dend) | ||
self.vth_soma = Var(shape=(1,), init=vth_soma) | ||
self.up_dur = Var(shape=(1,), init=up_dur) | ||
self.up_state = Var(shape=shape, init=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
|
||
import unittest | ||
import numpy as np | ||
from numpy.testing import assert_almost_equal | ||
|
||
from lava.magma.core.decorator import implements, requires, tag | ||
from lava.magma.core.model.py.model import PyLoihiProcessModel | ||
from lava.magma.core.model.py.ports import PyOutPort, PyInPort | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.process.ports.ports import OutPort, InPort | ||
from lava.magma.core.process.process import AbstractProcess | ||
from lava.magma.core.process.variable import Var | ||
from lava.magma.core.resources import CPU | ||
from lava.proc.plateau.process import Plateau | ||
from lava.proc.dense.process import Dense | ||
from lava.magma.core.run_configs import Loihi2SimCfg, RunConfig | ||
from lava.magma.core.run_conditions import RunSteps | ||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.tests.lava.proc.lif.test_models import VecSendProcess, VecRecvProcess | ||
|
||
|
||
class SpikeGen(AbstractProcess): | ||
"""Process for sending spikes at user-supplied time steps. | ||
Parameters | ||
---------- | ||
spikes_in: list[list], list of lists containing spike times | ||
runtime: int, number of timesteps for the generator to store spikes | ||
""" | ||
def __init__(self, spikes_in, runtime): | ||
super().__init__() | ||
n = len(spikes_in) | ||
self.shape = (n,) | ||
spike_data = np.zeros(shape=(n, runtime)) | ||
for i in range(n): | ||
for t in range(1, runtime + 1): | ||
if t in spikes_in[i]: | ||
spike_data[i, t - 1] = 1 | ||
self.s_out = OutPort(shape=self.shape) | ||
self.spike_data = Var(shape=(n, runtime), init=spike_data) | ||
|
||
|
||
@implements(proc=SpikeGen, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
@tag('fixed_pt') | ||
class PySpikeGenModel(PyLoihiProcessModel): | ||
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float) | ||
spike_data: np.ndarray = LavaPyType(np.ndarray, float) | ||
|
||
def run_spk(self): | ||
"""Send the appropriate spikes for the given time step | ||
""" | ||
self.s_out.send(self.spike_data[:, self.time_step - 1]) | ||
|
||
|
||
class TestPlateauProcessModelsFixed(unittest.TestCase): | ||
"""Tests for the fixed point Plateau process models.""" | ||
def test_fixed_max_decay(self): | ||
""" | ||
Tests fixed point Plateau with max voltage decays. | ||
""" | ||
shape = (3,) | ||
num_steps = 20 | ||
spikes_in_dend = [ | ||
[5], | ||
[5], | ||
[5], | ||
] | ||
spikes_in_soma = [ | ||
[3], | ||
[10], | ||
[17] | ||
] | ||
sg_dend = SpikeGen(spikes_in=spikes_in_dend, runtime=num_steps) | ||
sg_soma = SpikeGen(spikes_in=spikes_in_soma, runtime=num_steps) | ||
dense_dend = Dense(weights=2 * np.diag(np.ones(shape=shape))) | ||
dense_soma = Dense(weights=2 * np.diag(np.ones(shape=shape))) | ||
plat = Plateau( | ||
shape=shape, | ||
dv_dend=4096, | ||
dv_soma=4096, | ||
vth_soma=1, | ||
vth_dend=1, | ||
up_dur=10 | ||
) | ||
vr = VecRecvProcess(shape=(num_steps, shape[0])) | ||
sg_dend.s_out.connect(dense_dend.s_in) | ||
sg_soma.s_out.connect(dense_soma.s_in) | ||
dense_dend.a_out.connect(plat.a_dend_in) | ||
dense_soma.a_out.connect(plat.a_soma_in) | ||
plat.s_out.connect(vr.s_in) | ||
# run model | ||
plat.run(RunSteps(num_steps), Loihi2SimCfg(select_tag='fixed_pt')) | ||
test_spk_data = vr.spk_data.get() | ||
plat.stop() | ||
# Gold standard for the test | ||
expected_spk_data = np.zeros((num_steps, shape[0])) | ||
# Neuron 2 should spike when receiving soma input | ||
expected_spk_data[10, 1] = 1 | ||
self.assertTrue(np.all(expected_spk_data == test_spk_data)) | ||
|
||
def test_up_dur(self): | ||
""" | ||
Tests that the UP state lasts for the time specified by the model. | ||
Checks that up_state decreases by one each time step after activation. | ||
""" | ||
shape = (1,) | ||
num_steps = 10 | ||
spikes_in_dend = [[3]] | ||
sg_dend = SpikeGen(spikes_in=spikes_in_dend, runtime=num_steps) | ||
dense_dend = Dense(weights=2 * (np.diag(np.ones(shape=shape)))) | ||
plat = Plateau( | ||
shape=shape, | ||
dv_dend=4096, | ||
dv_soma=4096, | ||
vth_soma=1, | ||
vth_dend=1, | ||
up_dur=5 | ||
) | ||
sg_dend.s_out.connect(dense_dend.s_in) | ||
dense_dend.a_out.connect(plat.a_dend_in) | ||
# run model | ||
test_up_state = [] | ||
for t in range(num_steps): | ||
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt')) | ||
test_up_state.append(plat.up_state.get().astype(int)[0]) | ||
plat.stop() | ||
# Gold standard for the test | ||
# UP state active time steps 4 - 9 (5 timesteps) | ||
# this is delayed by one b.c. of the Dense process | ||
expected_up_state = [0, 0, 0, 5, 4, 3, 2, 1, 0, 0] | ||
self.assertListEqual(expected_up_state, test_up_state) | ||
|
||
def test_fixed_dvs(self): | ||
""" | ||
Tests fixed point Plateau voltage decays. | ||
""" | ||
shape = (1,) | ||
num_steps = 10 | ||
spikes_in = [[1]] | ||
sg_dend = SpikeGen(spikes_in=spikes_in, runtime=num_steps) | ||
sg_soma = SpikeGen(spikes_in=spikes_in, runtime=num_steps) | ||
dense_dend = Dense(weights=100 * np.diag(np.ones(shape=shape))) | ||
dense_soma = Dense(weights=100 * np.diag(np.ones(shape=shape))) | ||
plat = Plateau( | ||
shape=shape, | ||
dv_dend=2048, | ||
dv_soma=1024, | ||
vth_soma=100, | ||
vth_dend=100, | ||
up_dur=10 | ||
) | ||
sg_dend.s_out.connect(dense_dend.s_in) | ||
sg_soma.s_out.connect(dense_soma.s_in) | ||
dense_dend.a_out.connect(plat.a_dend_in) | ||
dense_soma.a_out.connect(plat.a_soma_in) | ||
# run model | ||
test_v_dend = [] | ||
test_v_soma = [] | ||
for t in range(num_steps): | ||
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt')) | ||
test_v_dend.append(plat.v_dend.get().astype(int)[0]) | ||
test_v_soma.append(plat.v_soma.get().astype(int)[0]) | ||
plat.stop() | ||
# Gold standard for the test | ||
# 100<<6 = 6400 -- initial value at time step 2 | ||
expected_v_dend = [ | ||
0, 6400, 3200, 1600, 800, 400, 200, 100, 50, 25 | ||
] | ||
expected_v_soma = [ | ||
0, 6400, 4800, 3600, 2700, 2025, 1518, 1138, 853, 639 | ||
] | ||
self.assertListEqual(expected_v_dend, test_v_dend) | ||
self.assertListEqual(expected_v_soma, test_v_soma) |
Oops, something went wrong.