Skip to content

Commit

Permalink
zero out the weight in bert init run (tinygrad#9076)
Browse files Browse the repository at this point in the history
`DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 EVAL_BS=6 GPUS=6 MODEL=bert python3 examples/mlperf/model_train.py` no longer oom. I think the buffer of random init weights caused the oom.
  • Loading branch information
chenyuxyz authored Feb 14, 2025
1 parent 82ad0d2 commit b58e7b1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
5 changes: 2 additions & 3 deletions examples/mlperf/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def get_mlperf_bert_config():
"vocab_size": 30522
}

def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
def get_mlperf_bert_model():
from extra.models import bert
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert

Expand All @@ -220,8 +220,7 @@ def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
config = get_mlperf_bert_config()
if getenv("DISABLE_DROPOUT", 0):
config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
model = BertForPretraining(**config)
return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model
return BertForPretraining(**config)

def get_data_bert(GPUS:list[str], it):
data: dict[str, Tensor] = next(it)
Expand Down
10 changes: 8 additions & 2 deletions examples/mlperf/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,14 @@ def train_bert():

# ** init model **

model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)

model = get_mlperf_bert_model()
if RUNMLPERF:
model.load_from_pretrained(init_ckpt)
else:
# for init, zero out all weights
for p in get_parameters(model):
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()

parameters = get_parameters(model)
for p in parameters:
p.to_(GPUS)
Expand Down

0 comments on commit b58e7b1

Please sign in to comment.