From f206afa9ad76d8c862393325ab8832cf0cadb113 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 19 Nov 2024 12:24:30 +0100 Subject: [PATCH] fix: using te and fsdp leads to multiple device found error (#1453) --- thunder/benchmarks/benchmark_litgpt.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 2aa78cbe13..8bcaf575eb 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -114,11 +114,13 @@ def _resursively_swap_linear_layers_for_te(module: torch.nn.Module) -> None: if isinstance(m, torch.nn.Linear): has_bias = m.bias is not None - new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=device) + # Pass device as str (as there is a bug in TransformerEngine's handling of torch.device) + new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=str(device)) setattr(module, n, new_linear) if swap_layernorm and isinstance(m, torch.nn.LayerNorm): - new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=device) + # Pass device as str (as there is a bug in TransformerEngine's handling of torch.device) + new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=str(device)) setattr(module, n, new_layernorm) initial_params_cnt = parameters_cnt(model) @@ -366,11 +368,6 @@ def __init__( self.model = self.init_model() print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - if self.use_te_fp8_autocast: - is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm" - swap_linear_layers_for_te(self.model, device, swap_layernorm=not is_wo_layernorm) - self.model.to(torch.bfloat16) - # Setup the distributed algorithm choices if distributed_first := (self.compile in ("eager", "inductor") or "dynamo" in self.compile): self.model = self.setup_distributed(self.model) @@ -407,8 +404,14 @@ def init_model(self): init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device with init_device: model = GPT(self.config) - model.to(dtype=torch.bfloat16) + + # Handle fp8 related Linear layer swapping (for torchao or TransformerEngine) model = self._torchao_fp8_handler.convert_model_to_fp8(model) + if self.use_te_fp8_autocast: + is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm" + swap_linear_layers_for_te(model, init_device, swap_layernorm=not is_wo_layernorm) + + model.to(dtype=torch.bfloat16) return model def setup_distributed(self, model):