Skip to content

Commit

Permalink
latest update on coord check implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lintangsutawika committed Jan 24, 2024
1 parent 16d04b1 commit e7b7bf6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 51 deletions.
89 changes: 48 additions & 41 deletions megatron/mup_substitute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
import torch.nn.functional as F

from megatron import print_rank_0

# from mup import coord_check as mup_coord_check
from megatron.training import train_step

Expand All @@ -30,7 +32,7 @@ def _get_coord_data(
filter_module_by_name=None,
fix_data=True,
cuda=True,
nseeds=1,
nseeds=3,
output_fdict=None,
input_fdict=None,
param_fdict=None,
Expand All @@ -43,40 +45,47 @@ def word_embedding_coord_check_hook(module, input, output):
with torch.no_grad():
word_embedding_act_abs_mean_list.append(output.abs().mean().item())

word_embedding_act_abs_mean_list = []
_seeds = []
_steps = []
remove_hooks = []

for i in range(nseeds):
torch.manual_seed(i)
for width, model in models.items():
model = model()
model.train()
# optimizer = optcls(model)
optimizer, _ = optcls(model, neox_args)
optimizer = optcls(model)
# optimizer, _ = optcls(model, neox_args)

for step in range(nsteps + 1):
word_embedding_act_abs_mean_list = []
remove_hooks = []

# add hooks
# for name, module in model.named_modules():
# if name.endswith(".embedding.word_embeddings"):
# print("yess")
# import sys; sys.exit
# remove_hook.append(
# module.register_forward_hook(word_embedding_coord_check_hook))

# # if filter_module_by_name and not filter_module_by_name(name):
# # continue
# # pass
# # remove_hooks.append(
# # module.register_forward_hook(
# # mup_coord_check._record_coords(
# # df,
# # width,
# # name,
# # step + 1,
# # output_fdict=output_fdict,
# # input_fdict=input_fdict,
# # param_fdict=param_fdict,
# # )
# # )
# # )
for name, module in model.named_modules():
if name.endswith(".word_embeddings"):
remove_hooks.append(
module.register_forward_hook(word_embedding_coord_check_hook))

_steps.append(step)
_seeds.append(i)


# if filter_module_by_name and not filter_module_by_name(name):
# continue
# pass
# remove_hooks.append(
# module.register_forward_hook(
# mup_coord_check._record_coords(
# df,
# width,
# name,
# step + 1,
# output_fdict=output_fdict,
# input_fdict=input_fdict,
# param_fdict=param_fdict,
# )
# )
# )

# train for a step
loss_dict, skipped_iter = train_step(
Expand All @@ -91,14 +100,13 @@ def word_embedding_coord_check_hook(module, input, output):
# remove hooks
for handle in remove_hooks:
handle.remove()

print("word_embedding_act_abs_mean_list")
print(word_embedding_act_abs_mean_list)
import gc

del model
gc.collect()

for _i,_j,_k in zip(_seeds, _steps, word_embedding_act_abs_mean_list):
print_rank_0(_i, _j, _k)

return pd.DataFrame(df)


Expand Down Expand Up @@ -211,15 +219,14 @@ def get_trainable(model):
params.append(p)
return params

# if optimizer == "sgd":
# optcls = lambda model: SGD(get_trainable(model), lr=lr)
# elif optimizer == "adam":
# optcls = lambda model: Adam(get_trainable(model), lr=lr)
# elif optimizer == "adamw":
# optcls = lambda model: AdamW(get_trainable(model), lr=lr)
# elif optimizer is None:
# raise ValueError("optimizer should be sgd|adam|adamw or a custom function")
optcls = optimizer
if optimizer == "sgd":
optcls = lambda model: SGD(get_trainable(model), lr=lr)
elif optimizer == "adam":
optcls = lambda model: Adam(get_trainable(model), lr=lr)
elif optimizer == "adamw":
optcls = lambda model: AdamW(get_trainable(model), lr=lr)
elif optimizer is None:
raise ValueError("optimizer should be sgd|adam|adamw or a custom function")

data = _get_coord_data(
neox_args, timers, lr_scheduler, models, dataloader, optcls, **kwargs
Expand Down
33 changes: 23 additions & 10 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
CharCounter,
)
from megatron.model.gpt2_model import cross_entropy
from eval_tasks import run_eval_harness
# from eval_tasks import run_eval_harness


def mup_weights_reinit(neox_args, model):
Expand Down Expand Up @@ -124,7 +124,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
sys.exit(1)


def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator):
def mup_coord_check(neox_args, timers, train_data_iterator):
from megatron.mup_substitute import get_coord_data
# from mup.coord_check import plot_coord_data

Expand All @@ -133,7 +133,7 @@ def gen():
old_hidden_size = neox_args.hidden_size
neox_args.hidden_size = hidden_size

model, optimizer, _ = setup_model_and_optimizer(
model, optimizer, lr_scheduler = setup_model_and_optimizer(
neox_args=neox_args, use_cache=False
)

Expand All @@ -145,24 +145,35 @@ def gen():

models = {}

# Hidden size needs to be divisible by num attention heads
for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)):
models[hidden_size] = lazy_model(hidden_size)
# # Hidden size needs to be divisible by num attention heads
# for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)):
# models[hidden_size] = lazy_model(hidden_size)

# optimizer, _ = get_optimizer(model, neox_args)
# 128
# 256
# 512
# 1024
# 2048
# 4096
# 8192

models[neox_args.hidden_size] = lazy_model(neox_args.hidden_size)

print_rank_0("df_up")
neox_args.use_mup = True
df_up = get_coord_data(
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True, optimizer=get_optimizer
neox_args, timers, None, models, train_data_iterator, mup=True, optimizer="adam"
)
print_rank_0("df_sp")
neox_args.use_mup = False
df_sp = get_coord_data(
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False, optimizer=get_optimizer
neox_args, timers, None, models, train_data_iterator, mup=False, optimizer="adam"
)

# plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg")
# plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg")


print_rank_0("Saved coord check plots... exiting")
sys.exit(1)

Expand Down Expand Up @@ -207,7 +218,9 @@ def pretrain(neox_args):

if neox_args.use_mup and neox_args.coord_check:
print_rank_0("Do muP Coord Check")
mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator)
mup_coord_check(neox_args, timers, train_data_iterator)
else:
pass

# Print setup timing.
print_rank_0("done with setups ...")
Expand Down

0 comments on commit e7b7bf6

Please sign in to comment.