Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long Context #801

Open
peregilk opened this issue Jul 28, 2024 · 2 comments
Open

Long Context #801

peregilk opened this issue Jul 28, 2024 · 2 comments
Assignees

Comments

@peregilk
Copy link

I am trying to adapt Llama3 for long context, is 128k. I am training on a v5-256, and are trying to follow the procedure explained in https://arxiv.org/pdf/2407.14482. Basically this states:

 We set the batch size to 32 to fit 4 million tokens in a batch and use a learning rate of
3e-5 to train 2000 steps (8B tokens in total).

I have prepared a dataset with 128k context, using the HF dataset. (Thanks @aireenmei).

My challenge is however that setting per_device_batch_size=1 gives me a global batch size of 256. This is way too high, and I get OOM errors. I want to split batches across devices. I attempted setting num_pipeline_microbatches=8but this does not seem to work.

Are there other ways of accomplishing this? I understand gradient accumulation is not implemented (and I am not sure if it will work here).

@aireenmei
Copy link
Collaborator

Have you tried setting per_device_batch_size<1 ?
See this example in our unit test: https://github.com/google/maxtext/blob/main/.github/workflows/UnitTests.yml#L102-L103

@peregilk
Copy link
Author

Thanks. I did not notice this. I will test.

@gobbleturk gobbleturk added bug Something isn't working feature request labels Sep 17, 2024
@shralex shralex removed the bug Something isn't working label Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants