Skip to content

Conversation

xq25478
Copy link

@xq25478 xq25478 commented Sep 1, 2025

Support train eagle3 by deepspeed for large model like 72B/235B

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @xq25478, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request adds a dedicated Python script for training the Eagle-3 model. The primary purpose of this change is to enable scalable and memory-efficient training by integrating DeepSpeed, a deep learning optimization library. This will significantly improve the ability to train large models effectively, particularly for online learning scenarios.

Highlights

  • DeepSpeed Integration: Introduces a new script to train the Eagle-3 model using DeepSpeed, enabling efficient distributed training capabilities.
  • Online Training Support: The script is designed for online training of the Eagle-3 model, facilitating continuous learning from data streams.
  • Flexible Zero Optimization: Provides configurable support for DeepSpeed's Zero Optimization stages (0, 1, 2, and 3) to manage memory usage during training.
  • Robust Checkpointing: Implements comprehensive checkpointing, allowing training to be resumed from the last saved state and saving draft models in safetensors format.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new script for training Eagle-3 models using DeepSpeed. The script is well-structured and covers the main aspects of a distributed training pipeline, including argument parsing, distributed setup, data loading, model initialization, training loop, and checkpointing. My review focuses on improving the script's robustness, configurability, and performance. I've provided suggestions to address hardcoded values, inconsistent variable usage, potential bugs with environment variables, use of bare exceptions, and inefficient metric accumulation and reduction. These changes should make the script more reliable and easier to use in different environments.


deepspeed.init_distributed()

local_rank = int(os.getenv("LOCAL_RANK",0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local_rank is parsed from command-line arguments into args.local_rank but it's not used. Instead, local_rank is fetched from the environment variable. For consistency and robustness, it's better to use the parsed argument. This improves clarity and makes the script more robust if used with different launchers that might not set the LOCAL_RANK environment variable.

You should also use args.local_rank on lines 205 and 362.

Suggested change
local_rank = int(os.getenv("LOCAL_RANK",0))
local_rank = args.local_rank

local_rank = int(os.getenv("LOCAL_RANK",0))
rank = dist.get_rank()
world_size = dist.get_world_size()
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", 8))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding a default value of 8 for LOCAL_WORLD_SIZE is risky. It can lead to incorrect behavior or hide configuration issues if the script is run on a machine with a different number of GPUs per node and the LOCAL_WORLD_SIZE environment variable is not set. It's better to rely on the environment variable being set and fail explicitly if it's missing.

Suggested change
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", 8))
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])

train_eagle3_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number of workers for the DataLoader is hardcoded to 8. This value might not be optimal for all systems and could lead to resource exhaustion or suboptimal performance. It would be better to make this a command-line argument to allow for easier configuration depending on the execution environment.

try:
model_engine.optimizer.load_state_dict(state["optimizer_state_dict"])
print_on_rank0("Successfully loaded optimizer state_dict.")
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: is generally discouraged as it catches all exceptions, including system-exiting ones like SystemExit or KeyboardInterrupt. It's better to catch a more specific exception, like Exception, to avoid unintended behavior.

Suggested change
except:
except Exception:

try:
scheduler.load_state_dict(state["scheduler_state_dict"])
print_on_rank0("Successfully loaded scheduler state_dict.")
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: is generally discouraged as it catches all exceptions, including system-exiting ones like SystemExit or KeyboardInterrupt. It's better to catch a more specific exception, like Exception, to avoid unintended behavior.

Suggested change
except:
except Exception:

dist.barrier()
print_on_rank0(f"Starting training from epoch {start_epoch}")

global_step = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The global_step variable is initialized here and incremented in the training loop (line 383), but it's never actually used for logging, scheduling, or checkpointing. This appears to be dead code. Consider removing it if it's not needed.

loss_mask=loss_mask,
)
# Weighted loss
ploss_weight = [0.8 ** i for i in range(len(plosses))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The decay factor 0.8 for weighting losses is hardcoded. This makes it difficult to experiment with different weighting schemes without modifying the code. It would be better to define this as a configurable parameter, for example, as a command-line argument.

Comment on lines +386 to +389
epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
epoch_plosses = [
epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Re-creating lists with concatenation inside the training loop is inefficient, especially for large datasets, as it creates a new list and copies elements in every iteration. It's more performant to append items to the existing lists in-place.

Suggested change
epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
epoch_plosses = [
epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))
]
for i, acc in enumerate(acces):
epoch_acces[i].append(acc)
for i, ploss_val in enumerate(plosses):
epoch_plosses[i].append(ploss_val.item())

Comment on lines +395 to +411
for i in range(len(epoch_acces)):
acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
dist.all_reduce(acc_i)
acc_i = acc_i / dist.get_world_size()
acc_i = acc_i.item()
print_on_rank0(
f"Train Epoch [{epoch + 1}/{args.num_epochs}], position {i}, Acc: {acc_i:.4f}"
)

for i in range(len(epoch_plosses)):
loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
dist.all_reduce(loss_i)
loss_i = loss_i / dist.get_world_size()
loss_i = loss_i.item()
print_on_rank0(
f"Train Epoch [{epoch + 1}/{args.num_epochs}], position {i}, pLoss: {loss_i:.4f}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the epoch-level logging, you are converting lists of Python floats to CPU tensors and then moving them to the GPU within a loop (.cuda()). This is done for both accuracies (line 396) and losses (line 405). This pattern is inefficient due to repeated CPU-to-GPU data transfers.

A more performant approach would be to accumulate metrics directly as tensors on the correct device during the training loop. For example, you could append torch.tensor(acc, device=model_engine.device) inside the loop, and then use torch.stack(epoch_acces[i]).mean() at the end of the epoch. This would avoid the expensive transfers.

@sleepcoo
Copy link
Collaborator

sleepcoo commented Sep 8, 2025

Is the training speed improved compared to the original implementation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants