diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 4d6a7feb6..1437a2222 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -241,6 +241,7 @@ def __init__( use_torchao_fp8_allgather: bool = False, use_torchao_fp8_precompute_scale_for_fsdp: bool = False, fp8_shard_intermediate_activation: bool = False, + enable_compiled_autograd: bool = False, ): seed = 1337 torch.manual_seed(seed) @@ -275,6 +276,7 @@ def __init__( self.dump_thunder_traces = dump_thunder_traces self.dump_memory_snapshot = dump_memory_snapshot self.fp8_shard_intermediate_activation = fp8_shard_intermediate_activation + self.enable_compiled_autograd = enable_compiled_autograd if use_torchao_fp8_linear: @@ -669,11 +671,12 @@ def train(self): input_ids, targets = next(self.train_data_iter) input_ids = input_ids.to(self.device) targets = targets.to(self.device) - if self.use_te_fp8_autocast: - with te.fp8_autocast(): + with torch._dynamo.utils.maybe_enable_compiled_autograd(self.enable_compiled_autograd): + if self.use_te_fp8_autocast: + with te.fp8_autocast(): + logits = self.model(input_ids) + else: logits = self.model(input_ids) - else: - logits = self.model(input_ids) logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) loss = ( @@ -685,11 +688,12 @@ def train(self): input_ids, targets = next(self.train_data_iter) input_ids = input_ids.to(self.device) targets = targets.to(self.device) - if self.use_te_fp8_autocast: - with te.fp8_autocast(): + with torch._dynamo.utils.maybe_enable_compiled_autograd(self.enable_compiled_autograd): + if self.use_te_fp8_autocast: + with te.fp8_autocast(): + logits = self.model(input_ids) + else: logits = self.model(input_ids) - else: - logits = self.model(input_ids) # This information is accurate only in the case when torch.compile # uses a single graph for the entire forward pass In the case of # torch.compile using multiple graphs, the saved_tensors will be