Skip to content

Commit

Permalink
Enable manual partitioning (#876)
Browse files Browse the repository at this point in the history
* enable manual partitioning

* fix codacy complaint

* fix unit tests

* fixed unit tests

* fixed linting
  • Loading branch information
PhilippPlank authored Aug 1, 2024
1 parent ae13b7a commit a82abc1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
13 changes: 10 additions & 3 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def _compile_proc_groups(
f"Cache {cache_dir}\n")
return proc_builders, channel_map

# Get manual partitioning, if available
partitioning = self._compile_config.get("partitioning", None)

# Create the global ChannelMap that is passed between
# SubCompilers to communicate about Channels between Processes.

Expand All @@ -266,7 +269,8 @@ def _compile_proc_groups(
subcompilers.append(pg_subcompilers)

# Compile this ProcGroup.
self._compile_proc_group(pg_subcompilers, channel_map)
self._compile_proc_group(pg_subcompilers, channel_map,
partitioning)

# Flatten the list of all SubCompilers.
subcompilers = list(itertools.chain.from_iterable(subcompilers))
Expand Down Expand Up @@ -403,7 +407,8 @@ def _create_subcompilers(

@staticmethod
def _compile_proc_group(
subcompilers: ty.List[AbstractSubCompiler], channel_map: ChannelMap
subcompilers: ty.List[AbstractSubCompiler], channel_map: ChannelMap,
partitioning: ty.Dict[str, ty.Dict]
) -> None:
"""For a given list of SubCompilers that have been initialized with
the Processes of a single ProcGroup, iterate through the compilation
Expand All @@ -419,6 +424,8 @@ def _compile_proc_group(
channel_map : ChannelMap
The global ChannelMap that contains information about Channels
between Processes.
partitioning: ty.Dict
Optional manual mapping dictionary used by ncproc compiler.
"""
channel_map_prev = None

Expand All @@ -431,7 +438,7 @@ def _compile_proc_group(
for subcompiler in subcompilers:
# Compile the Processes registered with each SubCompiler and
# update the ChannelMap.
channel_map = subcompiler.compile(channel_map)
channel_map = subcompiler.compile(channel_map, partitioning)

@staticmethod
def _extract_proc_builders(
Expand Down
3 changes: 2 additions & 1 deletion src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(
super().__init__(proc_group, compile_config)
self._spike_io_counter_offset: Offset = Offset()

def compile(self, channel_map: ChannelMap) -> ChannelMap:
def compile(self, channel_map: ChannelMap,
partitioning: ty.Dict = None) -> ChannelMap:
return self._update_channel_map(channel_map)

def __del__(self):
Expand Down
13 changes: 8 additions & 5 deletions tests/lava/magma/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def create_patches(
and the compile() method returns the given ChannelMap unchanged.
."""

def compile_return(channel_map: ChannelMap) -> ChannelMap:
def compile_return(channel_map: ChannelMap,
partitioning=None) -> ChannelMap:
return channel_map

py_patch = patch(
Expand Down Expand Up @@ -391,13 +392,13 @@ def test_compile_proc_group_single_loop(self) -> None:
subcompilers = [py_proc_compiler]

# Call the method to be tested.
self.compiler._compile_proc_group(subcompilers, channel_map)
self.compiler._compile_proc_group(subcompilers, channel_map, None)

# Check that it called compile() on every SubCompiler instance
# exactly once. After that, the while loop should exit because the
# ChannelMap instance has not changed.
for sc in subcompilers:
sc.compile.assert_called_once_with({})
sc.compile.assert_called_once_with({}, None)

def test_compile_proc_group_multiple_loops(self) -> None:
"""Test whether the correct methods are called on all objects when
Expand All @@ -424,13 +425,15 @@ def test_compile_proc_group_multiple_loops(self) -> None:
subcompilers = [py_proc_compiler]

# Call the method to be tested.
self.compiler._compile_proc_group(subcompilers, channel_map)
self.compiler._compile_proc_group(subcompilers, channel_map,
None)

# Check that it called compile() on every SubCompiler instance
# exactly once. After that, the while loop should exit because the
# ChannelMap instance has not changed.
for sc in subcompilers:
sc.compile.assert_called_with({**channel_map1, **channel_map2})
sc.compile.assert_called_with({**channel_map1, **channel_map2},
None)
self.assertEqual(sc.compile.call_count, 3)

def test_extract_proc_builders(self) -> None:
Expand Down

0 comments on commit a82abc1

Please sign in to comment.