Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BitCheck Process #802

Merged
merged 14 commits into from
Oct 25, 2023
117 changes: 117 additions & 0 deletions src/lava/proc/bit_check/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
import typing as ty

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyRefPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel

from lava.proc.bit_check.process import BitCheck


class AbstractPyBitCheckModel(PyLoihiProcessModel):
"""Abstract implementation of BitCheckModel

Specific implementations inherit from here.
"""

ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

bits: int = LavaPyType(int, int)
layerid: int = LavaPyType(int, int)
debug: int = LavaPyType(int, int)


class AbstractBitCheckModel(AbstractPyBitCheckModel):
"""Abstract implementation of BitCheck process. This
short and simple ProcessModel can be used for quick
checking of bit-accurate process runs as to whether
bits will overflow when running on hardware.
"""

ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

bits: int = LavaPyType(int, int)
layerid: int = LavaPyType(int, int)
debug: int = LavaPyType(int, int)
_overflowed: int = LavaPyType(int, int)

def post_guard(self):
return True

def run_post_mgmt(self):
value = self.ref.read()

if self.debug == 1:
print("Value is: {} at time step: {}"
.format(value, self.time_step))

# If self.check_bit_overflow(value) is true,
# the value overflowed the allowed bits from self.bits
if self.check_bit_overflow(value):
self._overflowed = 1
if self.debug == 1:
if self.layerid:
print("layer id number: {}".format(self.layerid))
print(
"value.max: overflows {} bits {}".format(
self.bits, value.max()
)
)
print(
"max signed value {}".format(
self.max_signed_int_per_bits(self.bits)
)
)
print(
"value.min: overflows {} bits {}".format(
self.bits, value.min()
)
)
print(
"min signed value {}".format(
self.max_signed_int_per_bits(self.bits)
)
)

def check_bit_overflow(self, value: ty.Type[np.ndarray]):
value = value.astype(np.int32)
shift_amt = 32 - self.bits
# shift value left by shift_amt and
mgkwill marked this conversation as resolved.
Show resolved Hide resolved
# then shift value right by shift_amt
# the result should equal unshifted value
# if the value did not overflow bits in self.bits
return not np.all(
((value << shift_amt) >> shift_amt) == value
)

def max_unsigned_int_per_bits(self, bits: ty.Type[int]):
return (1 << bits) - 1

def min_signed_int_per_bits(self, bits: ty.Type[int]):
return -1 << (bits - 1)

def max_signed_int_per_bits(self, bits: ty.Type[int]):
return (1 << (bits - 1)) - 1


@implements(proc=BitCheck, protocol=LoihiProtocol)
@requires(CPU)
class LoihiBitCheckModel(AbstractBitCheckModel):
"""Implementation of Loihi BitCheck process. This
short and simple ProcessModel can be used for quick
checking of Loihi bit-accurate process run as to
whether bits will overflow when running on Loihi Hardware.
"""

ref: PyRefPort = LavaPyType(PyRefPort.VEC_DENSE, int)

bits: int = LavaPyType(int, int)
layerid: int = LavaPyType(int, int)
debug: int = LavaPyType(int, int)
71 changes: 71 additions & 0 deletions src/lava/proc/bit_check/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import typing as ty

from lava.magma.core.process.process import LogConfig, AbstractProcess
from lava.magma.core.process.ports.ports import RefPort
from lava.magma.core.process.variable import Var


class BitCheck(AbstractProcess):
def __init__(
self,
shape: ty.Tuple[int, ...] = (1,),
layerid: ty.Optional[int] = None,
debug: ty.Optional[int] = 0,
bits: ty.Optional[int] = 24,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
**kwargs,
) -> None:
"""BitCheck process.
This process is used for quick checking of
bit-accurate process run as to whether bits will
overflow when running on bit sensitive hardware.

Parameters
----------
shape: Tuple
shape of the sigma process.
mgkwill marked this conversation as resolved.
Show resolved Hide resolved
Default is (1,).
layerid: int or float
layer number of network.
Default is None.
debug: 0 or 1
Enable (1) or disable (0) debug print.
Default is 0.
bits: int
bits to use when checking overflow, 1-32
Default is 24.
"""
super().__init__(
shape=shape,
name=name,
log_config=log_config,
**kwargs,
)
super().__init__(shape=shape, **kwargs)

self.ref = RefPort(shape=shape)

self.layerid = Var(shape=shape, init=layerid)
self.debug = Var(shape=shape, init=debug)
if bits <= 31 and bits >= 1:
self.bits = Var(shape=shape, init=bits)
else:
raise ValueError("bits value is \
{} but should be 1-31".format(bits))
self._overflowed: ty.Type(Var) = Var(shape=shape, init=0)

@property
def shape(self) -> ty.Tuple[int, ...]:
"""Return shape of the Process."""
return self.proc_params["shape"]

@property
def overflowed(self) -> ty.Type[int]:
"""Return overflow Var of Process.
1 is overflowed, 0 is not overflowed"""
return self._overflowed.get()
Empty file.
174 changes: 174 additions & 0 deletions tests/lava/proc/bit_check/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (C) 2021-22 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import sys
import unittest
import numpy as np
from typing import Tuple

from lava.magma.core.run_configs import Loihi1SimCfg
from lava.magma.core.run_conditions import RunSteps
from lava.proc.bit_check.process import BitCheck
from lava.proc.sdn.process import Sigma, SigmaDelta, ActivationMode
from lava.proc import io

verbose = True if (("-v" in sys.argv) or ("--verbose" in sys.argv)) else False


class TestBitCheckModels(unittest.TestCase):
"""Tests for BitCheck Models"""
num_steps = 100
input_ = np.sin(0.1 * np.arange(num_steps).reshape(1, -1))
input_ *= 1 << 12
input_ = input_.astype(int)
input_[:, 1:] -= input_[:, :-1]

def run_test(
self, num_steps: int, tag: str = "fixed_pt", bits: int = 24
) -> Tuple[np.ndarray, np.ndarray]:
source = io.source.RingBuffer(data=self.input_)
sigma = Sigma(shape=(1,))
sink = io.sink.RingBuffer(shape=sigma.shape, buffer=num_steps)

source.s_out.connect(sigma.a_in)
sigma.s_out.connect(sink.a_in)

debug = 0
if verbose:
debug = 1
bitcheck = BitCheck(
shape=sigma.shape, layerid=1, bits=bits, debug=debug
)
bitcheck.ref.connect_var(sigma.sigma)

run_condition = RunSteps(num_steps=num_steps)
run_config = Loihi1SimCfg(select_tag=tag)

sigma.run(condition=run_condition, run_cfg=run_config)
output = sink.data.get()

bits_used = bitcheck.bits.get()
overflowed = bitcheck.overflowed

sigma.stop()

return self.input_, output, bits_used, overflowed

def test_bitcheck_sigma_decoding_fixed_overflow(self) -> None:
"""Test BitCheck with overflow sigma decode."""
_, _, bitcheck_bits, bitcheck_overflowed = self.run_test(
num_steps=self.num_steps, tag="fixed_pt", bits=12
)

if verbose:
print("bitcheck_overflowed: ", bitcheck_overflowed)
print("bitcheck_bits: ", bitcheck_bits)
self.assertTrue(bitcheck_overflowed == 1)
self.assertTrue(bitcheck_bits == 12)

def test_bitcheck_sigma_decoding_fixed(self) -> None:
"""Test BitCheck no overflow sigma decode."""
_, _, bitcheck_bits, bitcheck_overflowed = self.run_test(
num_steps=self.num_steps, tag="fixed_pt", bits=24
)

if verbose:
print("bitcheck_overflowed: ", bitcheck_overflowed)
print("bitcheck_bits: ", bitcheck_bits)
self.assertTrue(bitcheck_overflowed == 0)
self.assertTrue(bitcheck_bits == 24)


class TestBitcheckSigmaDelta(unittest.TestCase):
"""Test BitCheck with sigma delta neurons."""
num_steps = 100
spike_exp = 6
state_exp = 6
vth = 10 << (spike_exp + state_exp)
input_ = np.sin(0.1 * np.arange(num_steps).reshape(1, -1))
input_ *= 1 << spike_exp + state_exp
input_ = input_.astype(int)
input_[:, 1:] -= input_[:, :-1]

def run_test(
self,
num_steps: int,
vth: int,
act_mode: ActivationMode,
spike_exp: int,
state_exp: int,
cum_error: bool,
tag: str = "fixed_pt",
bits: int = 24,
) -> Tuple[np.ndarray, np.ndarray]:
source = io.source.RingBuffer(data=self.input_ * (1 << 6))
sdn = SigmaDelta(
shape=(1,),
vth=vth,
act_mode=act_mode,
spike_exp=spike_exp,
state_exp=state_exp,
cum_error=cum_error,
)
sink = io.sink.RingBuffer(shape=sdn.shape, buffer=num_steps)

source.s_out.connect(sdn.a_in)
sdn.s_out.connect(sink.a_in)

debug = 0
if verbose:
debug = 1
bitcheck = BitCheck(shape=sdn.shape, layerid=1, bits=bits, debug=debug)
bitcheck.ref.connect_var(sdn.sigma)

run_condition = RunSteps(num_steps=num_steps)
run_config = Loihi1SimCfg(select_tag=tag)

sdn.run(condition=run_condition, run_cfg=run_config)
output = sink.data.get()
bits_used = bitcheck.bits.get()
overflowed = bitcheck.overflowed
sdn.stop()

input_ = np.cumsum(self.input_, axis=1)
output = np.cumsum(output, axis=1)

return input_, output, bits_used, overflowed

def test_bitcheck_reconstruction_fixed(self) -> None:
"""Tests BitCheck with fixed point sigma delta reconstruction"""
_, _, bitcheck_bits, bitcheck_overflowed = self.run_test(
num_steps=self.num_steps,
vth=self.vth,
act_mode=ActivationMode.UNIT,
spike_exp=self.spike_exp,
state_exp=self.state_exp,
cum_error=False,
bits=24,
)

if verbose:
print("bitcheck_overflowed: ", bitcheck_overflowed)
print("bitcheck_bits: ", bitcheck_bits)
self.assertTrue(bitcheck_overflowed == 0)
self.assertTrue(bitcheck_bits == 24)

def test_bitcheck_reconstruction_fixed_overflow(self) -> None:
"""Tests BitCheck overflow with fixed point
sigma delta reconstruction"""
_, _, bitcheck_bits, bitcheck_overflowed = self.run_test(
num_steps=self.num_steps,
vth=self.vth,
act_mode=ActivationMode.UNIT,
spike_exp=self.spike_exp,
state_exp=self.state_exp,
cum_error=False,
bits=12,
)

if verbose:
print("bitcheck_overflowed: ", bitcheck_overflowed)
print("bitcheck_bits: ", bitcheck_bits)
self.assertTrue(bitcheck_overflowed == 1)
self.assertTrue(bitcheck_bits == 12)
Loading