Skip to content

Update MemoryPlanning Verifier to only assume model has user input if it has at least one tensor input #10617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from executorch.exir.tensor import TensorSpec

from torch import fx
from torch.export.exported_program import ExportGraphSignature, InputKind
from torch.export.exported_program import (
ConstantArgument,
ExportGraphSignature,
InputKind,
)
from torch.fx import Node
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -338,16 +342,23 @@ def _do_user_inputs_exist(graph_signature: Optional[ExportGraphSignature]) -> bo
if graph_signature is None:
return False

return (
len(
list(
filter(
lambda input: input.kind == InputKind.USER_INPUT,
graph_signature.input_specs,
)
)
user_inputs = list(
filter(
lambda input: input.kind == InputKind.USER_INPUT,
graph_signature.input_specs,
)
) > 0
)

# Return false if:
# - there are no inputs.
# - if user inputs are all prims (as this currently
# causes the memory planning verifier to blow up).
# Otherwise, return true.
return any(
not isinstance(input.arg, ConstantArgument)
or not isinstance(input.arg.value, (int, float, bool, str))
for input in user_inputs
)


def get_graph_input_tensors(
Expand Down
51 changes: 51 additions & 0 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.exir import ExecutorchBackendConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import (
_do_user_inputs_exist,
filter_nodes,
get_node_tensor_specs,
greedy,
Expand Down Expand Up @@ -307,6 +308,56 @@ def wrapper(self: "TestMemoryPlanning") -> None:
return wrapper


class TestMemoryPlanningUserInputs(unittest.TestCase):
"""
Ensure that MemoryPlanning Verifer only assumes a model
has a user input if it has at least one tensor input.
"""

def test_tensor_only_inputs(self):
class TensorModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y

model = TensorModel()
inputs = (torch.randn(2), torch.randn(2))
ep = export(model, inputs, strict=True)
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
self.assertTrue(result)

def test_mixed_inputs(self):
class MixedModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
return x * y

model = MixedModel()
inputs = (torch.randn(2), 3)
ep = export(model, inputs, strict=True)
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
self.assertTrue(result)

def test_primitive_only_inputs(self):
class PrimModel(torch.nn.Module):
def forward(self, x: int, y: float) -> float:
return x * y

model = PrimModel()
inputs = (2, 3.0)
ep = export(model, inputs, strict=True)
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
self.assertFalse(result)

def test_no_inputs(self):
class NoInputModel(torch.nn.Module):
def forward(self) -> torch.Tensor:
return torch.tensor(1.0)

model = NoInputModel()
ep = export(model, (), strict=True)
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
self.assertFalse(result)


class TestMemoryPlanning(unittest.TestCase):
def verify_reuse(
self,
Expand Down
Loading