Skip to content

Commit

Permalink
Merge branch 'main' into graded_relu
Browse files Browse the repository at this point in the history
  • Loading branch information
mgkwill authored Jul 25, 2024
2 parents 2a345e3 + ae13b7a commit 90b4ef0
Show file tree
Hide file tree
Showing 41 changed files with 4,050 additions and 1,047 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
lfs: true

- name: setup CI
uses: lava-nc/ci-setup-composite-action@v1.2
uses: lava-nc/ci-setup-composite-action@v1.5.10_py3.10
with:
repository: 'Lava'

Expand All @@ -39,7 +39,7 @@ jobs:
lfs: true

- name: setup CI
uses: lava-nc/ci-setup-composite-action@v1.2
uses: lava-nc/ci-setup-composite-action@v1.5.10_py3.10
with:
repository: 'Lava'

Expand All @@ -62,7 +62,7 @@ jobs:
lfs: true

- name: setup CI
uses: lava-nc/ci-setup-composite-action@v1.2
uses: lava-nc/ci-setup-composite-action@v1.5.10_py3.10
with:
repository: 'Lava'

Expand Down
1,876 changes: 899 additions & 977 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ classifiers = [
"Discussions" = "https://github.com/lava-nc/lava/discussions"

[tool.poetry.dependencies]
python = ">=3.8, <3.11"
python = ">=3.10, <3.11"

numpy = "^1.24.4"
scipy = "^1.10.1"
networkx = "<=2.8.7"
asteval = "^0.9.31"
scikit-learn = "^1.3.1"
scikit-learn = "^1.5.0"

[tool.poetry.dev-dependencies]
bandit = "1.7.4"
Expand Down
14 changes: 14 additions & 0 deletions src/lava/frameworks/loihi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (C) 2022-23 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

from lava.networks.gradedvecnetwork import (InputVec, OutputVec, GradedVec,
GradedDense, GradedSparse,
ProductVec,
LIFVec,
NormalizeNet)

from lava.networks.resfire import ResFireVec

from lava.magma.core.run_conditions import RunSteps, RunContinuous
from lava.magma.core.run_configs import Loihi2SimCfg, Loihi2HwCfg
92 changes: 86 additions & 6 deletions src/lava/magma/compiler/channel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lava.magma.compiler.utils import PortInitializer
from lava.magma.core.process.ports.ports import AbstractPort
from lava.magma.core.process.ports.ports import AbstractSrcPort, AbstractDstPort
from lava.magma.core.process.process import AbstractProcess


@dataclass(eq=True, frozen=True)
Expand All @@ -27,6 +28,10 @@ class Payload:
dst_port_initializer: PortInitializer = None


def lmt_init_id():
return -1


class ChannelMap(dict):
"""The ChannelMap is used by the SubCompilers during compilation to
communicate how they are planning to partition Processes onto their
Expand All @@ -35,7 +40,7 @@ class ChannelMap(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._initializers_lookup = dict()
self._lmt_allocation_dict: ty.Dict[int, int] = defaultdict(lambda: -1)
self._lmt_allocation_dict: ty.Dict[int, int] = defaultdict(lmt_init_id)

def __setitem__(
self, key: PortPair, value: Payload, dict_setitem=dict.__setitem__
Expand Down Expand Up @@ -91,7 +96,8 @@ def from_proc_groups(self,
return channel_map

@classmethod
def _get_port_pairs_from_proc_groups(self, proc_groups: ty.List[ProcGroup]):
def _get_port_pairs_from_proc_groups(cls,
proc_groups: ty.List[ProcGroup]):
"""Loop over processes connectivity and get all connected port pairs."""
processes = list(itertools.chain.from_iterable(proc_groups))
port_pairs = []
Expand All @@ -102,7 +108,7 @@ def _get_port_pairs_from_proc_groups(self, proc_groups: ty.List[ProcGroup]):
for src_port in src_ports:
dst_ports = src_port.get_dst_ports()
for dst_port in dst_ports:
if self._is_leaf_process_port(dst_port, processes):
if cls._is_leaf_process_port(dst_port, processes):
port_pairs.append(PortPair(src=src_port, dst=dst_port))
return port_pairs

Expand All @@ -111,9 +117,9 @@ def _is_leaf_process_port(dst_port, processes):
dst_process = dst_port.process
return True if dst_process in processes else False

def set_port_initializer(
self, port: AbstractPort, port_initializer: PortInitializer
):
def set_port_initializer(self,
port: AbstractPort,
port_initializer: PortInitializer):
if port in self._initializers_lookup.keys():
raise AssertionError(
"An initializer for this port has already " "been assigned."
Expand All @@ -125,3 +131,77 @@ def get_port_initializer(self, port):

def has_port_initializer(self, port) -> bool:
return port in self._initializers_lookup

def write_to_cache(self,
cache_object: ty.Dict[ty.Any, ty.Any],
proc_to_procname_map: ty.Dict[AbstractProcess, str]):
cache_object["lmt_allocation"] = self._lmt_allocation_dict

initializers_serializable: ty.List[ty.Tuple[str, str,
PortInitializer]] = []
port: AbstractPort
pi: PortInitializer
for port, pi in self._initializers_lookup.items():
procname = proc_to_procname_map[port.process]
if procname.startswith("Process_"):
msg = f"Unable to Cache. " \
f"Please give unique names to every process. " \
f"Violation Name: {procname=}"
raise Exception(msg)

initializers_serializable.append((procname, port.name, pi))
cache_object["initializers"] = initializers_serializable

cm_serializable: ty.List[ty.Tuple[ty.Tuple[str, str],
ty.Tuple[str, str],
Payload]] = []
port_pair: PortPair
payload: Payload
for port_pair, payload in self.items():
src_port: AbstractPort = ty.cast(AbstractPort, port_pair.src)
dst_port: AbstractPort = ty.cast(AbstractPort, port_pair.dst)
src_proc_name: str = proc_to_procname_map[src_port.process]
src_port_info = (src_proc_name, src_port.name)
dst_proc_name: str = proc_to_procname_map[dst_port.process]
dst_port_info = (dst_proc_name, dst_port.name)
if src_proc_name.startswith("Process_") or \
dst_proc_name.startswith("Process_"):
msg = f"Unable to Cache. " \
f"Please give unique names to every process. " \
f"Violation Name: {src_proc_name=} {dst_proc_name=}"
raise Exception(msg)

cm_serializable.append((src_port_info, dst_port_info, payload))
cache_object["channelmap_dict"] = cm_serializable

def read_from_cache(self,
cache_object: ty.Dict[ty.Any, ty.Any],
procname_to_proc_map: ty.Dict[str, AbstractProcess]):
self._lmt_allocation_dict = cache_object["lmt_allocation"]
initializers_serializable = cache_object["initializers"]
cm_serializable = cache_object["channelmap_dict"]

for procname, port_name, pi in initializers_serializable:
process: AbstractProcess = procname_to_proc_map[procname]
port: AbstractPort = getattr(process, port_name)
self._initializers_lookup[port] = pi

src_port_info: ty.Tuple[str, str]
dst_port_info: ty.Tuple[str, str]
payload: Payload
for src_port_info, dst_port_info, payload in cm_serializable:
src_port_process: AbstractProcess = procname_to_proc_map[
src_port_info[0]]
src: AbstractPort = getattr(src_port_process,
src_port_info[1])
dst_port_process: AbstractProcess = procname_to_proc_map[
dst_port_info[0]]
dst: AbstractPort = getattr(dst_port_process,
dst_port_info[1])
for port_pair, pld in self.items():
s, d = port_pair.src, port_pair.dst
if s.name == src.name and d.name == dst.name and \
s.process.name == src_port_process.name and \
d.process.name == dst_port_process.name:
pld.src_port_initializer = payload.src_port_initializer
pld.dst_port_initializer = payload.dst_port_initializer
54 changes: 52 additions & 2 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import itertools
import logging
import os
import pickle # noqa: S403 # nosec
import typing as ty
from collections import OrderedDict, defaultdict

Expand Down Expand Up @@ -219,6 +221,32 @@ def _compile_proc_groups(
The global dict-like ChannelMap given as input but with values
updated according to partitioning done by subcompilers.
"""
procname_to_proc_map: ty.Dict[str, AbstractProcess] = {}
proc_to_procname_map: ty.Dict[AbstractProcess, str] = {}
for proc_group in proc_groups:
for p in proc_group:
procname_to_proc_map[p.name] = p
proc_to_procname_map[p] = p.name

if self._compile_config.get("cache", False):
cache_dir = self._compile_config["cache_dir"]
if os.path.exists(os.path.join(cache_dir, "cache")):
with open(os.path.join(cache_dir, "cache"), "rb") \
as cache_file:
cache_object = pickle.load(cache_file) # noqa: S301 # nosec

proc_builders_values = cache_object["procname_to_proc_builder"]
proc_builders = {}
for proc_name, pb in proc_builders_values.items():
proc = procname_to_proc_map[proc_name]
proc_builders[proc] = pb
pb.proc_params = proc.proc_params

channel_map.read_from_cache(cache_object, procname_to_proc_map)
print(f"\nBuilders and Channel Map loaded from "
f"Cache {cache_dir}\n")
return proc_builders, channel_map

# Create the global ChannelMap that is passed between
# SubCompilers to communicate about Channels between Processes.

Expand Down Expand Up @@ -248,6 +276,28 @@ def _compile_proc_groups(
subcompilers, channel_map
)

if self._compile_config.get("cache", False):
cache_dir = self._compile_config["cache_dir"]
os.makedirs(cache_dir, exist_ok=True)
cache_object = {}
# Validate All Processes are Named
procname_to_proc_builder = {}
for p, pb in proc_builders.items():
if p.name in procname_to_proc_builder or \
"Process_" in p.name:
msg = f"Unable to Cache. " \
f"Please give unique names to every process. " \
f"Violation Name: {p.name=}"
raise Exception(msg)
procname_to_proc_builder[p.name] = pb
pb.proc_params = None
cache_object["procname_to_proc_builder"] = procname_to_proc_builder
channel_map.write_to_cache(cache_object, proc_to_procname_map)
with open(os.path.join(cache_dir, "cache"), "wb") as cache_file:
pickle.dump(cache_object, cache_file)
for p, pb in proc_builders.items():
pb.proc_params = p.proc_params
print(f"\nBuilders and Channel Map stored to Cache {cache_dir}\n")
return proc_builders, channel_map

@staticmethod
Expand Down Expand Up @@ -713,8 +763,8 @@ def _create_runtime_service_builder(
model_ids: ty.List[int] = [p.id for p in sync_domain.processes]

rs_kwargs = {
"c_builders" : list(c_builders.values()),
"nc_builders" : list(nc_builders.values())
"c_builders": list(c_builders.values()),
"nc_builders": list(nc_builders.values())
}
if isinstance(run_cfg, AbstractLoihiHWRunCfg):
rs_kwargs["callback_fxs"] = run_cfg.callback_fxs
Expand Down
4 changes: 4 additions & 0 deletions src/lava/magma/compiler/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def map_cores(self, executable: Executable,
address.update(chips)
break
if len(address) > 1 and hasattr(var_model, "address"):
print('=' * 50)
print('Note to JOYESH from the future:')
print('Add logic to make multichip conv input work for YOLO.')
print('=' * 50)
raise ValueError("Lava Compiler doesn't support port"
"splitting currently. MultiChip "
"Not Supported ")
Expand Down
14 changes: 12 additions & 2 deletions src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ImplicitVarPort,
VarPort,
)
from lava.magma.core.process.ports.connection_config import ConnectionConfig
from lava.magma.core.process.process import AbstractProcess
from lava.magma.compiler.subcompilers.constants import SPIKE_BLOCK_CORE

Expand Down Expand Up @@ -189,7 +190,11 @@ def _create_inport_initializers(
pi.embedded_counters = \
np.arange(counter_start_idx,
counter_start_idx + num_counters, dtype=np.int32)
pi.connection_config = list(port.connection_configs.values())[0]
if port.connection_configs.values():
conn_config = list(port.connection_configs.values())[0]
else:
conn_config = ConnectionConfig()
pi.connection_config = conn_config
port_initializers.append(pi)
self._tmp_channel_map.set_port_initializer(port, pi)
else:
Expand All @@ -209,7 +214,7 @@ def _create_outport_initializers(
self, process: AbstractProcess
) -> ty.List[PortInitializer]:
port_initializers = []
for port in list(process.out_ports):
for k, port in enumerate(list(process.out_ports)):
pi = PortInitializer(
port.name,
port.shape,
Expand All @@ -218,6 +223,11 @@ def _create_outport_initializers(
self._compile_config["pypy_channel_size"],
port.get_incoming_transform_funcs(),
)
if port.connection_configs.values():
conn_config = list(port.connection_configs.values())[k]
else:
conn_config = ConnectionConfig()
pi.connection_config = conn_config
port_initializers.append(pi)
self._tmp_channel_map.set_port_initializer(port, pi)
return port_initializers
Expand Down
8 changes: 8 additions & 0 deletions src/lava/magma/compiler/var_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ class NcSpikeIOVarModel(NcVarModel):
interface: SpikeIOInterface = SpikeIOInterface.ETHERNET
spike_io_port: SpikeIOPort = SpikeIOPort.ETHERNET
spike_io_mode: SpikeIOMode = SpikeIOMode.TIME_COMPARE
ethernet_chip_id: ty.Optional[ty.Tuple[int, int, int]] = None
ethernet_chip_idx: ty.Optional[int] = None
decode_config: ty.Optional[DecodeConfig] = None
time_compare: ty.Optional[TimeCompare] = None
spike_encoder: ty.Optional[SpikeEncoder] = None


@dataclass
class NcConvSpikeInVarModel(NcSpikeIOVarModel):
# Tuple will be in the order of [atom_paylod, atom_axon, addr_idx]
region_map: ty.List[ty.List[ty.Tuple[int, int, int]]] = None
4 changes: 4 additions & 0 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,10 @@ def run_async(self) -> None:
if py_loihi_model.post_guard(self):
py_loihi_model.run_post_mgmt(self)
self.time_step += 1
# self.advance_to_time_step(self.time_step)
for port in self.py_ports:
if isinstance(port, PyOutPort):
port.advance_to_time_step(self.time_step)

py_async_model = type(
name,
Expand Down
4 changes: 4 additions & 0 deletions src/lava/magma/core/process/ports/connection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# expressly stated in the License.
from dataclasses import dataclass
from enum import IntEnum, Enum
import typing as ty


class SpikeIOInterface(IntEnum):
Expand Down Expand Up @@ -54,3 +55,6 @@ class ConnectionConfig:
spike_io_mode: SpikeIOMode = SpikeIOMode.TIME_COMPARE
num_time_buckets: int = 1 << 16
ethernet_mac_address: str = "0x90e2ba01214c"
loihi_mac_address: str = "0x0015edbeefed"
ethernet_chip_id: ty.Optional[ty.Tuple[int, int, int]] = None
ethernet_chip_idx: ty.Optional[int] = None
Loading

0 comments on commit 90b4ef0

Please sign in to comment.