Skip to content

Commit

Permalink
add option to enable CompiledAutograd
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Dec 10, 2024
1 parent 087637f commit fa70057
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(
use_torchao_fp8_allgather: bool = False,
use_torchao_fp8_precompute_scale_for_fsdp: bool = False,
fp8_shard_intermediate_activation: bool = False,
enable_compiled_autograd: bool = False,
):
seed = 1337
torch.manual_seed(seed)
Expand Down Expand Up @@ -275,6 +276,7 @@ def __init__(
self.dump_thunder_traces = dump_thunder_traces
self.dump_memory_snapshot = dump_memory_snapshot
self.fp8_shard_intermediate_activation = fp8_shard_intermediate_activation
self.enable_compiled_autograd = enable_compiled_autograd

if use_torchao_fp8_linear:

Expand Down Expand Up @@ -669,11 +671,12 @@ def train(self):
input_ids, targets = next(self.train_data_iter)
input_ids = input_ids.to(self.device)
targets = targets.to(self.device)
if self.use_te_fp8_autocast:
with te.fp8_autocast():
with torch._dynamo.utils.maybe_enable_compiled_autograd(self.enable_compiled_autograd):
if self.use_te_fp8_autocast:
with te.fp8_autocast():
logits = self.model(input_ids)
else:
logits = self.model(input_ids)
else:
logits = self.model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss = (
Expand All @@ -685,11 +688,12 @@ def train(self):
input_ids, targets = next(self.train_data_iter)
input_ids = input_ids.to(self.device)
targets = targets.to(self.device)
if self.use_te_fp8_autocast:
with te.fp8_autocast():
with torch._dynamo.utils.maybe_enable_compiled_autograd(self.enable_compiled_autograd):
if self.use_te_fp8_autocast:
with te.fp8_autocast():
logits = self.model(input_ids)
else:
logits = self.model(input_ids)
else:
logits = self.model(input_ids)
# This information is accurate only in the case when torch.compile
# uses a single graph for the entire forward pass In the case of
# torch.compile using multiple graphs, the saved_tensors will be
Expand Down

0 comments on commit fa70057

Please sign in to comment.