-
Notifications
You must be signed in to change notification settings - Fork 96
Add Draft LoRA scripts #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @shuaills, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the Eagle3 framework by integrating support for LoRA (Low-Rank Adaptation) based training of draft models. The primary goal is to enable more memory-efficient and faster fine-tuning of large language models by only training a small number of additional parameters. This is achieved through a new dedicated training script, a configurable LoRA setup, and necessary modifications to the model architecture for PEFT compatibility. Additionally, the PR introduces utility scripts for model download and upload to Hugging Face, and expands data preparation to include a new dataset type, streamlining the overall development and deployment workflow for fine-tuned models.
Highlights
- LoRA Training Script: Introduced a new
train_eagle3_lora_online.py
script that enables training of the Eagle3 draft model using LoRA. This script supports loading a base draft model, applying a LoRA configuration, and training it efficiently with distributed data parallelism (FSDP). - LoRA Configuration: Added a dedicated LoRA configuration file (
configs/draft_lora_trainable_config.json
) that specifies the LoRA parameters and target modules (e.g.,q_proj
,v_proj
) for the draft model, allowing for flexible and configurable LoRA fine-tuning. - PEFT Compatibility: Modified the
Eagle3DraftModel
andLlama3EagleDraftModel
to ensure full compatibility with the PEFT (Parameter-Efficient Fine-tuning) library. This includes inheriting fromGenerationMixin
and implementing necessary methods likeget_input_embeddings
andprepare_inputs_for_generation
. - Model Management Utilities: Included new utility scripts such as
download.py
for downloading models from Hugging Face andpush_to_hf.sh
for pushing trained models or LoRA adapters back to Hugging Face Hub, streamlining the model lifecycle management. - Dataset Support: Expanded data preparation capabilities by adding support for the
synth_summarize
dataset inscripts/prepare_data.py
, along with its specific processing logic, making it available for training workflows.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces scripts and configurations for LoRA training of a draft model. The changes are extensive and add significant new capabilities. My review has identified several critical security vulnerabilities related to hardcoded API keys in the example scripts, which must be addressed. Additionally, there are instances of code duplication and other maintainability issues in the Python scripts that should be resolved. I have provided specific suggestions to fix these issues.
--cache-dir $ROOT_DIR/cache \ | ||
--skip-vocab-mapping \ | ||
--wandb \ | ||
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A hardcoded wandb-key
has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ | |
--wandb-key "${WANDB_API_KEY}" \ |
--cache-dir $ROOT_DIR/cache \ | ||
--skip-vocab-mapping \ | ||
--wandb \ | ||
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A hardcoded wandb-key
has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ | |
--wandb-key "${WANDB_API_KEY}" \ |
--attention-backend flex_attention | ||
--cache-dir $ROOT_DIR/cache \ | ||
--wandb \ | ||
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A hardcoded wandb-key
has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ | |
--wandb-key "${WANDB_API_KEY}" \ |
elif args.dataset == "opc": | ||
ds = load_dataset( | ||
"OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct" | ||
)["train"] | ||
proc_fn = process_opc_sft_stage1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import os | ||
|
||
from huggingface_hub import snapshot_download | ||
|
||
|
||
def download_model(model_id, local_dir): | ||
print(f"downloading model: {model_id}") | ||
print(f"will save to: {local_dir}") | ||
|
||
try: | ||
snapshot_download( | ||
repo_id=model_id, | ||
local_dir=local_dir, | ||
local_dir_use_symlinks=False, | ||
) | ||
print("download success!") | ||
except Exception as e: | ||
print(f"error: {e}") | ||
|
||
|
||
if __name__ == "__main__": | ||
model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" | ||
save_directory = f"./{model_identifier.replace('/', '_')}" | ||
|
||
if not os.path.exists(save_directory): | ||
os.makedirs(save_directory) | ||
|
||
download_model(model_id=model_identifier, local_dir=save_directory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error handling in download_model
could be improved. When an exception occurs, the script prints to standard output and exits with status 0, which can make it difficult for calling scripts to detect failures. It's better practice to print errors to stderr
and exit with a non-zero status code. Additionally, the success message contains a non-ASCII exclamation mark, which should be replaced for consistency and to avoid potential encoding issues. This suggestion also adds the required sys
import to the top of the file.
import os
import sys
from huggingface_hub import snapshot_download
def def download_model(model_id, local_dir):
print(f"downloading model: {model_id}")
print(f"will save to: {local_dir}")
try:
snapshot_download(
repo_id=model_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
print("Download successful!")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
save_directory = f"./{model_identifier.replace('/', '_')}"
if not os.path.exists(save_directory):
os.makedirs(save_directory)
download_model(model_id=model_identifier, local_dir=save_directory)
"--dataset", | ||
type=str, | ||
choices=["ultrachat", "sharegpt", "opc"], | ||
choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
if lora_state_dict: | ||
# Save LoRA weights | ||
import safetensors.torch as st |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not hasattr(self, "_lora_logged"): | ||
print( | ||
"self.midlayer.self_attn.q_proj type:", | ||
type(self.midlayer.self_attn.q_proj), | ||
) | ||
print( | ||
"self.midlayer.self_attn.k_proj type:", | ||
type(self.midlayer.self_attn.k_proj), | ||
) | ||
print( | ||
"self.midlayer.self_attn.v_proj type:", | ||
type(self.midlayer.self_attn.v_proj), | ||
) | ||
print( | ||
"self.midlayer.self_attn.o_proj type:", | ||
type(self.midlayer.self_attn.o_proj), | ||
) | ||
self._lora_logged = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @shuaills is still working on this and this PR is not ready yet. I will leave it open |
Motivation
This PR added training scripts for LoRA. This is a draft PR, do not merge.
Modifications
Benchmark & Profiling
Acc on W/WO LoRA

Checklist