diff --git a/trident/operation/linear.py b/trident/operation/linear.py index 557fba5..411f356 100644 --- a/trident/operation/linear.py +++ b/trident/operation/linear.py @@ -29,7 +29,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): output = Linear.__forward(input, weight, bias, use_accelerator) util.pop_trace() - ctx.save_for_backward(input, weight, bias, output) + ctx.save_for_backward(input, weight, bias) ctx.use_accelerator = use_accelerator return output @@ -37,12 +37,10 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): @staticmethod def backward(ctx: Any, *grad_outputs: Any): (grad_output,) = grad_outputs - input, weight, bias, output = ctx.saved_tensors + input, weight, bias = ctx.saved_tensors util.push_trace("Linear.__backward") - grad_input, grad_weight, grad_bias = Linear.__backward( - grad_output, output, input, weight, bias, ctx.use_accelerator - ) + grad_input, grad_weight, grad_bias = Linear.__backward(grad_output, input, weight, bias, ctx.use_accelerator) util.pop_trace() return grad_input, grad_weight, grad_bias, None, None @@ -81,7 +79,7 @@ def grid(meta): return output @staticmethod - def __backward(grad_output, output, input, weight, bias, use_accelerator): + def __backward(grad_output, input, weight, bias, use_accelerator): factory_kwargs = {"device": input.device, "dtype": input.dtype} num_batches, m_size, k_size = input.shape n_size, _ = weight.shape