You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to adapt Llama3 for long context, is 128k. I am training on a v5-256, and are trying to follow the procedure explained in https://arxiv.org/pdf/2407.14482. Basically this states:
We set the batch size to 32 to fit 4 million tokens in a batch and use a learning rate of
3e-5 to train 2000 steps (8B tokens in total).
I have prepared a dataset with 128k context, using the HF dataset. (Thanks @aireenmei).
My challenge is however that setting per_device_batch_size=1 gives me a global batch size of 256. This is way too high, and I get OOM errors. I want to split batches across devices. I attempted setting num_pipeline_microbatches=8but this does not seem to work.
Are there other ways of accomplishing this? I understand gradient accumulation is not implemented (and I am not sure if it will work here).
The text was updated successfully, but these errors were encountered:
I am trying to adapt Llama3 for long context, is 128k. I am training on a v5-256, and are trying to follow the procedure explained in https://arxiv.org/pdf/2407.14482. Basically this states:
I have prepared a dataset with 128k context, using the HF dataset. (Thanks @aireenmei).
My challenge is however that setting
per_device_batch_size=1
gives me a global batch size of 256. This is way too high, and I get OOM errors. I want to split batches across devices. I attempted settingnum_pipeline_microbatches=8
but this does not seem to work.Are there other ways of accomplishing this? I understand gradient accumulation is not implemented (and I am not sure if it will work here).
The text was updated successfully, but these errors were encountered: