From 51c07f9f69aedf884fc697f3ef8545cb0303e2a9 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 20 May 2024 14:19:34 +0100 Subject: [PATCH] [dynamo] Allow asserts to fail (#126661) Currently if an assertion is statically known to be false, dynamo converts it to `_assert_async` which inductor currently ignores. Instead this graph breaks to raise the original assertion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126661 Approved by: https://github.com/ezyang --- test/dynamo/test_misc.py | 15 +++++++++++++++ .../TestPythonRegistration.test_alias_analysis | 0 .../TestScript.test_unspecialized_any_binding | 0 torch/_dynamo/symbolic_convert.py | 8 +++++--- 4 files changed, 20 insertions(+), 3 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis delete mode 100644 test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 5d7f780457d093..f07021c3155853 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1358,6 +1358,21 @@ def f(x): self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) + def test_assert(self): + @torch.compile + def fn1(x): + assert x.shape != x.shape + + with self.assertRaises(AssertionError): + a = torch.randn(10) + fn1(a) + + def fn2(x): + assert x.shape == x.shape + return x.abs() + + torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1) + def test_config_obj(self): class Cfg: def __init__(self): diff --git a/test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis b/test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding b/test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 093809703405f7..864e53777941e7 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -343,9 +343,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): ): error_msg: VariableTracker = self.pop() # Skip over things like `assert True` - if value.is_python_constant() and bool(value.as_python_constant()): - self.jump(inst) - return + if value.is_python_constant(): + if bool(value.as_python_constant()): + return self.jump(inst) + else: + jump_graph_break(self, inst, value) # TODO maybe should respect DtoH sync intention of users later?? # Manually insert torch._assert_async instead of python assert and jump over