Skip to content

Commit

Permalink
new logsigmoid impl and custom backward
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Dec 11, 2024
1 parent 9c84e23 commit bc9dda3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 14 deletions.
20 changes: 20 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,21 @@ def embedding_backward(a, num_weights, padding_idx, scale_grad_by_freq, sparse,
return gweight


@register_augmented_forward("torch.nn.functional.logsigmoid")
def log_sigmoid_aug_fwd(a):
from thunder.torch import logsigmoid, relu

primal = logsigmoid(a)
return VJPDual(primal, (a, a))


@register_backward("torch.nn.functional.logsigmoid")
def log_sigmoid_backward(a, _, g):
from thunder.torch import log_sigmoid_backward

return log_sigmoid_backward(g, a, _)


@register_augmented_forward("torch.cumsum")
def cumsum_aug_fwd(a: Proxy, dim: int, *, dtype: None | dtypes.dtype = None) -> VJPDual:
from thunder.torch import cumsum
Expand Down Expand Up @@ -2735,6 +2750,9 @@ def put_grad(v: Variable, val: Any) -> None:
raise NotImplementedError(f"Backward for {symbol.sym.id} is not implemented")

result = backward(*residuals, *cotangents)
print(f"res: {residuals}")
print(f"cot: {cotangents}")
print(result)
if isinstance(result, dict):
# If the backward returns a dict, we assume that it is a dict of
# forward arguments to the corresponding
Expand Down Expand Up @@ -2807,6 +2825,7 @@ def get_inexact_dtype_or_none(x):
gkwargs = tree_map(get_grad, trace.kwargs)
gkwargs = {k: v for k, v in gkwargs.items() if v is not None}
gargs, gkwargs = tree_map(get_inexact_dtype_or_none, (gargs, gkwargs))
print(f"gargs: {gargs}")
return gargs + (gkwargs,) if len(gkwargs) != 0 else gargs


Expand All @@ -2817,6 +2836,7 @@ def vjp_call(primals, cotangents, trace: Trace, **kwargs):
primals = (primals,)

result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
print(f"env: {env}")
check(
len(result) == len(cotangents) if isinstance(result, Sequence) else True,
lambda: f"Expected cotangents to be a sequence of length {len(result)}, got a sequence of length {len(cotangents)}",
Expand Down
16 changes: 12 additions & 4 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,15 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional)
log_sigmoid_backward = _register_torch_operation(
"torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward
)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)

Expand All @@ -851,11 +855,15 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(
ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable
)
_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)

Expand Down
10 changes: 2 additions & 8 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,14 +1683,8 @@ def gen(op, device, dtype, requires_grad):
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_generator,
torch_reference=torch.nn.functional.logsigmoid,
test_directives=(
# test tols are too tight for these half precision tests
DecorateInfo(
pytest.mark.skip,
"test_core_vs_torch_consistency",
dtypes=(datatypes.float16, datatypes.bfloat16),
),
),
domain=(-1, 1),
test_directives=(),
)
elementwise_unary_ops.append(logsigmoid_opinfo)

Expand Down
15 changes: 13 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,13 +1813,24 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool =


@torchsymbol(torch.nn.functional.logsigmoid, is_method=False)
def logsigmoid(a: TensorProxy, /):
return log(sigmoid(a))
def logsigmoid(a: TensorProxy, /) -> TensorLike:
return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a)))


_inplace_to_out_of_place[logsigmoid] = logsigmoid, -1


# @torchsymbol("log_sigmoid_backward", id="log_sigmoid_backward")
def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, _: TensorProxy) -> TensorLike:
exp_a = exp(-abs(a))
z = exp_a / (1 + exp_a)
return g * where(a > 0, z, 1 - z)
# return g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a)))


_inplace_to_out_of_place[log_sigmoid_backward] = log_sigmoid_backward, -1


# TODO Should this use clamp? -- Would that propagate NaNs properly?
@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
Expand Down

0 comments on commit bc9dda3

Please sign in to comment.