From 415b485ed26f7c3237d741d97cc35d0a785a4e8d Mon Sep 17 00:00:00 2001 From: nikitaved Date: Fri, 5 Jul 2024 16:19:59 +0200 Subject: [PATCH] `iter(TensorProxy)` lookaside (#718) --- thunder/core/proxies.py | 9 +++++++ thunder/tests/test_interpreter.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index df6dce4540..4d85011056 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1316,6 +1316,15 @@ def __getattr__(self, attr: str, /): return method_or_value + def __iter__(self): + # NOTE: this implementation is equivalent to torch.Tensor.__iter__ + + if self.ndim == 0: + raise TypeError("iteration over a 0-dim tensor") + + unbound_tuple = self.unbind(0) + return iter(unbound_tuple) + # # Default attribute # diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index a7919325be..bb702421b4 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1082,6 +1082,50 @@ def f(a, b, c): # } +def test_tensor_proxy_iter_lookaside(jit): + t0 = torch.rand(3, 3) + t1 = torch.rand(0, 3) + t2 = torch.rand(()) + + for x in (t0, t1): + + def f(x): + for i, xi in enumerate(x): + pass + return x + + jf = jit(f) + + assert f(x) is jf(x) + + def f(x): + res = 0 + for i, xi in enumerate(x): + res = xi + return res + + jf = jit(f) + + assert_close(jf(x), f(x)) + + with pytest.raises(TypeError, match="iteration over a 0-d tensor"): + jf(t2) + + with pytest.raises(TypeError, match="iteration over a 0-d tensor"): + f(t2) + + def f(x): + res = x + for xi in x: + res = res + xi.unsqueeze(0) + return res + + jf = jit(f) + + for x in (t0, t1): + assert_close(jf(x), f(x)) + + def test_calling_methods(jit): jitting = False