Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient accumulation #787

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

gradient accumulation #787

wants to merge 1 commit into from

Conversation

tonyjohnchen
Copy link
Collaborator

Adding new feature gradient accumulation to only update weight for every x steps.

Example command without using gradient accumulation:

python3 MaxText/train.py MaxText/configs/base.yml   base_output_directory=${MAXTEXT_OUTPUT_PATH} run_name=${RUN_NAME}    enable_checkpointing=false async_checkpointing=false    per_device_batch_size=1    skip_first_n_steps_for_profiler=5 steps=30    dataset_type=synthetic    profiler=xplane

Example command with using gradient accumulation:

python3 MaxText/train.py MaxText/configs/base.yml   base_output_directory=${MAXTEXT_OUTPUT_PATH} run_name=${RUN_NAME}    enable_checkpointing=false async_checkpointing=false    per_device_batch_size=1    skip_first_n_steps_for_profiler=5 steps=30    dataset_type=synthetic    profiler=xplane gradient_accumulation_steps=10

Result1
Result2

Copy link
Collaborator

@anfals anfals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double

MaxText/train.py Outdated Show resolved Hide resolved
MaxText/train.py Outdated Show resolved Hide resolved
MaxText/train.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this gives the same loss as a large batch? E.g. you could run locally on just a v4-8 with per_device_batch=1, gradient_accumulation_steps=5 for 100 steps and per_device_batch_size=5 gradient_accumulation_steps=1 for 100 steps (should give near identical loss)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with above test to verify the result.

Also trying to learn more here. After some research, it seems we should do:

  1. Accumulated Gradients += Gradients (from current accumulation step)
  2. Averaged Gradients = Accumulated Gradients / Number of accumulation steps (instead of doing average all the way, or calculate gradients at the last microbatch only)

MaxText/train.py Show resolved Hide resolved
@tonyjohnchen tonyjohnchen force-pushed the gradient branch 3 times, most recently from b39e98b to d4d99e3 Compare July 19, 2024 22:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants