@@ -34,7 +34,7 @@ def scan_layers(layers: Iterable[torch.nn.Module],
3434
3535 input_data: The input to be given to the first layer from `layers`.
3636
37- partition_fn: (Optional[Callable]) The graph parition function passed to AOTAutograd.
37+ partition_fn: (Optional[Callable]) The graph partition function passed to AOTAutograd.
3838 Since this function uses AOTAutograd to trace `fn`, you may override what computation
3939 happen in the forward and backward passes by specifying different partition functions.
4040 `default_partition` implies no activation checkpointing. You may specify
@@ -76,16 +76,12 @@ def scan_layers(layers: Iterable[torch.nn.Module],
7676 stacked_buffers = tree_map (lambda * tensors : torch .stack (tensors , dim = 0 ),
7777 * buffers_list )
7878
79- # Use the first layer as the example/template layer.
80- from copy import deepcopy
81- example_layer = deepcopy (first_layer )
82-
8379 # Define the function to apply at each step
8480 def one_layer (carry , params_buffers ):
8581 # Apply the current layer's weights and biases to the example layer,
8682 # then run the resulting layer.
8783 output = torch .func .functional_call ( # type: ignore
88- example_layer , params_buffers , carry , strict = True )
84+ first_layer , params_buffers , carry , strict = True )
8985 return output , None
9086
9187 stacked_params_buffers = (stacked_params , stacked_buffers )
0 commit comments