Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
a05174d
first stab at ulysses that will probably fail
hamishivi Jul 22, 2025
1f5f06d
Revert "Revert "Now, we run individual prompts through the queue. (#7…
finbarrtimbers Jul 21, 2025
d1ebfba
Updated code.
finbarrtimbers Jul 22, 2025
7dd7e32
Have to switch repros
finbarrtimbers Jul 22, 2025
2254e60
Linter passes.
finbarrtimbers Jul 22, 2025
9b07049
fix bug
hamishivi Jul 22, 2025
d5d55d5
Fixed bug in queue expectations.
finbarrtimbers Jul 22, 2025
496d748
fix bug
hamishivi Jul 22, 2025
9dc0191
Ran linter.
finbarrtimbers Jul 22, 2025
26e7397
fix bug
hamishivi Jul 22, 2025
8bb1243
Added tqdm.
finbarrtimbers Jul 22, 2025
1445b0c
Ran linter.
finbarrtimbers Jul 22, 2025
c3d93ac
Merge remote-tracking branch 'origin/revert-804-revert-796-insert-pro…
hamishivi Jul 22, 2025
ee7fe6d
fixes to run
Jul 22, 2025
db914e7
A bunch of minor changes.
finbarrtimbers Jul 23, 2025
75958a1
Undo changes to script.
finbarrtimbers Jul 23, 2025
1453b40
Added some tests which now pass.
finbarrtimbers Jul 23, 2025
37f8496
Fixing indexing issue.
finbarrtimbers Jul 23, 2025
a02ac2e
Claude tried to fix the indexing error we were running into.
finbarrtimbers Jul 23, 2025
70fba6c
Fix indexing issue.
finbarrtimbers Jul 23, 2025
73490ea
Another attempt to fix the index bug.
finbarrtimbers Jul 23, 2025
f9968aa
We now have a failing test.
finbarrtimbers Jul 23, 2025
13bcdcf
Added failing tests.
finbarrtimbers Jul 23, 2025
5de7f6c
Now, all tests pass.
finbarrtimbers Jul 23, 2025
d39b5dc
current state of debugging
Jul 23, 2025
ad99ba7
Tests pass. Launching.
finbarrtimbers Jul 23, 2025
99197d8
Added a bunch of logging.
finbarrtimbers Jul 23, 2025
b8d365a
Removed most of the logging code.
finbarrtimbers Jul 23, 2025
e06fb2c
Ran linter.
finbarrtimbers Jul 23, 2025
b581d79
Created stripped down version of the tests.
finbarrtimbers Jul 23, 2025
7e024da
Ran linter.
finbarrtimbers Jul 23, 2025
9c57ab5
current changes, working-ish but probably incorrect.
Jul 23, 2025
63c5267
merge in fixed branch
hamishivi Jul 24, 2025
8ca9a25
Merge remote-tracking branch 'origin/main' into deepspeed-ulysses
hamishivi Aug 6, 2025
ed81bee
fix
hamishivi Aug 6, 2025
5739fb0
fix merge
hamishivi Aug 6, 2025
b63190d
missed some tings
hamishivi Aug 6, 2025
675a9ff
fixing it up
Aug 7, 2025
382ebd3
pos ids
hamishivi Aug 7, 2025
4fee35e
Merge remote-tracking branch 'origin/main' into deepspeed-ulysses
hamishivi Aug 7, 2025
0a66d96
lint
hamishivi Aug 7, 2025
d828eaf
reqs
hamishivi Aug 7, 2025
5101592
extra lints
hamishivi Aug 7, 2025
5416531
clean
hamishivi Aug 7, 2025
c38e95c
add timedelta
hamishivi Aug 7, 2025
46382f1
fix
Aug 8, 2025
05621ec
Merge remote-tracking branch 'origin/main' into deepspeed-ulysses
hamishivi Aug 8, 2025
98a7f7b
remove timedelta
hamishivi Aug 8, 2025
8c52684
Adding some timedeltas
hamishivi Aug 8, 2025
e986ad6
fixes
Aug 8, 2025
95c9860
lint
hamishivi Aug 8, 2025
09b47c2
lint
hamishivi Aug 8, 2025
e66f434
lint again
hamishivi Aug 8, 2025
e08874a
fix lint i hope
hamishivi Aug 8, 2025
4658dd2
Merge branch 'improve-logging' into deepspeed-ulysses
hamishivi Aug 8, 2025
b346c8e
Merge remote-tracking branch 'origin/main' into deepspeed-ulysses
hamishivi Aug 12, 2025
bd4b839
adding current working debug state
Aug 15, 2025
8b459cd
some fixes
Aug 15, 2025
69c2408
fix some deps
Aug 15, 2025
e22747c
cleaning
hamishivi Aug 17, 2025
e710970
clean up code
Aug 17, 2025
326f0df
lint
Aug 17, 2025
9a9be7a
updated reqs
Aug 17, 2025
9160ae4
Merge branch 'main' into deepspeed-ulysses
Aug 17, 2025
d8836a4
minor fix
Aug 17, 2025
032d859
Merge branch 'main' into deepspeed-ulysses
hamishivi Aug 17, 2025
72a6951
fix uv lock
Aug 17, 2025
690e5c3
lint
Aug 17, 2025
3f21465
lint2
Aug 17, 2025
2860ddf
first pass at fixing checkpoint saving
hamishivi Aug 19, 2025
75314cd
Merge remote-tracking branch 'origin/main' into deepspeed-ulysses
hamishivi Aug 19, 2025
1c7e6bc
loading also needs to remove the mpu
hamishivi Aug 19, 2025
65b7c4a
some merge fixes
hamishivi Aug 19, 2025
d024f86
back to fa2
hamishivi Aug 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 175 additions & 22 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@
# isort: off
import os
from concurrent import futures
from datetime import timedelta

# We need to set NCCL_CUMEM_ENABLE=0 for performance reasons; see:
# https://github.com/vllm-project/vllm/issues/5723#issuecomment-2554389656
os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA
try:
import deepspeed
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF
from deepspeed.utils import groups
from deepspeed.runtime.utils import move_to_device

# @vwxyzjn: when importing on CPU-only machines, we get the following error:
# RuntimeError: 0 active drivers ([]). There should only be one.
Expand All @@ -58,7 +62,6 @@
from argparse import Namespace
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from queue import Empty, Full, Queue
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional

Expand Down Expand Up @@ -114,6 +117,7 @@
ArgumentParserPlus,
BeakerRuntimeConfig,
RayProcess,
UlyssesSPSplitter,
_z3_params_to_fetch,
calibrate_checkpoint_state_dir,
clean_last_n_checkpoints_deepspeed,
Expand Down Expand Up @@ -328,6 +332,9 @@ class Args:
num_learners_per_node: List[int] = field(default_factory=lambda: [1])
"""number of GPU deepspeed learners per node (e.g., --num_learners_per_node 2 4 means 2 learner processes
on the first node and 4 learner processes on the second node; each process will have 1 GPU)"""
sequence_parallel_size: int = 1
"""sequence parallel size - how many GPUs we will parallelize sequences across during training.
Useful for super-long context lengths."""
vllm_num_engines: int = 1
"""number of vLLM Engines, set to 0 to disable vLLM"""
vllm_tensor_parallel_size: int = 1
Expand Down Expand Up @@ -571,7 +578,13 @@ def load(self, path: str, map_location=None):
self.device = torch.device(self.local_rank)
deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout))

ds_config = get_train_ds_config(offload=False, adam_offload=False, stage=args.deepspeed_stage, bf16=True)
ds_config = get_train_ds_config(
offload=False,
adam_offload=False,
stage=args.deepspeed_stage,
bf16=True,
sequence_parallel_size=args.sequence_parallel_size,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["gradient_accumulation_steps"] = 1
# @vwxyzjn: MAGIC: it's actually needed to initialize this `dschf`, so
Expand All @@ -584,6 +597,18 @@ def load(self, path: str, map_location=None):
dschf = None
logger.info(f"Deepspeed config: {dschf=}")

# set sequence parallel
self.mpu = None
if args.sequence_parallel_size > 1:
self.mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=model_config.model_name_or_path,
core_attn_implementation="flash_attention_2",
sequence_parallel_size=args.sequence_parallel_size,
max_length=args.max_token_length,
micro_batch_size=args.per_device_train_batch_size,
seq_length_is_variable=True,
)

self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
Expand Down Expand Up @@ -616,7 +641,8 @@ def load(self, path: str, map_location=None):
optimizer=self.optimizer,
config=ds_config,
lr_scheduler=scheduler,
dist_init_required=True,
dist_init_required=False,
mpu=self.mpu,
)
optimization_steps_done = 0
if args.checkpoint_state_dir:
Expand All @@ -626,13 +652,19 @@ def load(self, path: str, map_location=None):
f"Skipping loading checkpoint state from {args.checkpoint_state_dir} because it does not exist!"
)
else:
old_mpu = None
if self.mpu is not None:
old_mpu = self.mpu
self.model.mpu = None
path, states = self.model.load_checkpoint(
args.checkpoint_state_dir,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_only=False,
)
if old_mpu is not None:
self.model.mpu = old_mpu
if path is None:
raise ValueError(f"Failed to load checkpoint from {args.checkpoint_state_dir}")
optimization_steps_done = states["training_step"]
Expand Down Expand Up @@ -665,9 +697,24 @@ def load(self, path: str, map_location=None):
use_cache=False,
)
disable_dropout_in_model(self.ref_policy)
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config)
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config, mpu=self.mpu)
self.ref_policy.eval()
self.local_metrics = MetricsTracker(max_metrics=32, device=self.device)

if self.mpu is not None:
self.sp_group = groups._get_sequence_parallel_group()
self.sp_world_size = groups._get_sequence_parallel_world_size()
self.sp_rank = groups._get_sequence_parallel_rank()
self.splitter = UlyssesSPSplitter(
sp_rank=self.sp_rank,
sp_group=self.sp_group,
sp_world_size=self.sp_world_size,
device=self.device,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
self.splitter = None

return optimization_steps_done

def forward(
Expand Down Expand Up @@ -829,6 +876,47 @@ def train(
# recalculate the "real" number of mini-batches
num_mini_batches = len(collated_query_responses) // accumulation_steps

if self.splitter is not None:
with Timer("✂️ Splitting batch for SP", noop=self.rank != 0):
batch = {
"input_ids": collated_query_responses,
"attention_mask": collated_attention_masks,
"position_ids": collated_position_ids,
"response_masks": collated_response_masks,
"tool_masks": collated_tool_masks,
"advantages": collated_advantages,
}
# Pad the items in the batch so they are divisible by sp_world_size
# where attention mask is 0, we dont end up using anyway.
for i in range(len(batch["input_ids"])):
for k in batch.keys():
if torch.is_tensor(batch[k][i]):
seq_length = batch[k][i].shape[1]
if seq_length % self.sp_world_size != 0:
padding_length = self.sp_world_size - (seq_length % self.sp_world_size)
padding = torch.zeros(
(batch[k][i].shape[0], padding_length),
dtype=batch[k][i].dtype,
device=batch[k][i].device,
)
batch[k][i] = torch.cat((batch[k][i], padding), dim=1)
sharded_batches = self.splitter.split_batch(batch)

# take just the sp_rank item.
collated_query_responses = sharded_batches[self.sp_rank]["input_ids"]
collated_attention_masks = sharded_batches[self.sp_rank]["attention_mask"]
collated_position_ids = sharded_batches[self.sp_rank]["position_ids"]
collated_response_masks = sharded_batches[self.sp_rank]["response_masks"]
collated_tool_masks = sharded_batches[self.sp_rank]["tool_masks"]
collated_advantages = sharded_batches[self.sp_rank]["advantages"]

collated_query_responses = move_to_device(collated_query_responses, self.ref_policy.device)
collated_attention_masks = move_to_device(collated_attention_masks, self.ref_policy.device)
collated_position_ids = move_to_device(collated_position_ids, self.ref_policy.device)
collated_response_masks = move_to_device(collated_response_masks, self.ref_policy.device)
collated_tool_masks = move_to_device(collated_tool_masks, self.ref_policy.device)
collated_advantages = move_to_device(collated_advantages, self.ref_policy.device)

# Calculate the logprob of the reference policy
collated_ref_logprobs = []
with Timer("Inference Calculation", noop=self.rank != 0):
Expand Down Expand Up @@ -958,19 +1046,87 @@ def train(
elif args.kl_estimator == "kl4":
kl = kl4

# grpo change: directly subtract KL in loss (add)
loss = masked_mean(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis)
if args.sequence_parallel_size == 1:
# grpo change: directly subtract KL in loss (add)
loss = masked_mean(
pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis
)
else:
# SP: gather loss sums from all ranks, divide by total valid tokens
local_loss_sum = ((pg_loss_max + args.beta * kl) * mb_response_masks_bool.float()).sum()
good_tokens = mb_response_masks_bool.sum()
loss_sums_per_rank = torch.distributed.nn.functional.all_gather(
local_loss_sum, group=self.sp_group
)
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(
good_tokens, group=self.sp_group
)
total_loss_sum = sum(loss_sums_per_rank)
total_good_tokens = sum(good_tokens_per_rank)

loss = (
total_loss_sum / total_good_tokens
if total_good_tokens > 0
else torch.tensor(0.0, device=local_loss_sum.device)
)
loss = loss / accumulation_steps
self.model.backward(loss)
if (local_step + 1) % accumulation_steps == 0:
self.model.step()
local_step += 1
with torch.no_grad():
# NOTE: in packed implementation, kl calculation are averages over response tokens
kl1_stats[i] = masked_mean(kl1, mb_response_masks_bool, args.masked_mean_axis).float()
kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, args.masked_mean_axis).float()
kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, args.masked_mean_axis).float()
kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, args.masked_mean_axis).float()
if args.sequence_parallel_size == 1:
# NOTE: in packed implementation, kl calculation are averages over response tokens
kl1_stats[i] = masked_mean(kl1, mb_response_masks_bool, args.masked_mean_axis).float()
kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, args.masked_mean_axis).float()
kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, args.masked_mean_axis).float()
kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, args.masked_mean_axis).float()
pg_clipfrac_stats[i] = masked_mean(
(pg_losses2 > pg_losses).float(), mb_response_masks_bool, args.masked_mean_axis
)
pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, args.masked_mean_axis)
loss_stats[i] = loss
ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, args.masked_mean_axis)
if args.record_entropy:
# Calculate entropy statistics
entropy_stats[i] = masked_mean(
mb_entropy, mb_response_masks_bool, args.masked_mean_axis
).float()
else:
# do the rank gather thing like for the main loss.
# this is because we have to pad out to the max length
# for the whole minibatch to get ulysses to work, so
# sometimes in a microbatch we end up with all padding
# on one rank.
def gather_mean_stats(stats_tensor, mask_tensor):
local_stats_sum = (stats_tensor * mask_tensor.float()).sum()
good_tokens = mask_tensor.sum()
loss_sums_per_rank = torch.distributed.nn.functional.all_gather(
local_stats_sum, group=self.sp_group
)
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(
good_tokens, group=self.sp_group
)
total_stats_sum = sum(loss_sums_per_rank)
total_good_tokens = sum(good_tokens_per_rank)
if total_good_tokens > 0:
return total_stats_sum / total_good_tokens
else:
return torch.tensor(0.0, device=local_stats_sum.device)

kl1_stats[i] = gather_mean_stats(kl1, mb_response_masks_bool)
kl2_stats[i] = gather_mean_stats(kl2, mb_response_masks_bool)
kl3_stats[i] = gather_mean_stats(kl3, mb_response_masks_bool)
kl4_stats[i] = gather_mean_stats(kl4, mb_response_masks_bool)
pg_clipfrac_stats[i] = gather_mean_stats(
(pg_losses2 > pg_losses).float(), mb_response_masks_bool
)
pg_loss_stats[i] = gather_mean_stats(pg_loss_max, mb_response_masks_bool)
loss_stats[i] = gather_mean_stats(loss, mb_response_masks_bool)
ratio_stats[i] = gather_mean_stats(ratio, mb_response_masks_bool)
if args.record_entropy:
entropy_stats[i] = gather_mean_stats(mb_entropy, mb_response_masks_bool)
# multiply by beta
if args.kl_estimator == "kl1":
kl_loss_stats[i] = kl1_stats[i] * args.beta
elif args.kl_estimator == "kl2":
Expand All @@ -979,17 +1135,6 @@ def train(
kl_loss_stats[i] = kl3_stats[i] * args.beta
elif args.kl_estimator == "kl4":
kl_loss_stats[i] = kl4_stats[i] * args.beta
pg_clipfrac_stats[i] = masked_mean(
(pg_losses2 > pg_losses).float(), mb_response_masks_bool, args.masked_mean_axis
)
pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, args.masked_mean_axis)
loss_stats[i] = loss
ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, args.masked_mean_axis)
if args.record_entropy:
# Calculate entropy statistics
entropy_stats[i] = masked_mean(
mb_entropy, mb_response_masks_bool, args.masked_mean_axis
).float()

with torch.no_grad():
self.local_metrics.add("objective/kl_avg", kl1_stats.mean())
Expand All @@ -1009,6 +1154,11 @@ def train(

def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: Dict[str, str]) -> None:
args = self.args
# mpu is just used for sequence parallel, so we remove it for saving, and then re-add it after.
old_mpu = None
if self.model.mpu is not None:
old_mpu = self.mpu
self.model.mpu = None
self.model.save_checkpoint(checkpoint_state_dir, client_state=client_state)
# `save_checkpoint` needs to be called on all ranks, only rank 0 will have all the states
if self.rank == 0:
Expand All @@ -1019,6 +1169,9 @@ def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: Dict[st
ray.remote(sync_gs_bucket).options(num_cpus=1).remote(
checkpoint_state_dir, args.gs_checkpoint_state_dir
)
# add back the mpu
if old_mpu is not None:
self.model.mpu = old_mpu

def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None:
model_to_save = self.model
Expand Down
Loading
Loading