From b56dd80a34f97dec38dcc4795abece8fc46130f7 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Tue, 19 Nov 2024 18:40:45 +0200 Subject: [PATCH 1/7] raise if is_leaf and require_grad in inplace operations --- thunder/executors/torchex.py | 19 ++++------ .../tests/test_inplace_functionalization.py | 38 +++++++++---------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7b363606e8..91f67829f5 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1,32 +1,24 @@ from __future__ import annotations import operator import importlib -from dataclasses import replace -from contextlib import ContextDecorator -from functools import wraps, partial -from inspect import signature -from itertools import groupby +from functools import partial from numbers import Number from typing import TYPE_CHECKING from collections.abc import Callable from collections.abc import Hashable, Sequence from collections.abc import Sequence from types import ModuleType -from enum import Enum, auto import torch -import math -from looseversion import LooseVersion +from thunder.core.compile_data import get_compile_data import thunder.core.dtypes as dtypes from thunder.core.dtypes import to_torch_dtype, to_dtype import thunder.core.devices as devices from thunder.core.devices import to_torch_device, to_device import thunder.core.prims as prims -from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace -from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype -from thunder.core.pytree import tree_flatten, tree_unflatten -from thunder.core.symbol import Symbol, BoundSymbol +from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, pytype +from thunder.core.symbol import Symbol from thunder.distributed.prims import DistributedReduceOps import thunder.distributed.prims as dist_prims import thunder.core.utils as utils @@ -2159,6 +2151,9 @@ def is_float_type(self, input): def _copy__impl(copy_from, copy_to): + cd = get_compile_data() + if cd is not None and cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad: + raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.") copy_to.copy_(copy_from) return copy_to diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 6f88f1f8eb..9b7973116e 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -476,31 +476,27 @@ def f(xs, ys, z): dtypes=NOTHING, ) def test_inplace_to_tensors_with_grad(executor, device, _): - @torch.no_grad def add_y(x, y): - x.add_(y, alpha=0.1) + # inplace operations requiring grad on leafs are illegal, trick to make z a non-leaf + z = torch.abs(x) * torch.sgn(x) + z.add_(y, alpha=0.1) - @torch.no_grad - def add_grad(x, y): - x.add_(x.grad, alpha=0.1) + jitted_f = executor.make_callable(add_y) + x = make_tensor((2, 2), device=device, dtype=torch.float32, requires_grad=True) + x.grad = make_tensor((2, 2), device=device, dtype=torch.float32) + y = make_tensor((2, 2), device=device, dtype=torch.float32) - for f in (add_y, add_grad): - jitted_f = executor.make_callable(f) - x = make_tensor((2, 2), device=device, dtype=torch.float32, requires_grad=True) - x.grad = make_tensor((2, 2), device=device, dtype=torch.float32) - y = make_tensor((2, 2), device=device, dtype=torch.float32) + x_ref = x.clone().detach().requires_grad_(True) + x_ref.grad = x.grad.clone().detach() + y_ref = y.clone().detach() - x_ref = x.clone().detach().requires_grad_(True) - x_ref.grad = x.grad.clone().detach() - y_ref = y.clone().detach() + res = jitted_f(x, y) + res_ref = add_y(x_ref, y_ref) - res = jitted_f(x, y) - res_ref = f(x_ref, y_ref) - - torch.testing.assert_close(x, x_ref) - torch.testing.assert_close(x.grad, x_ref.grad) - torch.testing.assert_close(y, y_ref) - torch.testing.assert_close(res, res_ref) + torch.testing.assert_close(x, x_ref) + torch.testing.assert_close(x.grad, x_ref.grad) + torch.testing.assert_close(y, y_ref) + torch.testing.assert_close(res, res_ref) @instantiate( @@ -551,6 +547,8 @@ def single_tensor_adam( jitted = executor.make_callable(single_tensor_adam) params, grads, exp_avgs, exp_avg_sqs = tensors + cd = thunder.compile_data(jitted) + cd.compile_options["torch_compile_fullgraph"] = False jitted(params, grads, exp_avgs, exp_avg_sqs, state_steps) torch.testing.assert_close(actual=tensors + [state_steps], expected=ref_tensors + [ref_state_steps]) From e4ad972529636a4412d54950743078ae26f98aad Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Wed, 20 Nov 2024 12:24:15 +0200 Subject: [PATCH 2/7] restore wraps --- thunder/executors/torchex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 56a27d9061..63f2b26621 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator import importlib -from functools import partial +from functools import partial, wraps from numbers import Number from typing import TYPE_CHECKING from collections.abc import Callable From 551de30c09bcee35c3acbd980c8da5762405b8cb Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Thu, 21 Nov 2024 16:01:14 +0200 Subject: [PATCH 3/7] add test and comment --- thunder/tests/test_inplace_copy.py | 8 ++++++++ thunder/tests/test_inplace_functionalization.py | 5 ++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 952df2faf0..cb50249d17 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -178,3 +178,11 @@ def func(T0): assert_close(a_ref, a) for o, o_ref in zip(o_thunder, o_eager): assert_close(o, o_ref) + + +@instantiate(dtypes=datatypes.float_math_dtypes) +def test_inplace_copy_of_leaf_requiring_grad_fails(executor, device, dtype): + tdtype = ttorch.to_torch_dtype(dtype) + a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=True) + with pytest.raises(RuntimeError): + a.copy_(a) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 9b7973116e..ea21b3e28c 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -545,10 +545,9 @@ def single_tensor_adam( ref_state_steps = [torch.tensor(1, device=device) for _ in range(2)] single_tensor_adam(*ref_tensors, state_steps=ref_state_steps) - jitted = executor.make_callable(single_tensor_adam) + # torch.compile does not support accessing the ContextVariable compile data used in _copy__impl_ + jitted = executor.make_callable(single_tensor_adam, torch_compile_fullgraph=False) params, grads, exp_avgs, exp_avg_sqs = tensors - cd = thunder.compile_data(jitted) - cd.compile_options["torch_compile_fullgraph"] = False jitted(params, grads, exp_avgs, exp_avg_sqs, state_steps) torch.testing.assert_close(actual=tensors + [state_steps], expected=ref_tensors + [ref_state_steps]) From a179d12e7d22bc3a2f2cb0c066e3a301c77f45f8 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Fri, 22 Nov 2024 14:14:28 +0200 Subject: [PATCH 4/7] test thunder, not torch --- thunder/executors/torchex.py | 2 +- thunder/tests/test_inplace_copy.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 63f2b26621..8f57a2d362 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2183,7 +2183,7 @@ def is_float_type(self, input): def _copy__impl(copy_from, copy_to): cd = get_compile_data() - if cd is not None and cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad: + if (cd is None or cd.is_grad_enabled) and copy_to.is_leaf and copy_to.requires_grad: raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.") copy_to.copy_(copy_from) return copy_to diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index cb50249d17..a4d5ba4c83 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -7,7 +7,7 @@ import thunder import thunder.core.dtypes as datatypes import thunder.torch as ttorch -from thunder.tests.framework import instantiate, nvFuserExecutor +from thunder.tests.framework import instantiate, nvFuserExecutor, TorchExecutor @instantiate(dtypes=datatypes.all_dtypes - datatypes.float_8bit_dtypes) @@ -180,9 +180,14 @@ def func(T0): assert_close(o, o_ref) -@instantiate(dtypes=datatypes.float_math_dtypes) +@instantiate(executors=(TorchExecutor,), dtypes=datatypes.float_math_dtypes) def test_inplace_copy_of_leaf_requiring_grad_fails(executor, device, dtype): + def fn(x): + x.copy_(x) + + jitted_fn = executor.make_callable(fn) + tdtype = ttorch.to_torch_dtype(dtype) a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=True) with pytest.raises(RuntimeError): - a.copy_(a) + jitted_fn(a) From a26ed676b084ee5d1ce8fd3ba5a4c6a8d331f053 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Mon, 25 Nov 2024 19:32:40 +0200 Subject: [PATCH 5/7] add parameter to copy__meta --- thunder/core/prims.py | 2 ++ thunder/core/transforms.py | 2 +- thunder/executors/nvfuserex_impl.py | 2 ++ thunder/executors/torchex.py | 5 ++- .../tests/test_inplace_functionalization.py | 36 ++++++++++--------- thunder/torch/__init__.py | 3 +- 6 files changed, 29 insertions(+), 21 deletions(-) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c17a28296c..29252f260a 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4030,6 +4030,8 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_ def copy__meta( copy_from: TensorProxy, copy_to: TensorProxy, + *, + is_grad_enabled: bool = False, ): utils.check_type(copy_from, TensorProxy) utils.check_type(copy_to, TensorProxy) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 7b09ef26b2..15da750386 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1588,7 +1588,7 @@ def zeros_like(x): prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)), prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)), prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)), - prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()), + prims.PrimIDs.COPY_: lambda x, y, is_grad_enabled: (prims.copy_(x, y, is_grad_enabled=is_grad_enabled), tuple()), prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()), } diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 561f838a7d..c6278bc6fa 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2054,6 +2054,7 @@ def var_mean( def _copy__check( copy_from: TensorProxy, copy_to: TensorProxy, + is_grad_enabled: bool, ) -> bool: return are_supported_tensors(copy_from, copy_to) @@ -2064,6 +2065,7 @@ def copy_( *, fd: FusionDefinition, lc_to_nv_map: dict, + is_grad_enabled: bool, ) -> Any: nvcopy_from = getnv(copy_from, fd, lc_to_nv_map) nvcopy_to = getnv(copy_to, fd, lc_to_nv_map) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 8f57a2d362..240a7514f4 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2181,9 +2181,8 @@ def is_float_type(self, input): einops._backends._type2backend[TensorProxy] = EinopsThunderBackend() -def _copy__impl(copy_from, copy_to): - cd = get_compile_data() - if (cd is None or cd.is_grad_enabled) and copy_to.is_leaf and copy_to.requires_grad: +def _copy__impl(copy_from, copy_to, *, is_grad_enabled): + if is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad: raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.") copy_to.copy_(copy_from) return copy_to diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index ea21b3e28c..4b0ca916df 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -476,27 +476,31 @@ def f(xs, ys, z): dtypes=NOTHING, ) def test_inplace_to_tensors_with_grad(executor, device, _): + @torch.no_grad + def add_grad(x, y): + x.add_(x.grad) + + @torch.no_grad def add_y(x, y): - # inplace operations requiring grad on leafs are illegal, trick to make z a non-leaf - z = torch.abs(x) * torch.sgn(x) - z.add_(y, alpha=0.1) + x.add_(y, alpha=0.1) - jitted_f = executor.make_callable(add_y) - x = make_tensor((2, 2), device=device, dtype=torch.float32, requires_grad=True) - x.grad = make_tensor((2, 2), device=device, dtype=torch.float32) - y = make_tensor((2, 2), device=device, dtype=torch.float32) + for fn in (add_grad, add_y): + jitted_f = executor.make_callable(fn) + x = make_tensor((2, 2), device=device, dtype=torch.float32, requires_grad=True) + x.grad = make_tensor((2, 2), device=device, dtype=torch.float32) + y = make_tensor((2, 2), device=device, dtype=torch.float32) - x_ref = x.clone().detach().requires_grad_(True) - x_ref.grad = x.grad.clone().detach() - y_ref = y.clone().detach() + x_ref = x.clone().detach().requires_grad_(True) + x_ref.grad = x.grad.clone().detach() + y_ref = y.clone().detach() - res = jitted_f(x, y) - res_ref = add_y(x_ref, y_ref) + res = jitted_f(x, y) + res_ref = fn(x_ref, y_ref) - torch.testing.assert_close(x, x_ref) - torch.testing.assert_close(x.grad, x_ref.grad) - torch.testing.assert_close(y, y_ref) - torch.testing.assert_close(res, res_ref) + torch.testing.assert_close(x, x_ref) + torch.testing.assert_close(x.grad, x_ref.grad) + torch.testing.assert_close(y, y_ref) + torch.testing.assert_close(res, res_ref) @instantiate( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a94ada1cc8..1bbded80f1 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1960,7 +1960,8 @@ def copysign_(a, b, /): @torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) def copy_(a, b, /): - return prims.copy_(b, a) + cd = get_compile_data() + return prims.copy_(b, a, is_grad_enabled=cd.is_grad_enabled if cd is not None else False) # TODO Implement div From 3fb1f53f00d7021748f0f614fba8d030ab947b0a Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Wed, 27 Nov 2024 17:13:43 +0200 Subject: [PATCH 6/7] apply function with compile data --- thunder/__init__.py | 4 ++-- thunder/core/prims.py | 2 -- thunder/core/transforms.py | 2 +- thunder/executors/nvfuserex_impl.py | 2 -- thunder/executors/torchex.py | 5 +++-- thunder/tests/test_core.py | 3 ++- thunder/tests/test_inplace_functionalization.py | 7 ++++--- thunder/torch/__init__.py | 3 +-- 8 files changed, 13 insertions(+), 15 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 54c94855dc..2b78a33bef 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -824,8 +824,8 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) check_storage_aliases(cache_entry, inps) - - result = cache_entry.computation_fn(*inps) + with compile_data_and_stats(cd, cs): + result = cache_entry.computation_fn(*inps) result = maybe_connect_to_autograd(cache_entry, result) result = maybe_call_epilogue(cache_entry, result, pro_to_epi) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 29252f260a..c17a28296c 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4030,8 +4030,6 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_ def copy__meta( copy_from: TensorProxy, copy_to: TensorProxy, - *, - is_grad_enabled: bool = False, ): utils.check_type(copy_from, TensorProxy) utils.check_type(copy_to, TensorProxy) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 15da750386..7b09ef26b2 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1588,7 +1588,7 @@ def zeros_like(x): prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)), prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)), prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)), - prims.PrimIDs.COPY_: lambda x, y, is_grad_enabled: (prims.copy_(x, y, is_grad_enabled=is_grad_enabled), tuple()), + prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()), prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()), } diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index c6278bc6fa..561f838a7d 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2054,7 +2054,6 @@ def var_mean( def _copy__check( copy_from: TensorProxy, copy_to: TensorProxy, - is_grad_enabled: bool, ) -> bool: return are_supported_tensors(copy_from, copy_to) @@ -2065,7 +2064,6 @@ def copy_( *, fd: FusionDefinition, lc_to_nv_map: dict, - is_grad_enabled: bool, ) -> Any: nvcopy_from = getnv(copy_from, fd, lc_to_nv_map) nvcopy_to = getnv(copy_to, fd, lc_to_nv_map) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 240a7514f4..64f150171a 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2181,8 +2181,9 @@ def is_float_type(self, input): einops._backends._type2backend[TensorProxy] = EinopsThunderBackend() -def _copy__impl(copy_from, copy_to, *, is_grad_enabled): - if is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad: +def _copy__impl(copy_from, copy_to): + cd = get_compile_data() + if cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad: raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.") copy_to.copy_(copy_from) return copy_to diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index bf6e9bd7d3..4053a1d772 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -350,7 +350,8 @@ def step(self): optimizer = Optimizer([a, b]) cstep = executor.make_callable(optimizer.step) - cstep() + with torch.no_grad(): + cstep() expected_a = ref_a - 0.1 * a.grad assert_close(a, expected_a) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 4b0ca916df..a33a7f20c2 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -478,11 +478,11 @@ def f(xs, ys, z): def test_inplace_to_tensors_with_grad(executor, device, _): @torch.no_grad def add_grad(x, y): - x.add_(x.grad) + return x.add_(x.grad) @torch.no_grad def add_y(x, y): - x.add_(y, alpha=0.1) + return x.add_(y, alpha=0.1) for fn in (add_grad, add_y): jitted_f = executor.make_callable(fn) @@ -494,7 +494,8 @@ def add_y(x, y): x_ref.grad = x.grad.clone().detach() y_ref = y.clone().detach() - res = jitted_f(x, y) + with torch.no_grad(): + res = jitted_f(x, y) res_ref = fn(x_ref, y_ref) torch.testing.assert_close(x, x_ref) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 1bbded80f1..a94ada1cc8 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1960,8 +1960,7 @@ def copysign_(a, b, /): @torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) def copy_(a, b, /): - cd = get_compile_data() - return prims.copy_(b, a, is_grad_enabled=cd.is_grad_enabled if cd is not None else False) + return prims.copy_(b, a) # TODO Implement div From cfd5143039b2035f1d83c33af52534300ba2f739 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Thu, 12 Dec 2024 17:41:23 +0200 Subject: [PATCH 7/7] pass grad_enabled bool instead of relying on compile data --- thunder/__init__.py | 3 +-- thunder/core/prims.py | 2 ++ thunder/core/transforms.py | 2 +- thunder/executors/nvfuserex_impl.py | 3 +++ thunder/executors/torchex.py | 5 ++--- thunder/tests/test_core.py | 3 +-- thunder/tests/test_inplace_functionalization.py | 3 +-- thunder/torch/__init__.py | 3 ++- 8 files changed, 13 insertions(+), 11 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 2b78a33bef..2d34c3b32a 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -824,8 +824,7 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) check_storage_aliases(cache_entry, inps) - with compile_data_and_stats(cd, cs): - result = cache_entry.computation_fn(*inps) + result = cache_entry.computation_fn(*inps) result = maybe_connect_to_autograd(cache_entry, result) result = maybe_call_epilogue(cache_entry, result, pro_to_epi) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c17a28296c..fd5ecc83e9 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4030,6 +4030,8 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_ def copy__meta( copy_from: TensorProxy, copy_to: TensorProxy, + *, + grad_enabled: bool = False, ): utils.check_type(copy_from, TensorProxy) utils.check_type(copy_to, TensorProxy) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 7b09ef26b2..83cb0d38b2 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1588,7 +1588,7 @@ def zeros_like(x): prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)), prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)), prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)), - prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()), + prims.PrimIDs.COPY_: lambda x, y, grad_enabled: (prims.copy_(x, y, grad_enabled=grad_enabled), tuple()), prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()), } diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 561f838a7d..98516f7417 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2054,6 +2054,8 @@ def var_mean( def _copy__check( copy_from: TensorProxy, copy_to: TensorProxy, + *, + grad_enabled: bool, ) -> bool: return are_supported_tensors(copy_from, copy_to) @@ -2064,6 +2066,7 @@ def copy_( *, fd: FusionDefinition, lc_to_nv_map: dict, + grad_enabled: bool, ) -> Any: nvcopy_from = getnv(copy_from, fd, lc_to_nv_map) nvcopy_to = getnv(copy_to, fd, lc_to_nv_map) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 64f150171a..6597d6eafe 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2181,9 +2181,8 @@ def is_float_type(self, input): einops._backends._type2backend[TensorProxy] = EinopsThunderBackend() -def _copy__impl(copy_from, copy_to): - cd = get_compile_data() - if cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad: +def _copy__impl(copy_from, copy_to, grad_enabled): + if grad_enabled and copy_to.is_leaf and copy_to.requires_grad: raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.") copy_to.copy_(copy_from) return copy_to diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4053a1d772..bf6e9bd7d3 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -350,8 +350,7 @@ def step(self): optimizer = Optimizer([a, b]) cstep = executor.make_callable(optimizer.step) - with torch.no_grad(): - cstep() + cstep() expected_a = ref_a - 0.1 * a.grad assert_close(a, expected_a) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index a33a7f20c2..08cbeb6fbc 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -494,8 +494,7 @@ def add_y(x, y): x_ref.grad = x.grad.clone().detach() y_ref = y.clone().detach() - with torch.no_grad(): - res = jitted_f(x, y) + res = jitted_f(x, y) res_ref = fn(x_ref, y_ref) torch.testing.assert_close(x, x_ref) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a94ada1cc8..7e1897471d 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1960,7 +1960,8 @@ def copysign_(a, b, /): @torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) def copy_(a, b, /): - return prims.copy_(b, a) + cd = get_compile_data() + return prims.copy_(b, a, grad_enabled=cd.is_grad_enabled if cd is not None else False) # TODO Implement div