Skip to content

Commit

Permalink
plateau- fixed point implementation of Plateau neuron model
Browse files Browse the repository at this point in the history
Signed-off-by: kevin <[email protected]>
  • Loading branch information
kds300 committed Sep 1, 2023
1 parent 0263677 commit d70b6e7
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 0 deletions.
130 changes: 130 additions & 0 deletions src/lava/proc/plateau/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import numpy as np
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.plateau.process import Plateau


@implements(proc=Plateau, protocol=LoihiProtocol)
@requires(CPU)
@tag("fixed_pt")
class PyPlateauModelFixed(PyLoihiProcessModel):
""" Implementation of Plateau neuron process in fixed point precision.
Precisions of state variables
- du_dend : unsigned 12-bit integer (0 to 4095)
- du_soma : unsigned 12-bit integer (0 to 4095)
- vth_dend : unsigned 17-bit integer (0 to 131071)
- vth_soma : unsigned 17-bit integer (0 to 131071)
- up_dur : unsigned 8-bit integer (0 to 255)
"""

a_dend_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16)
a_soma_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16)
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24)
v_dend: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
v_soma: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
dv_dend: int = LavaPyType(int, np.uint16, precision=12)
dv_soma: int = LavaPyType(int, np.uint16, precision=12)
vth_dend: int = LavaPyType(int, np.int32, precision=17)
vth_soma: int = LavaPyType(int, np.int32, precision=17)
up_dur: int = LavaPyType(int, np.uint16, precision=8)
up_state: int = LavaPyType(np.ndarray, np.uint16, precision=8)

def __init__(self, proc_params):
super(PyPlateauModelFixed, self).__init__(proc_params)
self.uv_bitwidth = 24
self.max_uv_val = 2 ** (self.uv_bitwidth - 1)
self.decay_shift = 12
self.decay_unity = 2 ** self.decay_shift
self.vth_shift = 6
self.act_shift = 6
self.isthrscaled = False

def scale_threshold(self):
self.effective_vth_dend = np.left_shift(self.vth_dend, self.vth_shift)

Check warning on line 54 in src/lava/proc/plateau/models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/lava/proc/plateau/models.py#L54

Attribute 'effective_vth_dend' defined outside __init__
self.effective_vth_soma = np.left_shift(self.vth_soma, self.vth_shift)

Check warning on line 55 in src/lava/proc/plateau/models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/lava/proc/plateau/models.py#L55

Attribute 'effective_vth_soma' defined outside __init__
self.isthrscaled = True

def subthr_dynamics(
self,
activation_dend_in: np.ndarray,
activation_soma_in: np.ndarray
):
"""Run the sub-threshold dynamics for both the dendrite and soma of the
neuron. Both use 'leaky integration'.
"""
for v, dv, a_in in [
(self.v_dend, self.dv_dend, activation_dend_in),
(self.v_soma, self.dv_soma, activation_soma_in),
]:
decayed_volt = np.int64(v) * (self.decay_unity - dv)
decayed_volt = np.sign(decayed_volt) * np.right_shift(
np.abs(decayed_volt), 12
)
decayed_volt = np.int32(decayed_volt)
updated_volt = decayed_volt + np.left_shift(a_in, self.act_shift)

neg_voltage_limit = -np.int32(self.max_uv_val) + 1
pos_voltage_limit = np.int32(self.max_uv_val) - 1

v[:] = np.clip(
updated_volt, neg_voltage_limit, pos_voltage_limit
)

def update_up_state(self):
"""Decrements the up state (if necessary) and checks v_dend to see if
up state needs to be (re)set. If up state is (re)set, then v_dend is
reset to 0.
"""
self.up_state[self.up_state > 0] -= 1
self.up_state[self.v_dend > self.effective_vth_dend] = self.up_dur
self.v_dend[self.v_dend > self.effective_vth_dend] = 0

def soma_spike_and_reset(self):
"""Check the spiking conditions for the plateau soma. Checks if:
v_soma > v_th_soma
up_state > 0
For any neurons n that satisfy both conditions, sets:
s_out_buff[n] = True
v_soma = 0
"""
s_out_buff = np.logical_and(
self.v_soma > self.effective_vth_soma,
self.up_state > 0
)
self.v_soma[s_out_buff] = 0

return s_out_buff

def run_spk(self):
"""The run function that performs the actual computation during
execution orchestrated bgy a PyLoihiProcessModel using the
LoihiProtocol.
"""

# Receive synaptic input
a_dend_in_data = self.a_dend_in.recv()
a_soma_in_data = self.a_soma_in.recv()

# Check threshold scaling
if not self.isthrscaled:
self.scale_threshold()

self.subthr_dynamics(a_dend_in_data, a_soma_in_data)

self.update_up_state()

self.s_out_buff = self.soma_spike_and_reset()

Check warning on line 128 in src/lava/proc/plateau/models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/lava/proc/plateau/models.py#L128

Attribute 's_out_buff' defined outside __init__

self.s_out.send(self.s_out_buff)
70 changes: 70 additions & 0 deletions src/lava/proc/plateau/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import typing as ty
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort


class Plateau(AbstractProcess):
"""Plateau Neuron Process.
Couples two modified LIF dynamics. The neuron posesses two potentials,
v_dend and v_soma. Both follow sub-threshold LIF dynamics. When v_dend
crosses v_th_dend, it resets and sets the up_state to the value up_dur.
The supra-threshold behavior of v_soma depends on up_state:
if up_state == 0:
v_soma follows sub-threshold dynamics
if up_state > 0:
v_soma resets and the neuron sends out a spike
Parameters
----------
shape : tuple(int)
Number and topology of Plateau neurons.
dv_dend : float
Inverse of the decay time-constant for the dendrite potential.
dv_soma : float
Inverse of the decay time-constant for the soma potential.
vth_dend : float
Dendrite threshold voltage, exceeding which, the neuron will enter the
UP state.
vth_soma : float
Soma threshold voltage, exceeding which, the neuron will spike if it is
also in the UP state.
up_dur : int
The duration, in timesteps, of the UP state.
"""
def __init__(
self,
shape: ty.Tuple[int, ...],
dv_dend: float,
dv_soma: float,
vth_dend: float,
vth_soma: float,
up_dur: int,
name: ty.Optional[str] = None,
):
super().__init__(
shape=shape,
dv_dend=dv_dend,
dv_soma=dv_soma,
name=name,
up_dur=up_dur,
vth_dend=vth_dend,
vth_soma=vth_soma
)
self.a_dend_in = InPort(shape=shape)
self.a_soma_in = InPort(shape=shape)
self.s_out = OutPort(shape=shape)
self.v_dend = Var(shape=shape, init=0)
self.v_soma = Var(shape=shape, init=0)
self.dv_dend = Var(shape=(1,), init=dv_dend)
self.dv_soma = Var(shape=(1,), init=dv_soma)
self.vth_dend = Var(shape=(1,), init=vth_dend)
self.vth_soma = Var(shape=(1,), init=vth_soma)
self.up_dur = Var(shape=(1,), init=up_dur)
self.up_state = Var(shape=shape, init=0)
178 changes: 178 additions & 0 deletions tests/lava/proc/plateau/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import unittest
import numpy as np
from numpy.testing import assert_almost_equal

Check warning on line 8 in tests/lava/proc/plateau/test_models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/lava/proc/plateau/test_models.py#L8

Unused assert_almost_equal imported from numpy.testing

from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.ports import PyOutPort, PyInPort

Check warning on line 12 in tests/lava/proc/plateau/test_models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/lava/proc/plateau/test_models.py#L12

Unused PyInPort imported from lava.magma.core.model.py.ports
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.ports.ports import OutPort, InPort

Check warning on line 14 in tests/lava/proc/plateau/test_models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/lava/proc/plateau/test_models.py#L14

Unused InPort imported from lava.magma.core.process.ports.ports
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.resources import CPU
from lava.proc.plateau.process import Plateau
from lava.proc.dense.process import Dense
from lava.magma.core.run_configs import Loihi2SimCfg, RunConfig

Check warning on line 20 in tests/lava/proc/plateau/test_models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/lava/proc/plateau/test_models.py#L20

Unused RunConfig imported from lava.magma.core.run_configs
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.tests.lava.proc.lif.test_models import VecSendProcess, VecRecvProcess

Check warning on line 23 in tests/lava/proc/plateau/test_models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/lava/proc/plateau/test_models.py#L23

Unused VecSendProcess imported from lava.tests.lava.proc.lif.test_models


class SpikeGen(AbstractProcess):
"""Process for sending spikes at user-supplied time steps.
Parameters
----------
spikes_in: list[list], list of lists containing spike times
runtime: int, number of timesteps for the generator to store spikes
"""
def __init__(self, spikes_in, runtime):
super().__init__()
n = len(spikes_in)
self.shape = (n,)
spike_data = np.zeros(shape=(n, runtime))
for i in range(n):
for t in range(1, runtime + 1):
if t in spikes_in[i]:
spike_data[i, t - 1] = 1
self.s_out = OutPort(shape=self.shape)
self.spike_data = Var(shape=(n, runtime), init=spike_data)


@implements(proc=SpikeGen, protocol=LoihiProtocol)
@requires(CPU)
@tag('fixed_pt')
class PySpikeGenModel(PyLoihiProcessModel):
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)
spike_data: np.ndarray = LavaPyType(np.ndarray, float)

def run_spk(self):
"""Send the appropriate spikes for the given time step
"""
self.s_out.send(self.spike_data[:, self.time_step - 1])


class TestPlateauProcessModelsFixed(unittest.TestCase):
"""Tests for the fixed point Plateau process models."""
def test_fixed_max_decay(self):
"""
Tests fixed point Plateau with max voltage decays.
"""
shape = (3,)
num_steps = 20
spikes_in_dend = [
[5],
[5],
[5],
]
spikes_in_soma = [
[3],
[10],
[17]
]
sg_dend = SpikeGen(spikes_in=spikes_in_dend, runtime=num_steps)
sg_soma = SpikeGen(spikes_in=spikes_in_soma, runtime=num_steps)
dense_dend = Dense(weights=2 * np.diag(np.ones(shape=shape)))
dense_soma = Dense(weights=2 * np.diag(np.ones(shape=shape)))
plat = Plateau(
shape=shape,
dv_dend=4096,
dv_soma=4096,
vth_soma=1,
vth_dend=1,
up_dur=10
)
vr = VecRecvProcess(shape=(num_steps, shape[0]))
sg_dend.s_out.connect(dense_dend.s_in)
sg_soma.s_out.connect(dense_soma.s_in)
dense_dend.a_out.connect(plat.a_dend_in)
dense_soma.a_out.connect(plat.a_soma_in)
plat.s_out.connect(vr.s_in)
# run model
plat.run(RunSteps(num_steps), Loihi2SimCfg(select_tag='fixed_pt'))
test_spk_data = vr.spk_data.get()
plat.stop()
# Gold standard for the test
expected_spk_data = np.zeros((num_steps, shape[0]))
# Neuron 2 should spike when receiving soma input
expected_spk_data[10, 1] = 1
self.assertTrue(np.all(expected_spk_data == test_spk_data))

def test_up_dur(self):
"""
Tests that the UP state lasts for the time specified by the model.
Checks that up_state decreases by one each time step after activation.
"""
shape = (1,)
num_steps = 10
spikes_in_dend = [[3]]
sg_dend = SpikeGen(spikes_in=spikes_in_dend, runtime=num_steps)
dense_dend = Dense(weights=2 * (np.diag(np.ones(shape=shape))))
plat = Plateau(
shape=shape,
dv_dend=4096,
dv_soma=4096,
vth_soma=1,
vth_dend=1,
up_dur=5
)
sg_dend.s_out.connect(dense_dend.s_in)
dense_dend.a_out.connect(plat.a_dend_in)
# run model
test_up_state = []
for t in range(num_steps):

Check warning on line 128 in tests/lava/proc/plateau/test_models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/lava/proc/plateau/test_models.py#L128

Unused variable 't'
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt'))
test_up_state.append(plat.up_state.get().astype(int)[0])
plat.stop()
# Gold standard for the test
# UP state active time steps 4 - 9 (5 timesteps)
# this is delayed by one b.c. of the Dense process
expected_up_state = [0, 0, 0, 5, 4, 3, 2, 1, 0, 0]
self.assertListEqual(expected_up_state, test_up_state)

def test_fixed_dvs(self):
"""
Tests fixed point Plateau voltage decays.
"""
shape = (1,)
num_steps = 10
spikes_in = [[1]]
sg_dend = SpikeGen(spikes_in=spikes_in, runtime=num_steps)
sg_soma = SpikeGen(spikes_in=spikes_in, runtime=num_steps)
dense_dend = Dense(weights=100 * np.diag(np.ones(shape=shape)))
dense_soma = Dense(weights=100 * np.diag(np.ones(shape=shape)))
plat = Plateau(
shape=shape,
dv_dend=2048,
dv_soma=1024,
vth_soma=100,
vth_dend=100,
up_dur=10
)
sg_dend.s_out.connect(dense_dend.s_in)
sg_soma.s_out.connect(dense_soma.s_in)
dense_dend.a_out.connect(plat.a_dend_in)
dense_soma.a_out.connect(plat.a_soma_in)
# run model
test_v_dend = []
test_v_soma = []
for t in range(num_steps):
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt'))
test_v_dend.append(plat.v_dend.get().astype(int)[0])
test_v_soma.append(plat.v_soma.get().astype(int)[0])
plat.stop()
# Gold standard for the test
# 100<<6 = 6400 -- initial value at time step 2
expected_v_dend = [
0, 6400, 3200, 1600, 800, 400, 200, 100, 50, 25
]
expected_v_soma = [
0, 6400, 4800, 3600, 2700, 2025, 1518, 1138, 853, 639
]
self.assertListEqual(expected_v_dend, test_v_dend)
self.assertListEqual(expected_v_soma, test_v_soma)
Loading

0 comments on commit d70b6e7

Please sign in to comment.