-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement saving FSDP with LoRA #295
Conversation
102e94c
to
340326f
Compare
This pull request has merge conflicts that must be resolved before it can be |
limit_all_gathers=True, | ||
mixed_precision_policy=MixedPrecision( | ||
param_dtype=torch.bfloat16, | ||
reduce_dtype=torch.bfloat16, | ||
buffer_dtype=torch.bfloat16, | ||
), | ||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | ||
backward_prefetch=BackwardPrefetch.BACKWARD_POST, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the impact of making this change for non-lora usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a performance/memory tradeoff. We should have it be configurable if possible, but I can limit it to only be this option when LoRA is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backward prefetch vs. postfetch shouldn't be impacting correctness of LoRA, but not using prefetch could hurt default training times. I thing prefetch should be the default for non-lora cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 James, I'll create a follow-up issue to have this as a configurable setting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/smoketest.sh
Outdated
@@ -2,7 +2,7 @@ | |||
set -eux -o pipefail | |||
|
|||
# ############### Read-only parameters ############### | |||
MODEL_NAME="instructlab/granite-7b-lab" | |||
MODEL_NAME="/home/ec2-user/.cache/huggingface/hub/models--instructlab--granite-7b-lab/snapshots/4fb6a018d68ab813b95c7f470e424a70f2f7e561" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't always be on ec2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed it
This pull request has merge conflicts that must be resolved before it can be |
LoRA models when training with FSDP as the distributed backend. This is accomplished by creating a copy of the LoRA model on the CPU, loading in the state dict after gathering it from the distributed model, and saving after merging the adapters back into the original model. Afterwards, the CPU copy is discarded and training continues. Signed-off-by: Oleg S <[email protected]>
This commit adds a smoketest for testing LoRA + FSDP. Signed-off-by: Oleg S <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding the test!
Additionally introuce a max_seq_len parameter to support testing on lower-end hardware. Signed-off-by: Oleg S <[email protected]>
Currently we cannot save LoRA models with FSDP, this PR addresses this limitation by instantiating a copy of the model on CPU, loading in the LoRA settings, loading the state dict after it has been gathered, and finally performing the same save as we do elsewhere throughout the codebase.
Resolves #241