diff --git a/parse_sweep.py b/parse_sweep.py new file mode 100644 index 00000000..9039b5d1 --- /dev/null +++ b/parse_sweep.py @@ -0,0 +1,57 @@ +""" +Input: a subdirectory containing the logs from various experiments +Output: a csv file with loss values, peak memory usage, throughout from each experiment +""" + +import csv +import os +import re + +import fire + +OUTPUT_FOLDER = '/home/vasiliy/local/tmp/torchtitan_outputs' + +# example: +# [rank0]:[INFO | root ]: step: 10 loss: 7.8774 memory: 0.44GiB(0.47%) tps: 997,458 mfu: 1.50% +# note that number of spaces between terms can vary +regex = r"- step:[ ]+([\d]+).*loss:[ ]+([\d\.]+).*memory:[ ]+([\d\.]+)GiB.*tps: ([\d\,]+).*mfu.*" + +def log_to_maybe_data(line): + res = re.search(regex, line) + if res is not None: + step, loss, memory_gib, wps = res.group(1), res.group(2), res.group(3), res.group(4) + return int(step), float(loss), float(memory_gib), int(wps.replace(',', '')) + else: + return None + +def run( + subfolder_prefix: str, + results_filename: str, +): + subfolder_prefix = str(subfolder_prefix) + + results = [['experiment', 'step', 'loss', 'memory_gib', 'tps']] + + for entry in os.scandir(OUTPUT_FOLDER): + if entry.is_dir() and subfolder_prefix in entry.path: + print(entry) + log_fname = f"{entry.path}/logs.txt" + short_path = entry.path.replace(f"{OUTPUT_FOLDER}/", '') + + with open(log_fname, 'r') as f: + lines = f.readlines() + for l in lines: + res = log_to_maybe_data(l) + if res is not None: + print(l.strip('\n')) + print(res) + results.append([short_path, *res]) + + with open(results_filename, 'w') as f: + writer = csv.writer(f) + writer.writerows(results) + + print('done') + +if __name__ == '__main__': + fire.Fire(run) diff --git a/test/test_te.py b/test/test_te.py new file mode 100644 index 00000000..41a48731 --- /dev/null +++ b/test/test_te.py @@ -0,0 +1,195 @@ +import copy + +import torch +import torch.nn as nn + +# path hack, TODO remove +import sys +sys.path.insert(0, '/home/vasiliy/local/torchtitan/') +import torchtitan.te_utils as te_utils +from torchtitan.models.norms import build_norm +from torchtitan.models.llama.model import FeedForward, Attention, ModelArgs, precompute_freqs_cis + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +# torch.use_deterministic_algorithms(True) +torch.manual_seed(0) + +fp8_format = Format.HYBRID +fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") +maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) + +def test_linear_module_swap(): + x = torch.randn(32, 32, device='cuda') + + m = nn.Sequential(nn.Linear(32, 32)).cuda() + te_utils.swap_linear_to_te_linear(m) + print(m) + m = torch.compile(m) + + with maybe_te_float8_ctx: + y = m(x) + y.sum().backward() + + print('done') + +# Subsection of TransformerBlock with only the ffn norm and the ffn +class NormFFNBlock(nn.Module): + def __init__(self, dim, hidden_dim, multiple_of): + super().__init__() + self.ffn_norm = build_norm("rmsnorm", dim, eps=1e-12) + self.feed_forward = FeedForward(dim, hidden_dim, multiple_of, None) + + def forward(self, h): + out = h + self.feed_forward(self.ffn_norm(h)) + return out + +class NormAttnBlock(nn.Module): + def __init__(self, model_args): + super().__init__() + self.attention_norm = build_norm("rmsnorm", model_args.dim, eps=1e-12) + self.attention = Attention(model_args) + self.model_args = model_args + self.freqs_cis = precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ).cuda() + + def forward(self, x): + x = self.attention_norm(x) + x = self.attention(x, self.freqs_cis) + return x + +def SQNR(x, y): + return 20 * torch.log10( + torch.linalg.norm(x) / torch.linalg.norm(x - y) + ) + +def test_norm_attn_rewrite(): + dim = 256 + model_args = ModelArgs() + m = NormAttnBlock(model_args).cuda().bfloat16() + m_copy = copy.deepcopy(m) + te_utils.swap_norm_attn_to_te_friendly_norm_attn(m_copy) + print(m) + + x = torch.randn(1, 128, model_args.dim).cuda().bfloat16() + x_copy = copy.deepcopy(x) + + y = m(x) + + y_copy = m_copy(x_copy) + + print(torch.allclose(y, y_copy)) + print(SQNR(y, y_copy)) + + te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy) + print(m) + y_copy2 = m_copy(x_copy) + print(torch.allclose(y_copy, y_copy2)) + print(SQNR(y_copy, y_copy2)) + + + +def test_norm_ln_ffn_rewrite(): + dim = 256 + hidden_dim = 512 + multiple_of = 1 + + x = torch.randn(1, 128, 256).cuda().bfloat16() + x_copy = copy.deepcopy(x) + + m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16() + m_copy = copy.deepcopy(m) + print(m) + + y = m(x) + y.sum().backward() + + te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy) + print(m_copy) + + y_copy = m_copy(x_copy) + y_copy.sum().backward() + + # TODO: debug why not an exact match + print(torch.allclose(y, y_copy)) + print(SQNR(y, y_copy)) + + # TODO test w13 + # assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0) + + te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy) + print(m_copy) + + y_copy2 = m_copy(x_copy) + print(torch.allclose(y_copy, y_copy2)) + print(SQNR(y_copy, y_copy2)) + +def test_norm_mlp_ffn_rewrite(): + dim = 256 + hidden_dim = 512 + multiple_of = 1 + + x = torch.randn(1, 128, 256).cuda().bfloat16() + x_copy = copy.deepcopy(x) + + m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16() + m_copy = copy.deepcopy(m) + print(m) + + y = m(x) + y.sum().backward() + + te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy) + print(m_copy) + + y_copy = m_copy(x_copy) + y_copy.sum().backward() + + # TODO: debug why not an exact match + print(torch.allclose(y, y_copy)) + print(SQNR(y, y_copy)) + + # TODO test w13 + # assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0) + + te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_mlp(m_copy) + print(123) + print(m_copy) + + y_copy2 = m_copy(x_copy) + print(torch.allclose(y_copy, y_copy2)) + print(SQNR(y_copy, y_copy2)) + +# works, so a bug in the swap above? +def test_split_linear(): + M, K, N = 32, 64, 128 + # M, K, N = 4, 6, 8 + + x = torch.randn(M, K) + + fc1 = nn.Linear(K, N, bias=False) + fc2 = nn.Linear(K, N, bias=False) + + fc3 = nn.Linear(K, N * 2, bias=False) + fc3.weight = torch.nn.Parameter( + torch.cat([copy.deepcopy(fc1.weight), copy.deepcopy(fc2.weight)], dim=0) + ) + + y1 = fc1(x) + y2 = fc2(x) + y3 = fc3(x) + y3_1, y3_2 = torch.split(y3, fc3.out_features // 2, dim=-1) + + assert torch.allclose(y1, y3_1) + assert torch.allclose(y2, y3_2) + + +if __name__ == '__main__': + test() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 814bd80f..9bcf624b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -373,6 +373,77 @@ def __init__(self): action="store_true", help="Whether to compile the model", ) + self.parser.add_argument( + "--training.compile_ln_mlp", + action="store_true", + help="Whether to compile only the LNMLP blocks", + ) + self.parser.add_argument( + "--training.compile_ln_linear", + action="store_true", + help="Whether to compile only the LNLinear blocks", + ) + self.parser.add_argument( + "--training.compile_linear", + action="store_true", + help="Whether to compile only the LNLinear blocks", + ) + self.parser.add_argument( + "--training.horizontally_fuse_fcs", + action="store_true", + help=""" + If true, fuses ffn.fc1 and ffn.fc3 into ffn.fc13. Note that this is required + to use te.LayerNormLinear for FFNs. + TODO also implement this for attention. + """, + ) + self.parser.add_argument( + "--training.te_swap_linear", + action="store_true", + help=""" + If true, swaps torch.nn.Linear with te.Linear + (not for land) + + Note: + * requires training.te_float8_autocast to use float8 + """, + ) + self.parser.add_argument( + "--training.te_swap_ln_linear", + action="store_true", + help=""" + If true, swaps NormFeedForward.norm_w13 from + nn.Sequential(RMSNorm, nn.Linear) to te.LayerNormLinear + (not for land) + + Note: + * requires training.horizontally_fuse_fcs to enable this swap + * this swap happens strictly before `training.te_swap_linear` if both are enabled + * requires training.te_float8_autocast to use float8 + """, + ) + self.parser.add_argument( + "--training.te_swap_ln_mlp", + action="store_true", + help=""" + If true, swaps `NormFeedForward` to te.LayerNormMLP + (not for land) + + Note: + * requires training.horizontally_fuse_fcs to enable this swap + * this swap happens strictly before `training.te_swap_linear` if both are enabled + * this swap happens strictly before `training.te_swap_ln_linear` if both are enabled + * requires training.te_float8_autocast to use float8 + """, + ) + self.parser.add_argument( + "--training.te_float8_autocast", + action="store_true", + help=""" + If true, enables TE's float8 autocast context manager + (not for land) + """, + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 1dd0d0bb..21d25d1b 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -31,6 +31,7 @@ def _is_sm89_or_later(): class Float8Handler: def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = False + self.job_config = job_config float8_config = job_config.float8 if not float8_config.enable_float8_linear: @@ -92,16 +93,28 @@ def convert_to_float8_training(self, model: nn.Module): from torchao.float8 import convert_to_float8_training + if self.job_config.training.compile_ln_linear: + # only convert compiled regions to float8 + module_filter_fn=lambda mod, fqn: (fqn != "output" and "norm_" in fqn) + elif self.job_config.training.compile_ln_mlp: + # only convert compiled regions to float8 + module_filter_fn=lambda mod, fqn: (fqn != "output" and "feed_forward" in fqn) + else: + module_filter_fn=lambda mod, fqn: fqn != "output" + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear convert_to_float8_training( model, config=self.config, - module_filter_fn=lambda mod, fqn: fqn != "output", + module_filter_fn=module_filter_fn, + # module_filter_fn=lambda mod, fqn: fqn != "output", + # module_filter_fn=lambda mod, fqn: fqn != "output" and "norm_w13" in fqn, ) logger.info( "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" f"{self.config.enable_fsdp_float8_all_gather}" ) + print(model) def precompute_float8_dynamic_scale_for_fsdp( self, model: Union[nn.Module, List[nn.Module]] diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 887a96cd..61080843 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -29,10 +29,12 @@ } llama3_configs = { - "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), + # "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), + "debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16, rope_theta=500000), "8B": ModelArgs( dim=4096, n_layers=32, + # n_layers=1, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index a3bae18a..ae4a3a46 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -329,7 +329,12 @@ def init_weights(self): for norm in (self.attention_norm, self.ffn_norm): norm.reset_parameters() self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) + if 'LayerNormMLP' in str(type(self.feed_forward)): + torch.nn.init.ones_(self.feed_forward.layer_norm_weight) + torch.nn.init.trunc_normal_(self.feed_forward.fc1_weight, mean=0.0, std=self.weight_init_std) + torch.nn.init.trunc_normal_(self.feed_forward.fc2_weight, mean=0.0, std=self.weight_init_std) + else: + self.feed_forward.init_weights(self.weight_init_std) class Transformer(nn.Module): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4d4c60bc..8b6ca6df 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -22,6 +22,7 @@ from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, + apply_activation_checkpointing, ) from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -75,7 +76,7 @@ def parallelize_llama( "fused_rmsnorm is not compatible with torch.compile yet. " "Please use rmsnorm or layernorm." ) - apply_compile(model) + apply_compile(model, job_config) if ( parallel_dims.dp_shard_enabled @@ -243,8 +244,20 @@ def apply_tp( } +import transformer_engine.pytorch as te +rng_seed = 1234 +torch.manual_seed(rng_seed) +torch.cuda.manual_seed(rng_seed) +CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker() +CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed) + + +def get_cuda_rng_tracker(): + return CUDA_RNG_STATES_TRACKER + + def _apply_ac_to_transformer_block(module: nn.Module, ac_config): - valid_ac_modes = ("full", "selective") + valid_ac_modes = ("full", "selective", "full_te") if ac_config.mode not in valid_ac_modes: raise ValueError( f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" @@ -252,6 +265,23 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): if ac_config.mode == "full": return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + elif ac_config.mode == "full_te": + # copy-paste from https://github.com/NVIDIA/TransformerEngine/blob/64126aa8c469b2a97ace01f925f3d5786d5fd1bb/examples/pytorch/fsdp/fsdp.py, apply_fsdp_checkpointing + # note: + # LLaMa 3 8B on 8 H100s with this option: + # 42.27 GiB, 4880 tps, strictly worse than PT-D's full AC. Have not done debugging + # on the cause yet. + + wrapper = lambda m: ptd_checkpoint_wrapper( + m, + checkpoint_fn=te.distributed.checkpoint, + use_reentrant=False, + get_rng_state_tracker=get_cuda_rng_tracker, + ) + def check_fn(submodule): + return True + apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) + return module assert ac_config.mode == "selective", f"{ac_config.mode}" use_op_sac = ac_config.selective_ac_option == "op" @@ -314,16 +344,55 @@ def apply_ac(model: nn.Module, ac_config): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") -def apply_compile(model: nn.Module): +def apply_compile(model: nn.Module, job_config): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) - model.layers.register_module(layer_id, transformer_block) + if job_config.training.compile_ln_mlp: + def _apply_compile(mod): + for name, child in mod.named_children(): + # hacky check, but good enough for this use case + # if isinstance(child, torch.nn.Sequential) and len(child) == 2: + if name == 'feed_forward': + new_child = torch.compile(child) + setattr(mod, name, new_child) + else: + _apply_compile(child) + + logger.info("Compiling each LNMLP with torch.compile") + _apply_compile(model) + elif job_config.training.compile_ln_linear: + def _apply_compile(mod): + for name, child in mod.named_children(): + # hacky check, but good enough for this use case + if isinstance(child, torch.nn.Sequential) and len(child) == 2: + new_child = torch.compile(child) + setattr(mod, name, new_child) + else: + _apply_compile(child) + + logger.info("Compiling each LNLinear with torch.compile") + _apply_compile(model) + elif job_config.training.compile_linear: + def _apply_compile(mod): + for name, child in mod.named_children(): + # hacky check, but good enough for this use case + if isinstance(child, torch.nn.Linear): + new_child = torch.compile(child) + setattr(mod, name, new_child) + else: + _apply_compile(child) + + logger.info("Compiling each Linear with torch.compile") + _apply_compile(model) + else: + for layer_id, transformer_block in model.layers.named_children(): + # transformer_block = torch.compile(transformer_block, fullgraph=True) + transformer_block = torch.compile(transformer_block, fullgraph=False) + model.layers.register_module(layer_id, transformer_block) - logger.info("Compiling each TransformerBlock with torch.compile") + logger.info("Compiling each TransformerBlock with torch.compile") def apply_fsdp( diff --git a/torchtitan/te_utils.py b/torchtitan/te_utils.py new file mode 100644 index 00000000..386fee2e --- /dev/null +++ b/torchtitan/te_utils.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for testing TransformerEngine + +Note: I attempted to hack in DTensor-based TP/SP to te.Linear in the +link below, and gave up for now as it seemed to be a lot of remaining work. +We can power through that if needed later. +* https://gist.github.com/vkuzo/64d5362b63dd6c76410464e020d9a35f + +Note: I looked into using te.LayerNormLinear, and that would require changing +how Attention and FFN are defined in torchtitan to use a single gemm for +attn.kqv and ffn.w1_w3. Punting for now but we can do this later if needed. +""" + +import contextlib +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchtitan.models.llama.model import apply_rotary_emb, repeat_kv + +# import transformer_engine as te +import transformer_engine.pytorch as te + +from transformer_engine.common.recipe import Format, DelayedScaling +te_fp8_format = Format.HYBRID +te_fp8_recipe = DelayedScaling(fp8_format=te_fp8_format, amax_history_len=16, amax_compute_algo="max") + +def swap_linear_to_te_linear(model, fqn=''): + for name, child in model.named_children(): + new_fqn = f"{fqn}.{name}" + # if isinstance(child, torch.nn.Linear) and new_fqn != 'output' and 'norm_' in new_fqn: + if isinstance(child, torch.nn.Linear) and name != 'output': + te_linear = te.Linear(child.in_features, child.out_features, bias=child.bias is not None) + te_linear.weight = child.weight + te_linear.bias = child.bias + setattr(model, name, te_linear) + else: + swap_linear_to_te_linear(child, new_fqn) + +class ResettableIdentity(nn.Identity): + def reset_parameters(self): + pass + + +class NormFeedForward(torch.nn.Module): + """ + A replacement for ffn_norm -> ffn which is TE swap friendly + """ + + def __init__(self, ffn_norm, ffn): + super().__init__() + # self.ffn_norm = ffn_norm + + # fuse w1 and w3, TE assumes this optimization is applied + w13_in_feat = ffn.w1.in_features + w13_out_feat = ffn.w1.out_features * 2 + with torch.device("meta"): + w13 = nn.Linear(w13_in_feat, w13_out_feat, bias=False) + w13.weight = torch.nn.Parameter( + torch.cat([ffn.w1.weight, ffn.w3.weight], dim=0).contiguous() + ) + + # wrapped in a sequential for easy swap to either te.LayerNorm or + # torch.compiling just this wrapper + self.norm_w13 = nn.Sequential(ffn_norm, w13) + + self.w2 = ffn.w2 + self.split_dim = getattr(self.norm_w13, "1").out_features // 2 + + def forward(self, x): + # x = self.ffn_norm(x) + # x = self.w13(x) + x = self.norm_w13(x) + w1_out, w3_out = torch.split( + x, + self.split_dim, + dim=-1, + ) + out = self.w2(F.silu(w1_out) * w3_out) + return out + + def init_weights(self, init_std: float): + if isinstance(self.norm_w13, te.LayerNormLinear): + torch.nn.init.ones_(self.norm_w13.layer_norm_weight) + + # slight difference from llama/model.py - init every weight to init_std + for linear in (self.w2, self.norm_w13): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + else: + getattr(self.norm_w13, "0").reset_parameters() + + # slight difference from llama/model.py - init every weight to init_std + for linear in (self.w2, getattr(self.norm_w13, "1")): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class NormAttention(torch.nn.Module): + """ + A replacement for attn_norm -> attn which is TE swap friendly + """ + def __init__(self, attn_norm, attn): + super().__init__() + + # fuse attn.qkv, TE assumes this optimization is applied + self.split_dim = attn.wq.out_features + with torch.device("meta"): + wqkv = nn.Linear(attn.wq.in_features, attn.wq.out_features * 3, bias=False) + wqkv.weight = torch.nn.Parameter( + torch.cat([attn.wq.weight, attn.wk.weight, attn.wv.weight], dim=0).contiguous() + ) + + self.norm_wqkv = nn.Sequential(attn_norm, wqkv) + self.wo = attn.wo + + self.n_heads = attn.n_heads + self.n_kv_heads = attn.n_kv_heads + self.n_rep = attn.n_rep + self.head_dim = attn.head_dim + + def forward(self, x, freqs_cis): + bs, seqlen, _ = x.shape + # xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + x = self.norm_wqkv(x) + xq, xk, xv = torch.split( + x, + [ + self.n_heads * self.head_dim, + self.n_kv_heads * self.head_dim, + self.n_kv_heads * self.head_dim, + ], + dim=-1, + ) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + def init_weights(self, init_std: float): + if isinstance(self.norm_wqkv, te.LayerNormLinear): + torch.nn.init.ones_(self.norm_wqkv.layer_norm_weight) + + # slight difference from llama/model.py - init every weight to init_std + for linear in (self.wo, self.norm_wqkv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + else: + getattr(self.norm_wqkv, "0").reset_parameters() + + # slight difference from llama/model.py - init every weight to init_std + for linear in (self.wo, getattr(self.norm_wqkv, "1")): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + + +def swap_norm_ffn_to_te_friendly_norm_ffn(parent_module) -> None: + """ + `parent_module` is a module with the following structure: + + parent_module + ffn_norm: LayerNorm|RMSNorm + ffn: FeedForward + w1 + w2 + w3 + + this function will rewrite the graph without changing numerics to the following structure + + parent_module + ffn_norm: ResettableIdentity + feed_forward: NormFeedForward + norm_w13: Sequential + 0: LayerNorm|RMSNorm + 1: Linear (fused w1 and w3) + w2: Linear + + this is done to then make it easier to then swap to te.LayerNormLinear + """ + if hasattr(parent_module, "ffn_norm") and hasattr(parent_module, "feed_forward"): + parent_module.feed_forward = NormFeedForward( + parent_module.ffn_norm, + parent_module.feed_forward, + ) + parent_module.ffn_norm = ResettableIdentity() + else: + for name, child in parent_module.named_children(): + swap_norm_ffn_to_te_friendly_norm_ffn(child) + +def swap_norm_attn_to_te_friendly_norm_attn(parent_module): + """ + `parent_module` is a module with the following structure: + + parent_module + attention_norm: LayerNorm|RMSNorm + attention: Attention + wq + wk + wv + wo + + this function will rewrite the graph without changing numerics to the following structure + + parent_module + attention_norm: ResettableIdentity + attention: NormAttention + norm_wqkv: Sequential + 0: LayerNorm|RMSNorm + 1: Linear (fused wq, wk, wv) + wo: Linear + + this is done to then make it easier to then swap to te.LayerNormLinear + """ + if hasattr(parent_module, "attention_norm") and hasattr(parent_module, "attention"): + parent_module.attention = NormAttention( + parent_module.attention_norm, + parent_module.attention, + ) + parent_module.attention_norm = ResettableIdentity() + else: + for name, child in parent_module.named_children(): + swap_norm_attn_to_te_friendly_norm_attn(child) + +def swap_te_friendly_norm_ffn_to_te_layernorm_linear(parent_module): + """ + In `NormFeedForward`, swaps `norm_w13` with `te.LayerNormLinear` + In `NormAttention`, swaps `norm_wqkv` with `te.LayerNormLinear` + """ + + if isinstance(parent_module, NormFeedForward): + + te_ln_linear = te.LayerNormLinear( + parent_module.norm_w13[1].in_features, + parent_module.norm_w13[1].out_features, + bias=False, + normalization='RMSNorm', + ) + + te_ln_linear.layer_norm_weight = parent_module.norm_w13[0].weight + te_ln_linear.weight = parent_module.norm_w13[1].weight + parent_module.norm_w13 = te_ln_linear + + elif isinstance(parent_module, NormAttention): + + te_ln_linear = te.LayerNormLinear( + parent_module.norm_wqkv[1].in_features, + parent_module.norm_wqkv[1].out_features, + bias=False, + normalization='RMSNorm', + ) + te_ln_linear.layer_norm_weight = parent_module.norm_wqkv[0].weight + te_ln_linear.weight = parent_module.norm_wqkv[1].weight + parent_module.norm_wqkv = te_ln_linear + + else: + for name, child in parent_module.named_children(): + swap_te_friendly_norm_ffn_to_te_layernorm_linear(child) + +def _monkey_patched_te_layernorm_mlp_init_weights(self): + torch.nn.init.ones_(self.layer_norm_weight) + torch.nn.init.trunc_normal_(self.fc1_weight, mean=0.0, std=init_std) + torch.nn.init.trunc_normal_(self.fc2_weight, mean=0.0, std=init_std) + +def swap_te_friendly_norm_ffn_to_te_layernorm_mlp(parent_module): + """ + Swaps `NormFeedForward` with `te.LayerNormMLP` + """ + + for name, child in parent_module.named_children(): + if isinstance(child, NormFeedForward): + te_ln_mlp = te.LayerNormMLP( + child.norm_w13[1].in_features, + child.norm_w13[1].out_features, + bias=False, + normalization='RMSNorm', + activation='swiglu', + ) + te_ln_mlp.layer_norm_weight = child.norm_w13[0].weight + te_ln_mlp.fc1_weight = child.norm_w13[1].weight + te_ln_mlp.fc2_weight = child.w2.weight + setattr(parent_module, name, te_ln_mlp) + + else: + swap_te_friendly_norm_ffn_to_te_layernorm_mlp(child) + + +def get_maybe_fp8_autocast(job_config): + # not for land - set up TransformerEngine fp8 autocast + # Note: te.fp8_autocast has to be created at every training iteration. + # If we try to create it once and reuse, we get this error: + # https://gist.github.com/vkuzo/d9840328c8bdc2901b8d04aa570ecb5b + maybe_te_float8_ctx = contextlib.nullcontext() + if job_config.training.te_float8_autocast: + assert ( + job_config.training.te_swap_linear + or job_config.training.te_swap_ln_linear + or job_config.training.te_swap_ln_mlp + ) + maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=te_fp8_recipe) + return maybe_te_float8_ctx diff --git a/train.py b/train.py index 58c2bcba..ce423502 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger +import torchtitan.te_utils as te_utils from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ( @@ -28,6 +29,8 @@ ) from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling from torchtitan.utils import device_module, device_type +import transformer_engine as te_main +import torchao # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -35,6 +38,9 @@ def main(job_config: JobConfig): init_logger() logger.info(f"Starting job: {job_config.job.description}") + logger.info(f"PyTorch version: {torch.__version__}") + logger.info(f"TransformerEngine version: {te_main.__version__}") + logger.info(f"torchao version: {torchao.__version__}") # used for colorful printing color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor @@ -113,11 +119,24 @@ def main(job_config: JobConfig): with torch.device("meta"): model = model_cls.from_model_args(model_config) + if job_config.training.horizontally_fuse_fcs: + # note: this is required for te.LayerNormLinear + te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(model) + te_utils.swap_norm_attn_to_te_friendly_norm_attn(model) + # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear based on float8 configs float8_handler.convert_to_float8_training(model) + # not for land - set up TransformerEngine + if job_config.training.te_swap_ln_mlp: + te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_mlp(model) + if job_config.training.te_swap_ln_linear: + te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(model) + if job_config.training.te_swap_linear: + te_utils.swap_linear_to_te_linear(model) + # log model size model_param_count = utils.get_num_params(model) num_flop_per_token = utils.get_num_flop_per_token( @@ -244,6 +263,8 @@ def loss_fn(pred, labels): checkpoint.reset() + print(model) + # train loop logger.info( f"Training starts at step {train_state.step + 1}, " @@ -285,7 +306,11 @@ def loss_fn(pred, labels): else None ) + # not for land - set up TransformerEngine fp8 autocast + maybe_te_float8_ctx = te_utils.get_maybe_fp8_autocast(job_config) + if parallel_dims.pp_enabled: + assert not job_config.training.use_te, "unsupported" # Pipeline Parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 @@ -307,12 +332,13 @@ def loss_fn(pred, labels): else: # Non-PP forward / backward with train_context(optional_context_parallel_ctx): - pred = model(input_ids) - loss = loss_fn(pred, labels) - # pred.shape=(bs, seq_len, vocab_size) - # need to free to before bwd to avoid peaking memory - del pred - loss.backward() + with maybe_te_float8_ctx: + pred = model(input_ids) + loss = loss_fn(pred, labels) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() # clip gradients utils.clip_grad_norm_( diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 2d55a36d..06a0027e 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -6,14 +6,14 @@ dump_folder = "./outputs" description = "Llama 3 70B training" [profiling] -enable_profiling = true -save_traces_folder = "profile_trace" -profile_freq = 100 +# enable_profiling = true +# save_traces_folder = "profile_trace" +# profile_freq = 100 [metrics] -log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +# log_freq = 10 +# enable_tensorboard = true +# save_tb_folder = "tb" [model] name = "llama3" @@ -35,7 +35,7 @@ data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = false -dataset = "c4" +dataset = "c4_test" [experimental] context_parallel_degree = 1 diff --git a/vasiliy_sweep.sh b/vasiliy_sweep.sh new file mode 100755 index 00000000..f94de0df --- /dev/null +++ b/vasiliy_sweep.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# terminate on first error +set -e + +# sweep over various important torchtitan + TE experiments + +OUTPUT_FOLDER=/home/vasiliy/local/tmp/torchtitan_outputs +OUTPUT_LOGFILE=logs.txt + +# need to loop over: +# 1. AC (none, full, selective with op) +# 2. experiment branches (TE and PT) + +for AC_SETTING in none selective full +do + + for NAME in baseline te_linear_f8 te_ln_linear_f8 pt_f8 pt_f8_fsdp_f8 + do + + if [ $NAME == "baseline" ]; then + EXTRA_ARGS="" + elif [ $NAME == "te_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_float8_autocast" + elif [ $NAME == "te_ln_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_swap_ln_linear --training.te_float8_autocast" + elif [ $NAME == "pt_f8" ]; then + EXTRA_ARGS="--float8.enable_float8_linear" + elif [ $NAME == "pt_f8_fsdp_f8" ]; then + EXTRA_ARGS="--float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" + else + # should not get here + exit 1 + fi + + OUTPUT_SUBFOLDER="20241204_v2_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + + # create the subdir if does not exist, `tee` needs this + mkdir -p $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER + + CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh $EXTRA_ARGS \ + --job.dump_folder $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER \ + --training.compile \ + --training.horizontally_fuse_fcs \ + --activation_checkpoint.mode $AC_SETTING \ + --activation_checkpoint.selective_ac_option 2 \ + --training.steps 200 \ + --profiling.profile_freq 100 2>&1 | tee $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER/$OUTPUT_LOGFILE + + done + +done diff --git a/vasiliy_sweep_regional.sh b/vasiliy_sweep_regional.sh new file mode 100755 index 00000000..5e7c1c74 --- /dev/null +++ b/vasiliy_sweep_regional.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# terminate on first error +set -e + +# sweep over various important torchtitan + TE experiments + +OUTPUT_FOLDER=/home/vasiliy/local/tmp/torchtitan_outputs +OUTPUT_LOGFILE=logs.txt + +# need to loop over: +# 1. AC (none, full, selective with op) +# 2. experiment branches (TE and PT) + +for AC_SETTING in none full +do + + for NAME in baseline te_ln_linear_f8 pt_f8 pt_f8_fsdp_f8 + do + + if [ $NAME == "baseline" ]; then + EXTRA_ARGS="--training.compile --training.compile_ln_linear" + elif [ $NAME == "te_ln_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_ln_linear --training.te_float8_autocast" + elif [ $NAME == "pt_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_ln_linear --float8.enable_float8_linear" + elif [ $NAME == "pt_f8_fsdp_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_ln_linear --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" + else + # should not get here + exit 1 + fi + + # OUTPUT_SUBFOLDER="20241204_v3_regional_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + # fixed - only enable compile for non-TE + OUTPUT_SUBFOLDER="20241204_v5_regional_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + + # create the subdir if does not exist, `tee` needs this + mkdir -p $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER + + CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh $EXTRA_ARGS \ + --job.dump_folder $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER \ + --training.horizontally_fuse_fcs \ + --activation_checkpoint.mode $AC_SETTING \ + --activation_checkpoint.selective_ac_option 2 \ + --training.steps 200 \ + --profiling.profile_freq 100 2>&1 | tee $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER/$OUTPUT_LOGFILE + + done + +done diff --git a/vasiliy_sweep_regional_linear.sh b/vasiliy_sweep_regional_linear.sh new file mode 100755 index 00000000..8fa38718 --- /dev/null +++ b/vasiliy_sweep_regional_linear.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# terminate on first error +set -e + +# sweep over various important torchtitan + TE experiments + +OUTPUT_FOLDER=/home/vasiliy/local/tmp/torchtitan_outputs +OUTPUT_LOGFILE=logs.txt + +# need to loop over: +# 1. AC (none, full, selective with op) +# 2. experiment branches (TE and PT) + +for AC_SETTING in none full +do + + for NAME in baseline te_ln_linear_f8 pt_f8 pt_f8_fsdp_f8 + do + + if [ $NAME == "baseline" ]; then + EXTRA_ARGS="--training.compile --training.compile_linear" + elif [ $NAME == "te_ln_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_float8_autocast" + elif [ $NAME == "pt_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_linear --float8.enable_float8_linear" + elif [ $NAME == "pt_f8_fsdp_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_linear --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" + else + # should not get here + exit 1 + fi + + # v6 contained an error + OUTPUT_SUBFOLDER="20241204_v7_regional_linear_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + + # create the subdir if does not exist, `tee` needs this + mkdir -p $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER + + CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh $EXTRA_ARGS \ + --job.dump_folder $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER \ + --training.horizontally_fuse_fcs \ + --activation_checkpoint.mode $AC_SETTING \ + --activation_checkpoint.selective_ac_option 2 \ + --training.steps 200 \ + --profiling.profile_freq 100 2>&1 | tee $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER/$OUTPUT_LOGFILE + + done + +done diff --git a/vasiliy_sweep_regional_linear_v2.sh b/vasiliy_sweep_regional_linear_v2.sh new file mode 100755 index 00000000..9e189c4f --- /dev/null +++ b/vasiliy_sweep_regional_linear_v2.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# terminate on first error +set -e + +# fixup from vasiliy_sweep_regional_linear: +# one of the experiments had a transient error, re-running + +OUTPUT_FOLDER=/home/vasiliy/local/tmp/torchtitan_outputs +OUTPUT_LOGFILE=logs.txt + +# need to loop over: +# 1. AC (none, full, selective with op) +# 2. experiment branches (TE and PT) + +# for AC_SETTING in none full +for AC_SETTING in none +do + + # for NAME in baseline te_ln_linear_f8 pt_f8 pt_f8_fsdp_f8 + for NAME in pt_f8 + do + + if [ $NAME == "baseline" ]; then + EXTRA_ARGS="--training.compile --training.compile_linear" + elif [ $NAME == "te_ln_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_float8_autocast" + elif [ $NAME == "pt_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_linear --float8.enable_float8_linear" + elif [ $NAME == "pt_f8_fsdp_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_linear --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" + else + # should not get here + exit 1 + fi + + # v8 for fixup + OUTPUT_SUBFOLDER="20241204_v8_regional_linear_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + + # create the subdir if does not exist, `tee` needs this + mkdir -p $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER + + CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh $EXTRA_ARGS \ + --job.dump_folder $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER \ + --training.horizontally_fuse_fcs \ + --activation_checkpoint.mode $AC_SETTING \ + --activation_checkpoint.selective_ac_option 2 \ + --training.steps 200 \ + --profiling.profile_freq 100 2>&1 | tee $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER/$OUTPUT_LOGFILE + + done + +done diff --git a/vasiliy_sweep_regional_ln_mlp.sh b/vasiliy_sweep_regional_ln_mlp.sh new file mode 100755 index 00000000..65016bf2 --- /dev/null +++ b/vasiliy_sweep_regional_ln_mlp.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# terminate on first error +set -e + +# TODO(next) build this + +# sweep over various important torchtitan + TE experiments + +OUTPUT_FOLDER=/home/vasiliy/local/tmp/torchtitan_outputs +OUTPUT_LOGFILE=logs.txt + +# need to loop over: +# 1. AC (none, full, selective with op) +# 2. experiment branches (TE and PT) + +# for AC_SETTING in none full +for AC_SETTING in full +do + + # for NAME in baseline te_ln_mlp_f8 pt_f8 pt_f8_fsdp_f8 + for NAME in te_ln_mlp_f8 pt_f8 + do + + if [ $NAME == "baseline" ]; then + EXTRA_ARGS="--training.compile --training.compile_ln_mlp" + elif [ $NAME == "te_ln_mlp_f8" ]; then + EXTRA_ARGS="--training.te_swap_ln_mlp --training.te_float8_autocast" + elif [ $NAME == "pt_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_ln_mlp --float8.enable_float8_linear" + elif [ $NAME == "pt_f8_fsdp_f8" ]; then + EXTRA_ARGS="--training.compile --training.compile_ln_mlp --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" + else + # should not get here + exit 1 + fi + + # OUTPUT_SUBFOLDER="20241204_v3_regional_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + # fixed - only enable compile for non-TE + # OUTPUT_SUBFOLDER="20241209_v10_regional_mlp_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + OUTPUT_SUBFOLDER="20241209_v11_regional_mlp_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + + # create the subdir if does not exist, `tee` needs this + mkdir -p $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER + + CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh $EXTRA_ARGS \ + --job.dump_folder $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER \ + --training.horizontally_fuse_fcs \ + --activation_checkpoint.mode $AC_SETTING \ + --activation_checkpoint.selective_ac_option 2 \ + --training.steps 200 \ + --profiling.profile_freq 100 2>&1 | tee $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER/$OUTPUT_LOGFILE + + done + +done diff --git a/vasiliy_sweep_v2.sh b/vasiliy_sweep_v2.sh new file mode 100755 index 00000000..2e39782d --- /dev/null +++ b/vasiliy_sweep_v2.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# terminate on first error +set -e + +# sweep over various important torchtitan + TE experiments +# save as v1, but adds layernorm_mlp only (and doesn't re-run the other sweeps) + +OUTPUT_FOLDER=/home/vasiliy/local/tmp/torchtitan_outputs +OUTPUT_LOGFILE=logs.txt + +# need to loop over: +# 1. AC (none, full, selective with op) +# 2. experiment branches (TE and PT) + +# for AC_SETTING in none selective full +for AC_SETTING in full +do + + # for NAME in baseline te_linear_f8 te_ln_linear_f8 pt_f8 pt_f8_fsdp_f8 + # for NAME in te_ln_mlp_f8 baseline + for NAME in te_ln_mlp_f8 + do + + if [ $NAME == "baseline" ]; then + EXTRA_ARGS="" + elif [ $NAME == "te_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_float8_autocast" + elif [ $NAME == "te_ln_linear_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_swap_ln_linear --training.te_float8_autocast" + elif [ $NAME == "te_ln_mlp_f8" ]; then + EXTRA_ARGS="--training.te_swap_linear --training.te_swap_ln_linear --training.te_swap_ln_mlp --training.te_float8_autocast" + elif [ $NAME == "pt_f8" ]; then + EXTRA_ARGS="--float8.enable_float8_linear" + elif [ $NAME == "pt_f8_fsdp_f8" ]; then + EXTRA_ARGS="--float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" + else + # should not get here + exit 1 + fi + + OUTPUT_SUBFOLDER="20241206_v2_ln_mlp_only_llama3_8b_name_${NAME}_ac_${AC_SETTING}" + + # create the subdir if does not exist, `tee` needs this + mkdir -p $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER + + CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh $EXTRA_ARGS \ + --job.dump_folder $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER \ + --training.compile \ + --training.horizontally_fuse_fcs \ + --activation_checkpoint.mode $AC_SETTING \ + --activation_checkpoint.selective_ac_option 2 \ + --training.steps 200 \ + --profiling.profile_freq 100 2>&1 | tee $OUTPUT_FOLDER/$OUTPUT_SUBFOLDER/$OUTPUT_LOGFILE + + done + +done