diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 83598940882..030ade687a8 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -25,7 +25,11 @@ from executorch.exir.tensor import TensorSpec from torch import fx -from torch.export.exported_program import ExportGraphSignature, InputKind +from torch.export.exported_program import ( + ConstantArgument, + ExportGraphSignature, + InputKind, +) from torch.fx import Node from torch.utils._pytree import tree_flatten @@ -338,16 +342,23 @@ def _do_user_inputs_exist(graph_signature: Optional[ExportGraphSignature]) -> bo if graph_signature is None: return False - return ( - len( - list( - filter( - lambda input: input.kind == InputKind.USER_INPUT, - graph_signature.input_specs, - ) - ) + user_inputs = list( + filter( + lambda input: input.kind == InputKind.USER_INPUT, + graph_signature.input_specs, ) - ) > 0 + ) + + # Return false if: + # - there are no inputs. + # - if user inputs are all prims (as this currently + # causes the memory planning verifier to blow up). + # Otherwise, return true. + return any( + not isinstance(input.arg, ConstantArgument) + or not isinstance(input.arg.value, (int, float, bool, str)) + for input in user_inputs + ) def get_graph_input_tensors( diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index b87ae2dfb58..6b895f27922 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -16,6 +16,7 @@ from executorch.exir import ExecutorchBackendConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( + _do_user_inputs_exist, filter_nodes, get_node_tensor_specs, greedy, @@ -307,6 +308,56 @@ def wrapper(self: "TestMemoryPlanning") -> None: return wrapper +class TestMemoryPlanningUserInputs(unittest.TestCase): + """ + Ensure that MemoryPlanning Verifer only assumes a model + has a user input if it has at least one tensor input. + """ + + def test_tensor_only_inputs(self): + class TensorModel(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + model = TensorModel() + inputs = (torch.randn(2), torch.randn(2)) + ep = export(model, inputs, strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertTrue(result) + + def test_mixed_inputs(self): + class MixedModel(torch.nn.Module): + def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: + return x * y + + model = MixedModel() + inputs = (torch.randn(2), 3) + ep = export(model, inputs, strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertTrue(result) + + def test_primitive_only_inputs(self): + class PrimModel(torch.nn.Module): + def forward(self, x: int, y: float) -> float: + return x * y + + model = PrimModel() + inputs = (2, 3.0) + ep = export(model, inputs, strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertFalse(result) + + def test_no_inputs(self): + class NoInputModel(torch.nn.Module): + def forward(self) -> torch.Tensor: + return torch.tensor(1.0) + + model = NoInputModel() + ep = export(model, (), strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertFalse(result) + + class TestMemoryPlanning(unittest.TestCase): def verify_reuse( self,