Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,7 @@ def get_deps(
a bool indicating if the deps are valid and a list of all the
dep nodes. This handles the src partition for
"""
if self.src_partitions is None:
# Cache src partitions so we don't have to recompute them every time
self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)
self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)

# src_partition is None if node is not in source partition,
# otherwise gives us the linear source partition it belongs to
Expand Down
11 changes: 11 additions & 0 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,14 @@ runtime.python_test(
"//executorch/examples/xnnpack:models", # @manual
],
)

runtime.python_test(
name = "test_xnnpack_partitioner",
srcs = ["test_xnnpack_partitioner.py"],
deps = [
"//caffe2:torch",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/exir:lib",
"//executorch/extension/pybindings:portable_lib",
],
)
79 changes: 79 additions & 0 deletions backends/xnnpack/test/test_xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
import unittest

import torch
import torch.nn.functional as F

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge, to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from torch.export import export


Expand Down Expand Up @@ -82,3 +87,77 @@ def test_no_warning_for_to_edge_transform_and_lower_workflow(self):

log_contents = log_capture_string.getvalue()
self.assertNotIn("DEPRECATION WARNING", log_contents)

def test_multi_method_partitioning_with_shared_weights(self):
"""
Test that multi-method models with shared weights are correctly partitioned.
Verify that:
1. Both methods are fully lowered to XNNPACK.
2. Constants are not duplicated between named data and constant buffers.
3. Program executes correctly.
"""

class MultiMethodModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(8, 16)
self.linear2 = torch.nn.Linear(16, 8)

def forward(self, x):
return self.linear2(F.sigmoid(self.linear(x)))

def forward_2(self, x):
return self.linear2(F.relu(self.linear(x)))

def example_inputs(self):
return (torch.randn(1, 8),)

model = MultiMethodModel()

# Get eager reference output.
example_inputs = model.example_inputs()
with torch.no_grad():
fwd1_eager = model.forward(*example_inputs)
fwd2_eager = model.forward_2(*example_inputs)

# Export both methods
ep_fwd = export(model, model.example_inputs(), strict=True)
# Patch the forward, as export only traces the 'forward' method.
model.forward = model.forward_2
ep_fwd_2 = export(model, model.example_inputs(), strict=True)

# Convert to edge and lower to executorch
edge = to_edge({"forward": ep_fwd, "forward_2": ep_fwd_2})
lowered = edge.to_backend(XnnpackPartitioner(force_fp32_dynamic_linear=True))
executorch = lowered.to_executorch()

# Check that graph is fully delegated.
nodes_1 = list(lowered._edge_programs["forward"].graph.nodes)
nodes_2 = list(lowered._edge_programs["forward_2"].graph.nodes)
self.assertEqual(len(nodes_1), 5)
self.assertEqual(len(nodes_2), 5)
expected_node_names = [
"x",
"lowered_module_0",
"executorch_call_delegate",
"getitem",
"output_1",
]
for n in expected_node_names:
self.assertTrue(any(node.name == n for node in nodes_1))
self.assertTrue(any(node.name == n for node in nodes_2))

# Check that weights are not duplicated.
self.assertEqual(len(executorch._named_data.pte_data), 4)
self.assertEqual(len(executorch._named_data.buffers), 4)
self.assertEqual(len(executorch._named_data.external_data), 0)

# Check that there are no constant buffers (besides the placeholder).
self.assertEqual(len(executorch._emitter_output.program.constant_buffer), 1)

# Check for model correctness.
executorch_module = _load_for_executorch_from_buffer(executorch.buffer)
fwd1_et = executorch_module.run_method("forward", example_inputs)
fwd2_et = executorch_module.run_method("forward_2", example_inputs)
self.assertTrue(torch.allclose(fwd1_eager, fwd1_et[0], 1e-3))
self.assertTrue(torch.allclose(fwd2_eager, fwd2_et[0], 1e-3))
Loading