Skip to content

Commit 4fab7c2

Browse files
mcr229facebook-github-bot
authored andcommitted
fix bug with sequential backends (#10708)
Summary: https://github.com/pytorch/executorch/pull/10584/files#r2070213706 there's a bug described in this PR comment. I add some tests and a fix to cover it. Essentially when sequential partitions go through preprocess_all, the get_item nodes from the first partition in the sequence don't correctly get mapped to the arguments input into the second partition. This is because the name of these nodes change (the original node to a get_item node). Instead of checking for the names, we instead delete the nodes we know must be deleted from the inputspec Additionaly, there is an issue with validation. the _validate fails when there are call_module nodes still in the graph. Since preprocess_multimethod will lower the call_submodule nodes one-by-one calling _validate before all the call_submodule nodes are transformed to call_delegate nodes will fail. We remove the _validate call from unsafe_adjust_original_program and instead call _validate on the original program after all the submodule nodes have been converted to call_delegate Differential Revision: D74226258
1 parent e196b50 commit 4fab7c2

File tree

4 files changed

+136
-24
lines changed

4 files changed

+136
-24
lines changed

exir/backend/backend_api.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,16 @@ def _insert_lowered_submodule(
204204
owning_graph_module = call_submodule_node.graph.owning_module
205205
# call delegate args should only use user_inputs
206206
call_delegate_args = []
207-
# Preserve input order as user_inputs
208-
for inp_name in submodule_program.graph_signature.user_inputs:
209-
for inp_node in call_submodule_node.all_input_nodes:
210-
if inp_node.name == inp_name:
211-
call_delegate_args.append(inp_node)
212-
break
207+
# names of input_specs to delete
208+
input_specs_to_delete = toplevel_input_specs_to_delete
209+
# Delete owned constants from the call_submodule_node args
210+
for call_sm_input in call_submodule_node.args:
211+
if (
212+
isinstance(call_sm_input, torch.fx.Node)
213+
and call_sm_input.name in input_specs_to_delete.keys()
214+
):
215+
continue
216+
call_delegate_args.append(call_sm_input)
213217

214218
def generate_debug_handle(ep: ExportedProgram) -> int:
215219
"""
@@ -324,6 +328,7 @@ def _partition_and_lower_one_graph_module(
324328
toplevel_input_specs_to_delete,
325329
toplevel_output_specs_to_delete,
326330
)
331+
owning_program._validate()
327332

328333
return tagged_graph_module
329334

@@ -742,6 +747,7 @@ def to_backend(
742747
for method_name in method_to_edge_program.keys():
743748
if method_name in method_to_tagged_exported_program:
744749
tagged_exported_program = method_to_tagged_exported_program[method_name]
750+
tagged_exported_program._validate()
745751
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
746752
root=tagged_exported_program.graph_module,
747753
graph=tagged_exported_program.graph_module.graph,

exir/backend/test/backend_with_preprocess_all_demo.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,30 @@
2121
)
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.graph_module import get_control_flow_submodules
24+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2425
from torch.export.exported_program import ExportedProgram
2526
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2627

2728

29+
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
30+
return (
31+
is_param(exp_prog, node)
32+
or is_buffer(exp_prog, node)
33+
or is_lifted_tensor_constant(exp_prog, node)
34+
)
35+
36+
37+
def get_total_num_ops_in_ep(edge_programs, supported_ops):
38+
total_number_of_ops = 0
39+
for edge_program in edge_programs.values():
40+
for partitioned_program in edge_program:
41+
for node in partitioned_program.graph.nodes:
42+
if node.op == "call_function":
43+
if node.target in supported_ops:
44+
total_number_of_ops += 1
45+
return total_number_of_ops
46+
47+
2848
def _preprocess_multimethod(
2949
edge_programs: Dict[str, List[ExportedProgram]],
3050
compile_specs: Dict[str, List[List[CompileSpec]]],
@@ -37,13 +57,7 @@ def _preprocess_multimethod(
3757
in testing for a partitioner which tags different partitions for different backends
3858
to be lowered to
3959
"""
40-
total_number_of_ops = 0
41-
for edge_program in edge_programs.values():
42-
for partitioned_program in edge_program:
43-
for node in partitioned_program.graph.nodes:
44-
if node.op == "call_function":
45-
if node.target in supported_ops:
46-
total_number_of_ops += 1
60+
total_number_of_ops = get_total_num_ops_in_ep(edge_programs, supported_ops)
4761
all_processed_results = {key: [] for key in edge_programs.keys()}
4862

4963
for method_name, partitioned_programs in edge_programs.items():
@@ -67,6 +81,8 @@ def _preprocess_multimethod(
6781
raise RuntimeError(
6882
f"{node.op} {node.target.__name__} is not supported in backend {backend_name}"
6983
)
84+
if is_param_node(partitioned_program, node):
85+
processed_bytes += f"CONST{node.name}:"
7086

7187
processed_bytes += "#"
7288
for cs in compile_spec_for_partition:
@@ -171,14 +187,30 @@ def preprocess_multimethod(
171187

172188

173189
class AddSinOperatorSupport(OperatorSupportBase):
190+
def __init__(self, original_program):
191+
self.original_program = original_program
192+
super().__init__()
193+
174194
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
175-
return node.op == "call_function" and node.target in [
195+
supported_targets = [
176196
exir_ops.edge.aten.add.Tensor,
177197
exir_ops.edge.aten.sin.default,
178198
]
199+
if node.op == "call_function" and node.target in supported_targets:
200+
return True
201+
202+
if node.op == "placeholder" and is_param_node(self.original_program, node):
203+
for user in node.users.keys():
204+
if user.target in supported_targets:
205+
return True
206+
return False
179207

180208

181209
class SubCosOperatorSupport(OperatorSupportBase):
210+
def __init__(self, original_program):
211+
self.original_program = original_program
212+
super().__init__()
213+
182214
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
183215
return node.op == "call_function" and node.target in [
184216
exir_ops.edge.aten.sub.Tensor,
@@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner):
199231
"""
200232

201233
def __init__(self) -> None:
202-
self.add_sin_support = any_chain(AddSinOperatorSupport())
203-
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__
204-
205-
self.sub_cos_support = any_chain(SubCosOperatorSupport())
206234
self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__
235+
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__
207236

208237
def _partition_graph_module(
209238
self,
@@ -260,6 +289,8 @@ def _partition_graph_module(
260289
return partition_tags, start_idx_for_submodules
261290

262291
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
292+
self.add_sin_support = any_chain(AddSinOperatorSupport(exported_program))
293+
self.sub_cos_support = any_chain(SubCosOperatorSupport(exported_program))
263294
partition_tags, _ = self._partition_graph_module(exported_program.graph_module)
264295
return PartitionResult(
265296
tagged_exported_program=exported_program, partition_tags=partition_tags

exir/backend/test/test_to_backend_multi_method.py

+71
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,77 @@ def forward(self, x):
392392
}
393393
self._test(test_set)
394394

395+
def test_multi_method_to_backend_sequential_delegates(self):
396+
class SequentialBackendModule(torch.nn.Module):
397+
def __init__(self):
398+
super().__init__()
399+
400+
def forward(self, x, y, z):
401+
# delegate one
402+
x = x - x
403+
y = y - y
404+
z = z - z
405+
# graph break
406+
a = x * y * z
407+
# delegate two uses outputs from delegate one and the
408+
# output from the graph break
409+
b = x + a
410+
b = b + z + a
411+
b = b + y + a
412+
return b
413+
414+
module = SequentialBackendModule()
415+
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
416+
seq_edgeir_m = to_edge(torch.export.export(module, example_inputs))
417+
418+
test_set = {
419+
"seq_edgeir": (
420+
seq_edgeir_m.exported_program(),
421+
BackendWithPreprocessAllPartitioner(),
422+
[
423+
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
424+
"FirstBackendWithPreprocessAll#5#aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
425+
],
426+
),
427+
}
428+
self._test(test_set)
429+
430+
def test_multi_method_to_backend_constants(self):
431+
class SequentialBackendModule(torch.nn.Module):
432+
def __init__(self):
433+
super().__init__()
434+
self.const = torch.zeros(1)
435+
436+
def forward(self, x, y, z):
437+
# delegate one
438+
x = x - x
439+
y = y - y
440+
z = z - z
441+
# graph break
442+
a = x * y * z * self.const
443+
# delegate two uses outputs from delegate one and the
444+
# output from the graph break
445+
b = x + self.const + a
446+
b = z + a + b
447+
b = y + a + b
448+
return b
449+
450+
module = SequentialBackendModule()
451+
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
452+
seq_const_m = to_edge(torch.export.export(module, example_inputs))
453+
454+
test_set = {
455+
"seq_const": (
456+
seq_const_m.exported_program(),
457+
BackendWithPreprocessAllPartitioner(),
458+
[
459+
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
460+
"FirstBackendWithPreprocessAll#6#CONSTc_const_copy_0:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
461+
],
462+
),
463+
}
464+
self._test(test_set)
465+
395466
def test_multi_method_to_backend_not_found(self):
396467
class SinModule(torch.nn.Module):
397468
def __init__(self):

exir/lowered_backend_module.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
381381

382382

383383
def arrange_graph_placeholders(
384-
gm: torch.fx.GraphModule, owning_program: ExportedProgram
384+
gm: torch.fx.GraphModule, owning_program: ExportedProgram, tag
385385
) -> torch.fx.GraphModule:
386386
"""
387387
Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
@@ -411,9 +411,15 @@ def arrange_graph_placeholders(
411411
if node.op != "placeholder":
412412
continue
413413

414-
if node.name in graph_sign.inputs_to_parameters:
414+
if (
415+
node.name in graph_sign.inputs_to_parameters
416+
and node.meta.get("delegation_tag", None) == tag
417+
):
415418
param_nodes.append(node)
416-
elif node.name in graph_sign.inputs_to_buffers:
419+
elif (
420+
node.name in graph_sign.inputs_to_buffers
421+
and node.meta.get("delegation_tag", None) == tag
422+
):
417423
buffer_nodes.append(node)
418424
else:
419425
input_nodes.append(node)
@@ -694,7 +700,7 @@ def create_exported_program_from_submodule(
694700
removed from the toplevel ExportedProgram.
695701
"""
696702
# Arrange the submodule's placeholders in order
697-
submodule = arrange_graph_placeholders(submodule, owning_program)
703+
submodule = arrange_graph_placeholders(submodule, owning_program, tag)
698704

699705
# TODO: we probably need to arrange the outputs wrt buffer mutations.
700706

@@ -958,5 +964,3 @@ def _unsafe_adjust_original_program( # noqa: C901
958964
if user_idx > idx:
959965
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
960966
break
961-
962-
original_program._validate()

0 commit comments

Comments
 (0)