Skip to content

Commit

Permalink
[not for land] TE experiments
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Dec 23, 2024
1 parent 87e2c09 commit f3ca7fa
Show file tree
Hide file tree
Showing 16 changed files with 1,109 additions and 23 deletions.
57 changes: 57 additions & 0 deletions parse_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Input: a subdirectory containing the logs from various experiments
Output: a csv file with loss values, peak memory usage, throughout from each experiment
"""

import csv
import os
import re

import fire

OUTPUT_FOLDER = '/home/vasiliy/local/tmp/torchtitan_outputs'

# example:
# [rank0]:[INFO | root ]: step: 10 loss: 7.8774 memory: 0.44GiB(0.47%) tps: 997,458 mfu: 1.50%
# note that number of spaces between terms can vary
regex = r"- step:[ ]+([\d]+).*loss:[ ]+([\d\.]+).*memory:[ ]+([\d\.]+)GiB.*tps: ([\d\,]+).*mfu.*"

def log_to_maybe_data(line):
res = re.search(regex, line)
if res is not None:
step, loss, memory_gib, wps = res.group(1), res.group(2), res.group(3), res.group(4)
return int(step), float(loss), float(memory_gib), int(wps.replace(',', ''))
else:
return None

def run(
subfolder_prefix: str,
results_filename: str,
):
subfolder_prefix = str(subfolder_prefix)

results = [['experiment', 'step', 'loss', 'memory_gib', 'tps']]

for entry in os.scandir(OUTPUT_FOLDER):
if entry.is_dir() and subfolder_prefix in entry.path:
print(entry)
log_fname = f"{entry.path}/logs.txt"
short_path = entry.path.replace(f"{OUTPUT_FOLDER}/", '')

with open(log_fname, 'r') as f:
lines = f.readlines()
for l in lines:
res = log_to_maybe_data(l)
if res is not None:
print(l.strip('\n'))
print(res)
results.append([short_path, *res])

with open(results_filename, 'w') as f:
writer = csv.writer(f)
writer.writerows(results)

print('done')

if __name__ == '__main__':
fire.Fire(run)
195 changes: 195 additions & 0 deletions test/test_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import copy

import torch
import torch.nn as nn

# path hack, TODO remove
import sys
sys.path.insert(0, '/home/vasiliy/local/torchtitan/')
import torchtitan.te_utils as te_utils
from torchtitan.models.norms import build_norm
from torchtitan.models.llama.model import FeedForward, Attention, ModelArgs, precompute_freqs_cis

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

# torch.use_deterministic_algorithms(True)
torch.manual_seed(0)

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)

def test_linear_module_swap():
x = torch.randn(32, 32, device='cuda')

m = nn.Sequential(nn.Linear(32, 32)).cuda()
te_utils.swap_linear_to_te_linear(m)
print(m)
m = torch.compile(m)

with maybe_te_float8_ctx:
y = m(x)
y.sum().backward()

print('done')

# Subsection of TransformerBlock with only the ffn norm and the ffn
class NormFFNBlock(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of):
super().__init__()
self.ffn_norm = build_norm("rmsnorm", dim, eps=1e-12)
self.feed_forward = FeedForward(dim, hidden_dim, multiple_of, None)

def forward(self, h):
out = h + self.feed_forward(self.ffn_norm(h))
return out

class NormAttnBlock(nn.Module):
def __init__(self, model_args):
super().__init__()
self.attention_norm = build_norm("rmsnorm", model_args.dim, eps=1e-12)
self.attention = Attention(model_args)
self.model_args = model_args
self.freqs_cis = precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# TODO: explain in docs/composability.md why we removed the 2x
# relaxing in our CP enablement PR
self.model_args.max_seq_len,
self.model_args.rope_theta,
).cuda()

def forward(self, x):
x = self.attention_norm(x)
x = self.attention(x, self.freqs_cis)
return x

def SQNR(x, y):
return 20 * torch.log10(
torch.linalg.norm(x) / torch.linalg.norm(x - y)
)

def test_norm_attn_rewrite():
dim = 256
model_args = ModelArgs()
m = NormAttnBlock(model_args).cuda().bfloat16()
m_copy = copy.deepcopy(m)
te_utils.swap_norm_attn_to_te_friendly_norm_attn(m_copy)
print(m)

x = torch.randn(1, 128, model_args.dim).cuda().bfloat16()
x_copy = copy.deepcopy(x)

y = m(x)

y_copy = m_copy(x_copy)

print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy)
print(m)
y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))



def test_norm_ln_ffn_rewrite():
dim = 256
hidden_dim = 512
multiple_of = 1

x = torch.randn(1, 128, 256).cuda().bfloat16()
x_copy = copy.deepcopy(x)

m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16()
m_copy = copy.deepcopy(m)
print(m)

y = m(x)
y.sum().backward()

te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy)
print(m_copy)

y_copy = m_copy(x_copy)
y_copy.sum().backward()

# TODO: debug why not an exact match
print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

# TODO test w13
# assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0)

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy)
print(m_copy)

y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))

def test_norm_mlp_ffn_rewrite():
dim = 256
hidden_dim = 512
multiple_of = 1

x = torch.randn(1, 128, 256).cuda().bfloat16()
x_copy = copy.deepcopy(x)

m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16()
m_copy = copy.deepcopy(m)
print(m)

y = m(x)
y.sum().backward()

te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy)
print(m_copy)

y_copy = m_copy(x_copy)
y_copy.sum().backward()

# TODO: debug why not an exact match
print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

# TODO test w13
# assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0)

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_mlp(m_copy)
print(123)
print(m_copy)

y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))

# works, so a bug in the swap above?
def test_split_linear():
M, K, N = 32, 64, 128
# M, K, N = 4, 6, 8

x = torch.randn(M, K)

fc1 = nn.Linear(K, N, bias=False)
fc2 = nn.Linear(K, N, bias=False)

fc3 = nn.Linear(K, N * 2, bias=False)
fc3.weight = torch.nn.Parameter(
torch.cat([copy.deepcopy(fc1.weight), copy.deepcopy(fc2.weight)], dim=0)
)

y1 = fc1(x)
y2 = fc2(x)
y3 = fc3(x)
y3_1, y3_2 = torch.split(y3, fc3.out_features // 2, dim=-1)

assert torch.allclose(y1, y3_1)
assert torch.allclose(y2, y3_2)


if __name__ == '__main__':
test()
71 changes: 71 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,77 @@ def __init__(self):
action="store_true",
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.compile_ln_mlp",
action="store_true",
help="Whether to compile only the LNMLP blocks",
)
self.parser.add_argument(
"--training.compile_ln_linear",
action="store_true",
help="Whether to compile only the LNLinear blocks",
)
self.parser.add_argument(
"--training.compile_linear",
action="store_true",
help="Whether to compile only the LNLinear blocks",
)
self.parser.add_argument(
"--training.horizontally_fuse_fcs",
action="store_true",
help="""
If true, fuses ffn.fc1 and ffn.fc3 into ffn.fc13. Note that this is required
to use te.LayerNormLinear for FFNs.
TODO also implement this for attention.
""",
)
self.parser.add_argument(
"--training.te_swap_linear",
action="store_true",
help="""
If true, swaps torch.nn.Linear with te.Linear
(not for land)
Note:
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_swap_ln_linear",
action="store_true",
help="""
If true, swaps NormFeedForward.norm_w13 from
nn.Sequential(RMSNorm, nn.Linear) to te.LayerNormLinear
(not for land)
Note:
* requires training.horizontally_fuse_fcs to enable this swap
* this swap happens strictly before `training.te_swap_linear` if both are enabled
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_swap_ln_mlp",
action="store_true",
help="""
If true, swaps `NormFeedForward` to te.LayerNormMLP
(not for land)
Note:
* requires training.horizontally_fuse_fcs to enable this swap
* this swap happens strictly before `training.te_swap_linear` if both are enabled
* this swap happens strictly before `training.te_swap_ln_linear` if both are enabled
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_float8_autocast",
action="store_true",
help="""
If true, enables TE's float8 autocast context manager
(not for land)
""",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
15 changes: 14 additions & 1 deletion torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _is_sm89_or_later():
class Float8Handler:
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False
self.job_config = job_config

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
Expand Down Expand Up @@ -92,16 +93,28 @@ def convert_to_float8_training(self, model: nn.Module):

from torchao.float8 import convert_to_float8_training

if self.job_config.training.compile_ln_linear:
# only convert compiled regions to float8
module_filter_fn=lambda mod, fqn: (fqn != "output" and "norm_" in fqn)
elif self.job_config.training.compile_ln_mlp:
# only convert compiled regions to float8
module_filter_fn=lambda mod, fqn: (fqn != "output" and "feed_forward" in fqn)
else:
module_filter_fn=lambda mod, fqn: fqn != "output"

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
module_filter_fn=module_filter_fn,
# module_filter_fn=lambda mod, fqn: fqn != "output",
# module_filter_fn=lambda mod, fqn: fqn != "output" and "norm_w13" in fqn,
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
print(model)

def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
}

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
# "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
# n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,12 @@ def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)
if 'LayerNormMLP' in str(type(self.feed_forward)):
torch.nn.init.ones_(self.feed_forward.layer_norm_weight)
torch.nn.init.trunc_normal_(self.feed_forward.fc1_weight, mean=0.0, std=self.weight_init_std)
torch.nn.init.trunc_normal_(self.feed_forward.fc2_weight, mean=0.0, std=self.weight_init_std)
else:
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down
Loading

0 comments on commit f3ca7fa

Please sign in to comment.