diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 7b09ef26b..05cefe4ec 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2217,6 +2217,21 @@ def embedding_backward(a, num_weights, padding_idx, scale_grad_by_freq, sparse, return gweight +@register_augmented_forward("torch.nn.functional.logsigmoid") +def log_sigmoid_aug_fwd(a): + from thunder.torch import logsigmoid, relu + + primal = logsigmoid(a) + return VJPDual(primal, (a, a)) + + +@register_backward("torch.nn.functional.logsigmoid") +def log_sigmoid_backward(a, _, g): + from thunder.torch import log_sigmoid_backward + + return log_sigmoid_backward(g, a, _) + + @register_augmented_forward("torch.cumsum") def cumsum_aug_fwd(a: Proxy, dim: int, *, dtype: None | dtypes.dtype = None) -> VJPDual: from thunder.torch import cumsum @@ -2735,6 +2750,9 @@ def put_grad(v: Variable, val: Any) -> None: raise NotImplementedError(f"Backward for {symbol.sym.id} is not implemented") result = backward(*residuals, *cotangents) + print(f"res: {residuals}") + print(f"cot: {cotangents}") + print(result) if isinstance(result, dict): # If the backward returns a dict, we assume that it is a dict of # forward arguments to the corresponding @@ -2807,6 +2825,7 @@ def get_inexact_dtype_or_none(x): gkwargs = tree_map(get_grad, trace.kwargs) gkwargs = {k: v for k, v in gkwargs.items() if v is not None} gargs, gkwargs = tree_map(get_inexact_dtype_or_none, (gargs, gkwargs)) + print(f"gargs: {gargs}") return gargs + (gkwargs,) if len(gkwargs) != 0 else gargs @@ -2817,6 +2836,7 @@ def vjp_call(primals, cotangents, trace: Trace, **kwargs): primals = (primals,) result, env = augmented_forward_pass(*primals, trace=trace, **kwargs) + print(f"env: {env}") check( len(result) == len(cotangents) if isinstance(result, Sequence) else True, lambda: f"Expected cotangents to be a sequence of length {len(result)}, got a sequence of length {len(cotangents)}", diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 8484d389d..df10fd9a4 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -835,11 +835,15 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor: celu = _register_torch_operation("celu", module=torch.nn.functional) elu = _register_torch_operation("elu", module=torch.nn.functional) gelu = _register_torch_operation("gelu", module=torch.nn.functional) +hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional) +hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional) +logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional) +log_sigmoid_backward = _register_torch_operation( + "torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward +) relu = _register_torch_operation("relu", module=torch.nn.functional) relu6 = _register_torch_operation("relu6", module=torch.nn.functional) -hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional) -hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) selu = _register_torch_operation("selu", module=torch.nn.functional) silu = _register_torch_operation("silu", module=torch.nn.functional) @@ -851,11 +855,15 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F _register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable) +_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable) +_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable) +_register_elementwise_unary_implementation( + ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable +) +_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker) -_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable) -_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 057b73942..3c1008f67 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1683,14 +1683,8 @@ def gen(op, device, dtype, requires_grad): dtypes=(datatypes.floating,), sample_input_generator=elementwise_unary_generator, torch_reference=torch.nn.functional.logsigmoid, - test_directives=( - # test tols are too tight for these half precision tests - DecorateInfo( - pytest.mark.skip, - "test_core_vs_torch_consistency", - dtypes=(datatypes.float16, datatypes.bfloat16), - ), - ), + domain=(-1, 1), + test_directives=(), ) elementwise_unary_ops.append(logsigmoid_opinfo) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 55eedf17c..4a95e8abc 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1813,13 +1813,24 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool = @torchsymbol(torch.nn.functional.logsigmoid, is_method=False) -def logsigmoid(a: TensorProxy, /): - return log(sigmoid(a)) +def logsigmoid(a: TensorProxy, /) -> TensorLike: + return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a))) _inplace_to_out_of_place[logsigmoid] = logsigmoid, -1 +# @torchsymbol("log_sigmoid_backward", id="log_sigmoid_backward") +def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, _: TensorProxy) -> TensorLike: + exp_a = exp(-abs(a)) + z = exp_a / (1 + exp_a) + return g * where(a > 0, z, 1 - z) + # return g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a))) + + +_inplace_to_out_of_place[log_sigmoid_backward] = log_sigmoid_backward, -1 + + # TODO Should this use clamp? -- Would that propagate NaNs properly? @torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True) def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: