Skip to content

Conversation

shuaills
Copy link
Collaborator

@shuaills shuaills commented Aug 13, 2025

Motivation

This PR added training scripts for LoRA. This is a draft PR, do not merge.

Modifications

Benchmark & Profiling

Acc on W/WO LoRA
image

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @shuaills, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Eagle3 framework by integrating support for LoRA (Low-Rank Adaptation) based training of draft models. The primary goal is to enable more memory-efficient and faster fine-tuning of large language models by only training a small number of additional parameters. This is achieved through a new dedicated training script, a configurable LoRA setup, and necessary modifications to the model architecture for PEFT compatibility. Additionally, the PR introduces utility scripts for model download and upload to Hugging Face, and expands data preparation to include a new dataset type, streamlining the overall development and deployment workflow for fine-tuned models.

Highlights

  • LoRA Training Script: Introduced a new train_eagle3_lora_online.py script that enables training of the Eagle3 draft model using LoRA. This script supports loading a base draft model, applying a LoRA configuration, and training it efficiently with distributed data parallelism (FSDP).
  • LoRA Configuration: Added a dedicated LoRA configuration file (configs/draft_lora_trainable_config.json) that specifies the LoRA parameters and target modules (e.g., q_proj, v_proj) for the draft model, allowing for flexible and configurable LoRA fine-tuning.
  • PEFT Compatibility: Modified the Eagle3DraftModel and Llama3EagleDraftModel to ensure full compatibility with the PEFT (Parameter-Efficient Fine-tuning) library. This includes inheriting from GenerationMixin and implementing necessary methods like get_input_embeddings and prepare_inputs_for_generation.
  • Model Management Utilities: Included new utility scripts such as download.py for downloading models from Hugging Face and push_to_hf.sh for pushing trained models or LoRA adapters back to Hugging Face Hub, streamlining the model lifecycle management.
  • Dataset Support: Expanded data preparation capabilities by adding support for the synth_summarize dataset in scripts/prepare_data.py, along with its specific processing logic, making it available for training workflows.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces scripts and configurations for LoRA training of a draft model. The changes are extensive and add significant new capabilities. My review has identified several critical security vulnerabilities related to hardcoded API keys in the example scripts, which must be addressed. Additionally, there are instances of code duplication and other maintainability issues in the Python scripts that should be resolved. I have provided specific suggestions to fix these issues.

--cache-dir $ROOT_DIR/cache \
--skip-vocab-mapping \
--wandb \
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

A hardcoded wandb-key has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"

Suggested change
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
--wandb-key "${WANDB_API_KEY}" \

--cache-dir $ROOT_DIR/cache \
--skip-vocab-mapping \
--wandb \
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

A hardcoded wandb-key has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"

Suggested change
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
--wandb-key "${WANDB_API_KEY}" \

--attention-backend flex_attention
--cache-dir $ROOT_DIR/cache \
--wandb \
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

A hardcoded wandb-key has been found in the script. Committing secrets like API keys to version control is a significant security risk. This key should be removed and loaded from a secure source, such as an environment variable. You should also add a check at the beginning of the script to ensure the environment variable is set, for example: : "${WANDB_API_KEY:?WANDB_API_KEY is not set}"

Suggested change
--wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \
--wandb-key "${WANDB_API_KEY}" \

Comment on lines +188 to +192
elif args.dataset == "opc":
ds = load_dataset(
"OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct"
)["train"]
proc_fn = process_opc_sft_stage1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This elif block for args.dataset == "opc" is a duplicate of the one at lines 174-178. It should be removed to avoid redundant code and potential inconsistencies.

Comment on lines +1 to +28
import os

from huggingface_hub import snapshot_download


def download_model(model_id, local_dir):
print(f"downloading model: {model_id}")
print(f"will save to: {local_dir}")

try:
snapshot_download(
repo_id=model_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
print("download success!")
except Exception as e:
print(f"error: {e}")


if __name__ == "__main__":
model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
save_directory = f"./{model_identifier.replace('/', '_')}"

if not os.path.exists(save_directory):
os.makedirs(save_directory)

download_model(model_id=model_identifier, local_dir=save_directory)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error handling in download_model could be improved. When an exception occurs, the script prints to standard output and exits with status 0, which can make it difficult for calling scripts to detect failures. It's better practice to print errors to stderr and exit with a non-zero status code. Additionally, the success message contains a non-ASCII exclamation mark, which should be replaced for consistency and to avoid potential encoding issues. This suggestion also adds the required sys import to the top of the file.

import os
import sys

from huggingface_hub import snapshot_download


def def download_model(model_id, local_dir):
    print(f"downloading model: {model_id}")
    print(f"will save to: {local_dir}")

    try:
        snapshot_download(
            repo_id=model_id,
            local_dir=local_dir,
            local_dir_use_symlinks=False,
        )
        print("Download successful!")
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
    save_directory = f"./{model_identifier.replace('/', '_')}"

    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    download_model(model_id=model_identifier, local_dir=save_directory)

"--dataset",
type=str,
choices=["ultrachat", "sharegpt", "opc"],
choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The choice "opc" is duplicated in the choices list for the --dataset argument. The duplicate should be removed for clarity and correctness.

Suggested change
choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"],
choices=["ultrachat", "sharegpt", "opc", "synth_summarize"],


if lora_state_dict:
# Save LoRA weights
import safetensors.torch as st
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should be placed at the top of the file, as per PEP 8 guidelines. The import safetensors.torch as st on this line and import shutil on line 663 are inside the main function. Moving them to the top of the script improves code readability and structure.

Comment on lines +735 to +752
if not hasattr(self, "_lora_logged"):
print(
"self.midlayer.self_attn.q_proj type:",
type(self.midlayer.self_attn.q_proj),
)
print(
"self.midlayer.self_attn.k_proj type:",
type(self.midlayer.self_attn.k_proj),
)
print(
"self.midlayer.self_attn.v_proj type:",
type(self.midlayer.self_attn.v_proj),
)
print(
"self.midlayer.self_attn.o_proj type:",
type(self.midlayer.self_attn.o_proj),
)
self._lora_logged = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of print statements appears to be debugging code. It should be removed before merging to keep the codebase clean and avoid polluting logs.

@zyksir
Copy link
Collaborator

zyksir commented Oct 20, 2025

cc @shuaills is still working on this and this PR is not ready yet. I will leave it open

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants