From 728f14deeece2d4a41f7ebde9c97e5e8729ac015 Mon Sep 17 00:00:00 2001 From: Rui <179625410+rpsilva-aws@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:00:17 -0800 Subject: [PATCH] Introducing experimental gradient accumulation API (#8608) --- test/spmd/test_train_spmd_linear_model.py | 32 ++ .../utils/train_spmd_linear_model_grad_acc.py | 188 ++++++++ .../experimental/gradient_accumulation.py | 403 ++++++++++++++++++ 3 files changed, 623 insertions(+) create mode 100644 test/utils/train_spmd_linear_model_grad_acc.py create mode 100644 torch_xla/experimental/gradient_accumulation.py diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index 08637490a3c..74d6dfe0c95 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -10,8 +10,13 @@ parent_folder = os.path.dirname(os.path.dirname(__file__)) sys.path.append(parent_folder) + +# TODO(rpsilva-aws): Unify the SPMD MLP training files. from utils.train_spmd_linear_model import train_and_evaluate +from utils.train_spmd_linear_model_grad_acc import train_and_evaluate_grad_acc +# CPU does not support optimization barriers, and hence we use this to disable +# the gradient checkpointing A/B test run for it. SKIP_GRADIENT_CHECKPOINTING: bool = False @@ -48,8 +53,35 @@ def test_basic(self): baseline_losses, checkpointing_losses)) +class TestSPMDLinearModelGradientAccumulation( + test_xla_sharding_base.XlaShardingTest): + + def test_gradient_accumulation_matches(self): + """Verify that gradient accumulation produces the same losses with and + without the XLA `While` op. + """ + + COMMON_GRAD_ACC_ARGS = ["--gradient_accumulation_steps", "8"] + print('Training loop with traditional gradient accumulation') + with extended_argv(COMMON_GRAD_ACC_ARGS): + baseline_grad_acc_losses = train_and_evaluate_grad_acc() + + print('Training loop with XLA\'s `While` gradient accumulation') + with extended_argv(COMMON_GRAD_ACC_ARGS + + ["--use_gradient_accumulation_loop"]): + loop_grad_acc_losses = train_and_evaluate_grad_acc() + + # Verify that the model losses are not zero, and that the runs match. + assert all(loss != 0 for loss in baseline_grad_acc_losses) + assert all( + torch.allclose(baseline_loss, checkpointing_loss) + for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses, + loop_grad_acc_losses)) + + if __name__ == '__main__': parser = argparse.ArgumentParser() + # Relevant parser for the gradient checkpointing basic coverage. parser.add_argument('--skip-gradient-checkpointing', action='store_true') parsed_args, remaining_argv = parser.parse_known_args() SKIP_GRADIENT_CHECKPOINTING = parsed_args.skip_gradient_checkpointing diff --git a/test/utils/train_spmd_linear_model_grad_acc.py b/test/utils/train_spmd_linear_model_grad_acc.py new file mode 100644 index 00000000000..706ae490404 --- /dev/null +++ b/test/utils/train_spmd_linear_model_grad_acc.py @@ -0,0 +1,188 @@ +import sys +from typing import Optional + +import numpy as np +import torch +from torch import nn +import torch.optim as optim + +import args_parse +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.spmd as xs +import torch_xla.runtime as xr +import torch_xla.utils.utils as xu +from torch_xla.distributed.spmd import Mesh +from torch_xla.experimental.gradient_accumulation import gradient_accumulation +from torch_xla.utils.checkpoint import checkpoint + +MODEL_OPTS = { + '--sharding': { + 'choices': ['batch', 'megatron-lm', 'fsdp'], + 'nargs': '+', + 'default': [], + }, + '--input_dim': { + 'type': int, + 'default': 16834, + }, + '--train_dataset_len': { + 'type': int, + 'default': 1024 * 8, + }, + '--use_gradient_checkpointing': { + 'action': 'store_true', + }, + '--gradient_accumulation_steps': { + 'type': int, + 'default': 1, + }, + '--use_gradient_accumulation_loop': { + 'action': 'store_true', + } +} + +FLAGS = {} +PROFILER_SERVER = None + + +class SimpleLinear(nn.Module): + NUM_CLASSES = 3 + + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2), + nn.ReLU(), + nn.Linear(FLAGS.input_dim // 2, 3), + # # Add an additional 3x3 layer at the end to ensure the final layer + # # is not sharded. + nn.Linear(3, self.NUM_CLASSES), + ) + + def forward(self, x): + if FLAGS.use_gradient_checkpointing: + for n_l, layer in enumerate(self.layers): + # Apply gradient checkpointing for reduced memory footprint. + # This would result in increased computation cost. + if n_l > 0: + x = checkpoint(layer, x) + else: + x = layer(x) + else: + x = self.layers(x) + return x + + +def train(): + device = xm.xla_device() + num_devices = xr.global_runtime_device_count() + print(f'num_devices: {num_devices}') + # Define a mesh with all devices along one axis + mesh_shape = (num_devices, 1) + device_ids = np.arange(num_devices) + mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + + torch.manual_seed(42) + model = SimpleLinear().to(device) + print('===> Preparing data..') + batch_size = FLAGS.batch_size * FLAGS.gradient_accumulation_steps + train_loader = xu.SampleGenerator( + data=(torch.randn(batch_size, FLAGS.input_dim), + torch.randint( + 0, model.NUM_CLASSES, (batch_size,), dtype=torch.int64)), + sample_count=FLAGS.train_dataset_len // batch_size) + + if 'batch' in FLAGS.sharding: + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) + + if 'fsdp' in FLAGS.sharding: + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1))) + print('Sharding model weights') + # Shard the weights according to their 0th dim + xs.mark_sharding(model.layers[0].weight, mesh, (0, 1)) + xs.mark_sharding(model.layers[2].weight, mesh, (0, 1)) + + if 'megatron-lm' in FLAGS.sharding: + print('Sharding model weights') + # Shard the first layer's weights row-wise + xs.mark_sharding(model.layers[0].weight, mesh, (0, 1)) + # Shard the second layer's weights column-wise + xs.mark_sharding(model.layers[2].weight, mesh, (1, 0)) + + optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr) + + loss_fn = nn.CrossEntropyLoss() + + def train_step(input_id, label): + output = model(input_id) + loss = loss_fn(output, label) + return loss + + def train_loop_fn(data, target, running_loss): + if FLAGS.use_gradient_accumulation_loop: + running_loss, = gradient_accumulation(train_step, (data, target), model, + None) + else: + for i in range(FLAGS.gradient_accumulation_steps): + loss = train_step(data[i], target[i]) + loss /= FLAGS.gradient_accumulation_steps + running_loss += loss.detach() + loss.backward() + return running_loss + + losses = [] + for epoch in range(FLAGS.num_epochs): + model.train() + training_step = 0 + running_loss = torch.zeros(1, dtype=torch.float32, device=device) + for (data, target) in train_loader: + with xp.StepTrace('train_linear_model'): + with xp.Trace('build_graph'): + data = (data.reshape(FLAGS.gradient_accumulation_steps, -1, + *data.shape[1:])).to(device) + target = (target.reshape(FLAGS.gradient_accumulation_steps, + -1)).to(device) + # Ensure the appropriate sharding specs with the reshaped gradient + # gradient accumulation leading dimension. + if "batch" in FLAGS.sharding or "fsdp" in FLAGS.sharding: + xs.mark_sharding(data, mesh, (None, 0, 1)) + xs.mark_sharding(target, mesh, (None, 0)) + running_loss = train_loop_fn(data, target, running_loss) + training_step += FLAGS.gradient_accumulation_steps + optimizer.step() + xm.mark_step() + losses.append(running_loss.clone().detach()) + if training_step % FLAGS.log_steps == 0: + print( + f"Epoch {epoch} step {training_step} loss {running_loss.cpu().item()}" + ) + optimizer.zero_grad() + running_loss.zero_() + + return losses, model + + +def train_and_evaluate_grad_acc(): + default_config = { + 'batch_size': 128, + 'num_epochs': 1, + 'lr': 0.1, + 'log_steps': 8, + 'opts': MODEL_OPTS.items() + } + + global PROFILER_SERVER, FLAGS + FLAGS = args_parse.parse_common_options(**default_config) + if FLAGS.profile: + PROFILER_SERVER = xp.start_server(FLAGS.profiler_port) + xr.use_spmd(auto=FLAGS.auto_spmd) + print('Start training loop...') + losses, m = train() + t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device()) + m(t).cpu() + return [loss.cpu() for loss in losses] diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py new file mode 100644 index 00000000000..fcbdface796 --- /dev/null +++ b/torch_xla/experimental/gradient_accumulation.py @@ -0,0 +1,403 @@ +import torch +import torch_xla +import torch_xla.core.xla_builder as xb + +from typing import Any, Callable, Sequence, Tuple, Optional, List, Dict +from dataclasses import dataclass + + +@dataclass(frozen=True) +class GradientAccumulationContext: + """Context for the gradient accumulation instructions. + Attributes: + * num_gradient_steps: Number of steps to accumulate gradients over + * num_iterable_tensors: Number of input tensors to iterate over + * num_carried_tensors: Number of tensors carried between iterations + * num_model_params: Number of model parameters + * num_internal_tensors: Number of internal tensors used (default: 2) + + Note: `num_internal_tensors` should only be changed if we create new internal + tensors. + """ + num_gradient_steps: int + num_iterable_tensors: int + num_carried_tensors: int + num_model_params: int + num_internal_tensors: int = 2 + + +def gradient_accumulation( + train_step: Callable[..., Any], + iterable_tensors: Sequence[torch.Tensor], + model: torch.nn.Module, + carried_tensors: Optional[Tuple[torch.Tensor, ...]] = None +) -> Tuple[torch.Tensor, ...]: + """Accumulates gradients over multiple training steps using XLA's `While` + operator to iterate over the leading dimension of the iterable tensors. + The backward computation of the model is implicitly executed following the + train_step operations. + + Notes: + + The model tracing will happen entirely within the loop. Hence, it is + assumed that `train_step` is purposefully encapsulated inside of the + loop. Hence, it is not recommended to have any operation involving the + model parameters outside of `train_step`. + + Args: + train_step: Training function that takes iterable tensors and carried + tensors, and returns either a loss tensor or a tuple of (loss, + *carried_outputs). The iterable tensor inputs to this function should + disregard the leading dimension. + + iterable_tensors: Input tensors to iterate over. All tensors must have the + same first dimension size which determines number of iterations. The + underlying loop in the gradient accumulation will iterate through the + leading dimension of these tensors. + + model: PyTorch model whose parameters will be updated. Note that the entire + model computation will be traced and generated from within the loop. + + carried_tensors: Optional tensors passed and updated between iterations. + + Returns: + (accumulated_loss, carried_tensor0, carried_tensor1, ...): A tuple including + the `accumulated_loss` and the same unpacked `carried_tensors` that were + provided as inputs. In addition, the model parameter gradients, if + applicable, contain the accumulated gradients. + + Example: + + >>> # Note: This is a partial example, since it is dependent on the + >>> # training model. Please refer to existing tests. + >>> + >>> from torch_xla.experimental.gradient_accumulation import ( + >>> gradient_accumulation + >>> ) + >>> + >>> def train_step(input, label, other_tensor): + >>> output = model(input_id) + >>> loss = loss_fn(output, label) + >>> updated_other_tensor += 10 + >>> return loss, updated_other_tensor + >>> + >>> some_tensor = torch.tensor(10).to(device) + >>> for (data, target) in loader: + >>> # Assuming data's and target's first iterable dimension is 5. + >>> # >> data.shape = [5, 128, 16834] + >>> # >> label.shape = [5, 128] + >>> running_loss, some_tensor = gradient_accumulation( + >>> train_step, + >>> (data, target), + >>> model, + >>> (some_tensor,) + >>> ) + >>> print(some_tensor) # Should be 60 + >>> print(running_loss) # Should be the accumulated loss across all 5 + >>> # iteration steps + >>> optimizer.step() # Should update all weights with the accumulated + >>> # parameter weights + """ + # Validate that the arguments minimally suffice our requirements + if not iterable_tensors: + raise ValueError("iterable_tensors cannot be empty") + + accumulation_steps = iterable_tensors[0].size(0) + for i, tensor in enumerate(iterable_tensors): + if not isinstance(tensor, torch.Tensor): + raise ValueError(f"Element {i} of iterable_tensors is not a tensor") + if tensor.numel() == 0: + raise ValueError(f"Element {i} of iterable_tensors is empty") + if tensor.size(0) != accumulation_steps: + raise ValueError( + f"Element {i} of iterable_tensors has inconsistent first dimension") + carried_tensors = carried_tensors or tuple() + return _gradient_accumulation(accumulation_steps, train_step, + iterable_tensors, model, carried_tensors) + + +class XlaBuildHelper: + """Helper class for tracking the parameters for the XLA while computations.""" + + def __init__(self, name: str): + self._builder = xb.create_builder(name) + self._params: List[xb.Op] = [] + self._param_tensors: List[torch.Tensor] = [] + + def add_param(self, val: torch.Tensor, idx: Optional[int] = None) -> int: + if idx is None: + idx = len(self._params) + param = xb.mkparam(self._builder, idx, xb.tensor_shape(val)) + self._params.append(param) + self._param_tensors.append(val) + return idx + + @property + def params(self) -> Tuple[xb.Op, ...]: + return tuple(self._params) + + @property + def param_tensors(self) -> Tuple[torch.Tensor, ...]: + return tuple(self._param_tensors) + + @property + def num_params(self) -> int: + return len(self._params) + + +def _gradient_accumulation_impl(context, body_fn, iterable_tensors, params, + grads, carried_tensors): + builder = XlaBuildHelper('grad_acc') + device = torch_xla.device() + + def _prepare_fake_tensors( + iterable_tensors: Sequence[torch.Tensor], + carried_tensors: Sequence[torch.Tensor] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + fake_iterable_tensors = [] + for iter_tensor in iterable_tensors: + original_size = iter_tensor.size() + fake_iterable_tensors.append( + torch.empty(original_size[1:], + dtype=iter_tensor.dtype).to(iter_tensor.device)) + + fake_carried_tensors = [] + for carried_input in carried_tensors: + fake_carried_tensors.append( + torch.empty(carried_input.size(), dtype=carried_input.dtype).to( + carried_input.device).requires_grad_(carried_input.requires_grad)) + return fake_iterable_tensors, fake_carried_tensors + + # TODO - Fake the model once we are able to create placeholder tensors. + fake_iterable_tensors, fake_carried_tensors = _prepare_fake_tensors( + iterable_tensors, carried_tensors) + init_iterator = torch.tensor(0, dtype=torch.int32, device=device) + init_loss = torch.tensor(0, dtype=torch.float32, device=device) + + body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors, + *fake_carried_tensors, *params, *grads) + body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors), + tuple(fake_carried_tensors), tuple(params), + tuple(grads)) + + ( + graph_input_tensor_ids, + graph_input_xla_values, + ) = torch_xla._XLAC._get_tensors_xla_device_data_node( + list(body_result) + list(body_fn_inputs)) + + body_fn_input_tensor_ids = [ + torch_xla._XLAC._xla_get_tensor_id(i) for i in body_fn_inputs + ] + uncaptured_input_tensor_ids = tuple( + v for i, v in zip(graph_input_tensor_ids, graph_input_xla_values) + if i not in body_fn_input_tensor_ids) + + body_ctx = torch_xla._XLAC.lowering.LoweringContext() + body_ctx.set_name_string("bodyctx") + body_ctx.build(body_result + uncaptured_input_tensor_ids) + body_hlo = body_ctx.hlo() + body_computation = xb.computation_from_module_proto("bodycomputation", + body_hlo) + + builder.add_param(init_iterator) + builder.add_param(init_loss) + + def _build_parameter_mapping( + builder: XlaBuildHelper, + context: GradientAccumulationContext, + body_fn_inputs: Tuple[torch.Tensor, ...], + uncaptured_input_tensor_ids: Tuple[torch.Tensor, ...], + iterable_tensors: Sequence[torch.Tensor], + fake_iterable_tensors: Sequence[torch.Tensor], + carried_tensors: Tuple[torch.Tensor, ...], + fake_carried_tensors: Tuple[torch.Tensor, ...], + params: List[torch.Tensor], + grads: List[torch.Tensor], + ) -> Dict[int, int]: + param_mapping = {} + + def add_to_mapping(val: torch.Tensor, + fake_val: Optional[torch.Tensor] = None): + idx = builder.add_param(val) + param_id = body_ctx.tensor_parameter_id( + fake_val if fake_val is not None else val) + if param_id != -1: + param_mapping[param_id] = idx + + # Process iterable tensors and carried inputs + for val, fake_val in zip(iterable_tensors, fake_iterable_tensors): + add_to_mapping(val, fake_val) + for val, fake_val in zip(carried_tensors, fake_carried_tensors): + add_to_mapping(val, fake_val) + + # Process params, grads, and uncaptured input tensor ids + for tensor_list in (params, grads, uncaptured_input_tensor_ids): + for val in tensor_list: + add_to_mapping(val) + + # Handle any additional hoisted variables + hoisted_vars = body_ctx.device_parameter_id_tensor_mapping() + for v in body_fn_inputs + uncaptured_input_tensor_ids: + param_id = body_ctx.tensor_parameter_id(v) + hoisted_vars.pop(param_id, None) + + # TODO(rpsilva-aws): Derived from `experimental/scan.py`. Unify the RNG and + # hoisted paths. + seed_info_id = torch_xla._XLAC._get_seed_info_id() + seed_parameter_id = None + if seed_info_id in graph_input_tensor_ids: + seed_idx = graph_input_tensor_ids.index(seed_info_id) + seed_parameter_id = body_ctx.tensor_parameter_id( + graph_input_xla_values[seed_idx]) + assert seed_parameter_id != -1, "`fn` uses random seed, but random seed is not \ + a parameter to the traced HLO graph" + + # Replace the single seed value with a tensor of seeds, one per iteration. + seed_tensor = hoisted_vars[seed_parameter_id] + assert seed_tensor.dtype == torch.int64 + hoisted_vars[seed_parameter_id] = torch.randint( + 0, + 2**62, (context.num_gradient_steps,), + dtype=torch.int64, + device=device) + + for param_id, tensor in hoisted_vars.items(): + idx = builder.add_param(tensor) + param_mapping[param_id] = idx + return param_mapping, seed_parameter_id + + param_mapping, seed_parameter_id = _build_parameter_mapping( + builder, context, body_fn_inputs, uncaptured_input_tensor_ids, + iterable_tensors, fake_iterable_tensors, carried_tensors, + fake_carried_tensors, params, grads) + + def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, + *while_params: xb.Op): + + def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: + indices = [idx] + [idx.zeros_like() for _ in range(xs.shape().rank - 1)] + slice_shape = list(xs.shape().sizes) + slice_shape[0] = 1 + sliced = xs.dynamic_slice(indices, slice_shape) + return sliced.reshape(list(xs.shape().sizes)[1:]) + + # TODO(rpsilva-aws): Derived from `experimental/scan.py`. Unify the RNG + # path. + def replace_rng_seed(curr_iter: xb.Op, *while_params: xb.Op): + """Slices the pre-generated seed tensor for the current iteration.""" + if seed_parameter_id is None: + return while_params + idx = param_mapping[seed_parameter_id] + replaced = list(while_params) + replaced[idx] = dynamic_slice(replaced[idx], curr_iter) + return replaced + + def call_fn_computation(*while_params: xb.Op) -> xb.Op: + fn_inputs = [ + while_params[param_mapping[i]] for i in range(len(param_mapping)) + ] + return xb.Op.call(body_computation, fn_inputs) + + iterable_tensors = while_params[:context.num_iterable_tensors] + idx = curr_iter + sliced_iterables = [ + dynamic_slice(iter_tensor, idx) for iter_tensor in iterable_tensors + ] + + # Call the computation with current values + result = call_fn_computation( + idx, curr_loss, + *replace_rng_seed(idx, *sliced_iterables, + *while_params[context.num_iterable_tensors:])) + + # Extract the carried tensors and accumulated gradients. + carried_tensors_and_gradients = [ + result.get_tuple_element(i) for i in range( + context.num_internal_tensors + context.num_iterable_tensors, + result.shape().tuple_size()) + ] + one = xb.Op.scalar(idx.builder(), 1, dtype=xb.Type.S32) + updated_loss = curr_loss + result.get_tuple_element(1) + return (curr_iter + one, updated_loss, *iterable_tensors, + *carried_tensors_and_gradients) + + def _cond_fn(curr_iter: xb.Op, *rest): + return curr_iter < xb.Op.scalar( + curr_iter.builder(), context.num_gradient_steps, dtype=xb.Type.S32) + + def _compute_output_indices( + context: GradientAccumulationContext) -> List[int]: + # Start with loss index + indices = [1] + # Add indices for carried tensors + carried_start = context.num_internal_tensors + context.num_iterable_tensors + carried_end = carried_start + context.num_carried_tensors + indices.extend(range(carried_start, carried_end)) + # Add indices for accumulated gradients + grad_start = carried_end + context.num_model_params + grad_end = grad_start + context.num_model_params + indices.extend(range(grad_start, grad_end)) + return indices + + w = xb.Op.mkwhile(builder.params, _cond_fn, _body_fn_wrapper) + outputs = [w.get_tuple_element(i) for i in _compute_output_indices(context)] + op = xb.Op.tuple(outputs) + computation = op.build('grad_acc_loop_torch_func') + result = torch_xla._XLAC._xla_user_computation('xla::_op_grad_acc_loop', + builder.param_tensors, + computation) + return result + + +def _gradient_accumulation(accumulation_steps, train_step, iterable_tensors, + model, carried_tensors): + model_parameters = list(model.parameters()) + context = GradientAccumulationContext(accumulation_steps, + len(iterable_tensors), + len(carried_tensors), + len(model_parameters)) + + def body_fn(iteri: torch.Tensor, _: torch.Tensor, + iterable_tensors: Tuple[torch.Tensor, ...], + carried_tensors: Tuple[torch.Tensor, + ...], params: Tuple[torch.Tensor, ...], + grads: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + result = train_step(*iterable_tensors, *carried_tensors) + + if not context.num_carried_tensors: + loss = result + else: + loss, *carried_tensors = result + loss /= context.num_gradient_steps + gradients = torch.autograd.grad(loss, model_parameters) + acc_grads = [prev_grad + grad for prev_grad, grad in zip(grads, gradients)] + return (iteri, loss, *iterable_tensors, *carried_tensors, *params, + *acc_grads) + + init_grads = [] + # Initialize the gradients to zero. + for param in model_parameters: + if not param.requires_grad: + continue + if param.grad: + grad = param.grad + else: + grad = torch.zeros(param.size()).to(param.device).requires_grad_(False) + param_sharding = torch_xla._XLAC._get_xla_op_sharding(param) + if param_sharding: + # Match the gradient sharding to the parameter's. + torch_xla._XLAC._xla_mark_sharding(grad, param_sharding) + init_grads.append(grad) + + # Apply gradients to parameters + result = _gradient_accumulation_impl(context, body_fn, iterable_tensors, + model_parameters, init_grads, + carried_tensors) + + for param, grad in zip(model_parameters, + result[1 + context.num_carried_tensors:]): + if param.requires_grad: + param.grad = grad + + return (result[0], *result[1:context.num_carried_tensors + 1])