Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for mcore optimizer (to enable MoE) #380

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ jobs:
- dpo-llama3
- sft-llama3
- rm-llama3
- dpo-mixtral-ep
- dpo-mixtral-peft-tp-sp
with:
RUNNER: self-hosted-azure
# Fairly aggresive timeout that all functional tests should try to adhere to
Expand Down
6 changes: 5 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,20 @@ git fetch -a
# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652
# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863
# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654
# ba8edbd2063f3349c40c9c73e5bae46abbe65f94: fix: regular torch optims (e.g., sgd) no longer error with closure spec NeMo#11189
# 35a7f718237cf011215db9e92273ed7236d0e8b1: Fix for crash with LoRA + tp_overlap_comm=false + sequence_parallel=true NeMo#10920
for pr_and_commit in \
"10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \
"10652 60e677423667c029dd05875da72bf0719774f844" \
"10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \
"11189 ba8edbd2063f3349c40c9c73e5bae46abbe65f94" \
"10920 53cf6527571b29379188c8bb0dba8e507db3cca1" \
; do
pr=$(cut -f1 -d' ' <<<"$pr_and_commit")
head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit")
git fetch origin $head_pr_commit:PR-${pr}
# cherry-picks all commits between main and the top of the PR
git cherry-pick --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
git cherry-pick -m 1 --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
# Tag cherry-picks to help
git tag cherry-pick-PR-${pr}
done
Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# dpo specific args
dpo:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down Expand Up @@ -57,6 +59,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
variable_seq_lengths: ${not:model.data.pad_length_to_multiple_of}

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down Expand Up @@ -114,6 +117,7 @@ model:
data_impl: jsonl
splits_string: null
seq_length: ${model.encoder_seq_length}
pad_length_to_multiple_of: null # Use if sequence_parallel is enabled to ensure seq_length is divisible by the ...
skip_warmup: True
num_workers: 0
reset_position_ids: False # Reset position ids after end-of-document token
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_kto.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# kto specific args
kto:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
# How many steps we train warmup the critic for (without training the policy)
Expand All @@ -21,6 +22,7 @@ trainer:
max_steps: -1 # max PPO steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# PPO args to generate the data for training
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
port: 5556
Expand All @@ -15,6 +16,7 @@ trainer:

# used to set the learning rate scheduler
max_steps: 10000
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# a PyTriton parameter to specify
Expand Down
4 changes: 3 additions & 1 deletion examples/nlp/gpt/conf/gpt_rs_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

rs:
max_epochs: 1
max_steps: -1 # max rs steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# pick up from the model
Expand Down Expand Up @@ -177,4 +179,4 @@ model:
# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ trainer:
devices: 1
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

sft:
max_epochs: 1
Expand All @@ -15,6 +16,7 @@ trainer:
limit_train_batches: 1.0

limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# can be used to register any custom metrics that require token-by-token generation
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_spin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16-mixed
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# spin specific args
spin:
Expand All @@ -18,6 +19,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/training_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# rm specific args
rm:
Expand All @@ -20,6 +21,7 @@ trainer:
# set to float for a percentage
# of the validation dataset
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
29 changes: 13 additions & 16 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
Expand All @@ -37,6 +37,7 @@

OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True)
OmegaConf.register_new_resolver("not", lambda x: not x, replace=True)

mp.set_start_method("spawn", force=True)

Expand Down Expand Up @@ -85,7 +86,7 @@ def main(cfg) -> None:
# use the entire dataset
train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3

train_ds, validation_ds, test_ds = build_train_valid_test_dpo_datasets(
train_ds, validation_ds, _ = build_train_valid_test_dpo_datasets(
cfg=cfg.model,
data_prefix=cfg.model.data.data_prefix,
data_impl=cfg.model.data.data_impl,
Expand All @@ -104,13 +105,7 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
collate_fn=identity_collate,
)

val_dataloader = build_dataloader(
Expand All @@ -121,13 +116,7 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
collate_fn=identity_collate,
use_random_sampler=False,
)

Expand All @@ -147,6 +136,14 @@ def main(cfg) -> None:
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
pad_length_to_multiple_of=cfg.model.data.get("pad_length_to_multiple_of", None),
),
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/critic_server_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
59 changes: 54 additions & 5 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections import defaultdict
from statistics import mean
from typing import Any, Protocol

import torch
import torch.distributed
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm

Expand All @@ -24,13 +27,34 @@
)
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.utils import logging
from nemo_aligner.utils import parallel_state
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory


def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
class DistributedCollateFunction(Protocol):
def __call__(self, batch: list[dict], **kwargs: Any) -> dict[str, torch.Tensor]:
...


def dpo_custom_collate(
batch: list[dict],
eos_id: int,
reset_position_ids: bool = False,
reset_attention_mask: bool = False,
eod_mask_loss: bool = False,
pad_length_to_multiple_of: int | None = None,
) -> dict[str, torch.Tensor]:
"""
Transposes minibatch from list[dict] -> dict[Tensor] and also pads

This collate happens outside of the torch data loader and is not compatible with the multiprocessing
logic due to requiring communication collectives.
"""
if pad_length_to_multiple_of is not None and pad_length_to_multiple_of < 0:
raise ValueError(f"{pad_length_to_multiple_of=} must be >= 0")
chosen_tokens = [item["chosen"] for item in batch]
rejected_tokens = [item["rejected"] for item in batch]
chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch])
Expand All @@ -44,9 +68,32 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_
rejected_tokens = torch.nn.utils.rnn.pad_sequence(rejected_tokens, batch_first=True, padding_value=eos_id)
chosen_labels = torch.nn.utils.rnn.pad_sequence(chosen_labels, batch_first=True, padding_value=-100)
rejected_labels = torch.nn.utils.rnn.pad_sequence(rejected_labels, batch_first=True, padding_value=-100)
assert chosen_tokens.shape == rejected_tokens.shape
assert chosen_labels.shape == rejected_labels.shape

if pad_length_to_multiple_of:
# Assumes both chosen and rejected match
max_seq_len = torch.tensor(chosen_tokens.shape[1], device=torch.cuda.current_device())
torch.distributed.all_reduce(
max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group()
)

padded_max_len = math.ceil(max_seq_len / pad_length_to_multiple_of) * pad_length_to_multiple_of
chosen_tokens = torch.nn.functional.pad(
chosen_tokens, (0, padded_max_len - chosen_tokens.shape[1]), mode="constant", value=eos_id
)
rejected_tokens = torch.nn.functional.pad(
rejected_tokens, (0, padded_max_len - rejected_tokens.shape[1]), mode="constant", value=eos_id
)
chosen_labels = torch.nn.functional.pad(
chosen_labels, (0, padded_max_len - chosen_labels.shape[1]), mode="constant", value=-100
)
rejected_labels = torch.nn.functional.pad(
rejected_labels, (0, padded_max_len - rejected_labels.shape[1]), mode="constant", value=-100
)

attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
chosen_tokens, eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
chosen_tokens.cuda(), eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
)
assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate"
if attention_mask.shape[0] == 1:
Expand All @@ -70,8 +117,7 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_


class DPOTrainer:
"""Trainer to coordinate DPO training
"""
"""Trainer to coordinate DPO training"""

def __init__(
self,
Expand All @@ -82,6 +128,7 @@ def __init__(
train_dataloader,
val_dataloader,
test_dataloader,
collate_fn: DistributedCollateFunction,
logger,
ckpt_callback,
run_timer,
Expand All @@ -90,6 +137,7 @@ def __init__(
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
self.collate_fn = collate_fn
self.logger = logger
self.cfg = cfg
self.optimizer = optimizer
Expand Down Expand Up @@ -172,7 +220,7 @@ def train_single_step(self, global_batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down Expand Up @@ -317,6 +365,7 @@ def augment_dataloader(self, dataloader):
while True:
try:
batch = next(iter_dataloader)
batch = self.collate_fn(batch)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_chosen"] = chosen_logps
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def run_training(self, dataloader_iter):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
Loading
Loading