diff --git a/CHANGELOG.md b/CHANGELOG.md index 564f0593d..97105e325 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### New Features and Optimizations - Implement Kahneman-Tversky Optimization (KTO). - Sequence packing is now supported when running SFT with SFTChatDataset. +- Implement REINFORCE algorithm. ### Breaking Changes diff --git a/README.md b/README.md index 7c1452ad3..e993be066 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ The toolkit is currently in it's early stages. We are committed to improving the * **Reward Model Training** * **Reinforcement Learning from Human Feedback using the [PPO](https://arxiv.org/pdf/1707.06347.pdf) Algorithm** * [Llama3-70B-PPO-Chat](https://huggingface.co/nvidia/Llama3-70B-PPO-Chat) aligned with NeMo-Aligner using TRT-LLM. +* **Reinforcement Learning from Human Feedback using the REINFORCE Algorithm** + * [Llama-3.1-Nemotron-70B-Instruct](https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct) aligned with NeMo-Aligner using TRT-LLM. * **Direct Preference Optimization** as described in [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290) * [Llama3-70B-DPO-Chat](https://huggingface.co/nvidia/Llama3-70B-DPO-Chat) aligned with NeMo Aligner. * **Self-Play Fine-Tuning (SPIN)** as described in [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/pdf/2401.01335) diff --git a/docs/user-guide/reinforce.rst b/docs/user-guide/reinforce.rst new file mode 100644 index 000000000..ff0fde22d --- /dev/null +++ b/docs/user-guide/reinforce.rst @@ -0,0 +1,255 @@ +.. include:: /content/nemo.rsts + +.. _model-aligner-reinforce: + +Model Alignment by REINFORCE +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + +In this tutorial, we will guide you through the process of aligning a NeMo Framework model using REINFORCE. This method can be applied to various models, including LLaMa2 and Mistral, with our scripts functioning consistently across different models. + +REINFORCE is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the REINFORCE algorithm on the `Anthropic-HH-RLHF `__ dataset. + +REINFORCE Training +############ + +After you have fine-tuned a GPT model using Supervised Fine-Tuning (SFT), and trained a reward model as explained in the preceding section, you can start aligning the policy using REINFORCE. + +During REINFORCE training, we have three models interacting with each other, which Aligner runs in two separate jobs: + +#. The Policy Network: This is the model we are training and it should start from an SFT model. +#. The Reward Model (RM): This model accepts a prompt combined with a response as input and produces a single scalar value, known as the reward. The REINFORCE algorithm aims to maximize this reward. +#. The Initial Policy Network (also known as the Reference Model): We use this model to compute a KL Divergence penalty term that ensures that the PPO Actor does not diverge too much from the Initial Policy. This way, we prevent the REINFORCE Actor from overfitting to the rewards given by the RM, and ensure it does not forget the knowledge it acquired during pretraining and SFT. This model should be the one used to initialize the REINFORCE Actor Network. + +The next section discusses how to launch each of these two jobs. + +Launching the Reward Model and Critic Server +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +To launch the server: + +.. code-block:: bash + + #!/bin/bash + CHECKPOINT_NEMO_FILE="/path/to/trained_rm.nemo" + GPFS="/path/to/nemo-aligner-repo" + + RESULTS_DIR="critic_results_dir" + + cd ${GPFS} + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/serve_reward_model.py \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + ++model.tensor_model_parallel_size=4 \ + rm_model_file=${RM_NEMO_FILE} + + +The above example launches the reward model server on 8 GPUs and 1 node. Please make sure to change trainer.devices, trainer.num_nodes depending on your model size and scale. Aligner will work on any scale. Also, make sure to tune the trainer.reinforce.inference_micro_batch_size argument. This argument sets the size of the batch the REINFORCE actor is allowed to send to the reward per DP rank. + +Launch the Initial Policy and REINFORCE Actor Training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +The REINFORCE Actor training job contains the master controller that makes the HTTP calls to all servers when needed. To launch the REINFORCE Actor and Initial Policy server: + +.. code-block:: bash + + GPFS="/path/to/nemo-aligner-repo" + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + PRETRAINED_ACTOR_NEMO_FILE="/path/to/sft_checkpoint.nemo" + RESULTS_DIR="/path/to/actor_results_dir" + + USE_FLASK=False + ACTOR_LR=1e-6 + KL=0.01 + NUM_ROLLOUTS=32 + ACTOR_GBS=32 + REWARD_PORT=5555 + host_reward="$(scontrol show hostnames=$SLURM_JOB_NODELIST_HET_GROUP_0 | head -n1)" + + cd ${GPFS} + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/train_gpt_reinforce_actor.py \ + "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ + pretrained_checkpoint.restore_from_path=\"${ACTOR_NEMO_FILE}\" \ + exp_manager.checkpoint_callback_params.save_top_k=1 \ + exp_manager.explicit_log_dir=\"${RESULTS_DIR}\" \ + trainer.reinforce.max_epochs=1 \ + trainer.reinforce.max_steps=313 \ + trainer.reinforce.val_check_interval=4 \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + trainer.reinforce.trt_llm.enable=True \ + trainer.reinforce.trt_llm.reshard=True \ + trainer.reinforce.trt_llm.unload_engine_train=False \ + ++model.tensor_model_parallel_size=4 \ + ++model.reinforce.num_rollout_samples=${NUM_ROLLOUTS} \ + model.global_batch_size=${ACTOR_GBS} \ + model.micro_batch_size=1 \ + model.optim.lr=\"\\\$\{multiply:${ACTOR_LR},1.001\}\" \ + model.optim.sched.warmup_steps=0 \ + model.optim.sched.constant_steps=312 \ + model.optim.sched.min_lr=${ACTOR_LR} \ + model.optim.weight_decay=0.01 \ + model.reinforce.rollout_micro_batch_size=16 \ + model.reinforce.forward_micro_batch_size=16 \ + model.reinforce.val_rollout_micro_batch_size=8 \ + model.data.data_impl=jsonl \ + remote_rm.reward_model.ip=${host_reward} \ + remote_rm.reward_model.port=${REWARD_PORT} \ + ++model.reinforce.length_params.max_length=2048 \ + trainer.reinforce.initial_policy_kl_penalty="${KL}" \ + ++model.optim.bucket_cap_mb=200 \ + ++model.dist_ckpt_format=zarr \ + ++model.optim.overlap_grad_sync=False \ + ++model.optim.contiguous_grad_buffer=True \ + ++model.enable_nge=True \ + trainer.reinforce.batch_iterator.use_flask=${USE_FLASK} \ + trainer.reinforce.rollout_batch_seq_length=4096 + +The above command launches the initial and actor server on 1 node with 8 GPUs. + +Launching Both Servers for REINFORCE training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +You can use slurm to launch the 2 jobs and get them to coordinate together in a full REINFORCE job via the following: + +.. code-block:: bash + + #!/bin/bash + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + #SBATCH hetjob + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + + NAME="2p_reinforce" + + # PARAMETERS + RM_NEMO_FILE="/path/to/trained_rm.nemo" + + ACTOR_NEMO_FILE="/path/to/sft_model.nemo" + + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + RESULTS_DIR="/path/to/results_dir" + mkdir -p $RESULTS_DIR + + GPFS="/path/to/nemo-aligner-repo" + MOUNTS="--container-mounts=MOUNTS" # mounts + + CONTAINER=<<>> # use the latest NeMo Training container, Aligner will work there + + PROJECT=reinforce_run + + CRITIC_LOG_DIR="${RESULTS_DIR}/critic_results" + CRITIC_OUTFILE="${CRITIC_LOG_DIR}/critic_output_%j_%t.log" + CRITIC_ERRFILE="${CRITIC_LOG_DIR}/critic_error_%j_%t.err" + REWARD_PORT=5567 + CRITIC_CONFIG_PATH="${GPFS}/examples/nlp/gpt/conf" + CRITIC_CONFIG_NAME="inference_rm" + + CONF_DIR="${GPFS}/examples/nlp/gpt/conf" + CONFIG_NAME="gpt_reinforce_actor" + + mkdir -p $CRITIC_LOG_DIR + + CRITIC_NAME="${NAME}_critic" + + read -r -d '' cmd_critic_inference <`__ script from the NeMo codebase to run more rigorous evaluation of your trained model. diff --git a/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml new file mode 100644 index 000000000..b2b946f1d --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml @@ -0,0 +1,218 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + # these args are respected + num_nodes: 8 + devices: 8 + accelerator: gpu + precision: bf16 + + reinforce: + + max_epochs: 1 + max_steps: -1 # max REINFORCE steps (-1 to go through the whole train set) + val_check_interval: 10 + save_interval: ${.val_check_interval} + gradient_clip_val: 1.0 + + # REINFORCE args to generate the data for training + initial_policy_kl_penalty: 0.01 + use_absolute_kl: True + num_rollouts_per_prompt: 4 + + + # the sequence length to pad the rollout batch for training to + # reduce fragmentation at the cost of using more + # memory, set to null if we don't want to pad it + # to a constant size + # if actual seq length is higher than this a warning will be raised + # but will not crash and training will still proceed on the larger + # sequence length + rollout_batch_seq_length: null + + # Speed-up training by accelerating inference stage using TRTLLM + trt_llm: + enable: False + reshard: False # if True then reshard the model into TP only for inference + + # TRTLLM preallocates activation memory according to the number of input tokens + # By default, assume the max input length is the difference between the model sequence length and the max number of tokens to generate + max_input_len: ${subtract:${model.encoder_seq_length}, ${model.reinforce.length_params.max_length}} + max_input_tokens: ${multiply:${.max_input_len}, ${model.reinforce.rollout_micro_batch_size}} + + # the seed to use for trt-llm generation + seed: ${model.seed} + + # for supported values see: https://github.com/NVIDIA/NeMo/blob/db6244857af3b012f645c7f4672522978bb608b1/nemo/export/trt_llm/converter/utils.py#L26 + model_type: llama # can be gptj, gptnext, llama, gemma, falcon + + # Save GPU memory by unloading and reloading the TRTLLM engine before and after the training stage + # Reloading the engine incurs a constant time overhead + unload_engine_train: False + + batch_iterator: + # When use_flask is True, we will spawn a flask server on rank 0 to balance the work of policy rollouts. + # This option is useful in cases where the generation length varies greatly across DP ranks since + # the flask server will allow DP ranks with shorter responses to process more samples and DP ranks + # with longer responses to process less samples. Thereby lowering the DP wait time. + use_flask: False + port: 5557 + + # pick up from the model + # *do not change this* + model_gbs: ${model.global_batch_size} + model_mbs: ${model.micro_batch_size} + + # no need to change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.reinforce.max_epochs} + max_steps: ${.reinforce.max_steps} + +remote_rm: + # what to batch the inputs to + # set to None if no batching when sending inference to the reward model + pad_to_length: ${model.encoder_seq_length} + + # reward model server + reward_model: + name: reward_model + ip: localhost + port: 5555 + + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt_reinforce_actor + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_reinforce + name: gpt3_reinforce_2b + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_rewards + save_top_k: 1 + mode: max + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt-{step}-{consumed_samples}-{reinforce_optimization_step}-{epoch}-{val_rewards:.3f}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +pretrained_checkpoint: + restore_from_path: null + +model: + + reinforce: + # training generation mbs + rollout_micro_batch_size: 8 + num_rollout_samples: 512 + + # mbs to do log prob inference, can be set to + # lower than rollout_micro_batch_size to reduce + # memory usage + forward_micro_batch_size: ${.rollout_micro_batch_size} + + # val generation mbs + val_rollout_micro_batch_size: ${.rollout_micro_batch_size} + num_val_samples: ${.num_rollout_samples} + + # to offload during generation or not + offload_adam_states: True + + # params for generation + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 0 + top_p: 1.0 + repetition_penalty: 1.0 + add_BOS: False + all_probs: False + compute_logprob: False + # will be used in NeMo version > 1.20.0 + # keeping it for now + end_strings: ["<|endoftext|>", ""] + + # length argument for autoregressive sampling + # max length means max amount of tokens to generate + length_params: + max_length: ${int_div:${model.encoder_seq_length}, 2} + min_length: 1 + + trt_llm: ${trainer.reinforce.trt_llm} + + #peft + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + mcore_gpt: True + # these control the mbs/gbs during REINFORCE training + micro_batch_size: 1 + global_batch_size: 64 + megatron_amp_O2: True + + encoder_seq_length: 4096 + max_position_embeddings: ${model.encoder_seq_length} + + ## Sequence Parallelism + sequence_parallel: False + + # miscellaneous + seed: 1234 + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-7 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-8 + + precision: ${trainer.precision} + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: null + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True diff --git a/examples/nlp/gpt/train_gpt_reinforce_actor.py b/examples/nlp/gpt/train_gpt_reinforce_actor.py new file mode 100644 index 000000000..589dfca1f --- /dev/null +++ b/examples/nlp/gpt/train_gpt_reinforce_actor.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch +import torch.multiprocessing as mp +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.reinforce import ReinforceTrainer +from nemo_aligner.data.nlp.builders import ( + build_dataloader, + build_train_valid_test_rlhf_datasets, + collate_with_pad_to_max_batch, +) +from nemo_aligner.models.nlp.gpt.megatron_gpt_reinforce_actor import MegatronGPTReinforceActorModel +from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMClient +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.batch_iterators import get_batch_iterator_cls +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + resolve_and_create_trainer, + retrieve_custom_trainer_state_dict, +) +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu + +"""Script to start REINFORCE training""" + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) +OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) +OmegaConf.register_new_resolver("subtract", lambda x, y: x - y, replace=True) + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="gpt_reinforce_actor") +def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + trainer = resolve_and_create_trainer(cfg, "reinforce") + + exp_manager(trainer, cfg.exp_manager) + + logger = CustomLoggerWrapper(trainer.loggers) + + ptl_model = load_from_nemo( + MegatronGPTReinforceActorModel, + cfg.model, + trainer, + strict=True, + restore_path=cfg.pretrained_checkpoint.restore_from_path, + ) + + init_peft(ptl_model, cfg.model) + + init_policy_state_dict = None + + # only need this if we are running with inital kl penalty & full-parameter tuning + if cfg.trainer.reinforce.initial_policy_kl_penalty > 0 and cfg.model.peft.peft_scheme == "none": + init_policy_state_dict = retrieve_model_state_dict_in_cpu( + ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False) + ) + + ptl_model.init_policy_state_dict = init_policy_state_dict + + # pull values from checkpoint + trainer_restore_path = trainer.ckpt_path + + # TODO: log this restore path + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + else: + custom_trainer_state_dict = None + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the entire dataset + train_valid_test_num_samples = [-1, -1, -1] + train_ds, validation_ds, _ = build_train_valid_test_rlhf_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl=cfg.model.data.data_impl, + splits_string=cfg.model.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.seed, + tokenizer=ptl_model.tokenizer, + ) + + max_seqlen = cfg.model.reinforce.length_params.max_length + eos_id = ptl_model.tokenizer.eos_id + + # collate fn to pad to the max seq length in the batch + collate_fn = collate_with_pad_to_max_batch(max_seqlen, eos_id, cfg, generate_masks_and_position_ids=False) + + train_dataloader_builder = partial( + build_dataloader, + cfg=cfg, + dataset=train_ds, + mbs=cfg.model.reinforce.rollout_micro_batch_size, + gbs=cfg.model.reinforce.num_rollout_samples, + collate_fn=collate_fn, + load_gbs=False, + ) + + val_dataloader_builder = partial( + build_dataloader, + cfg=cfg, + dataset=validation_ds, + mbs=cfg.model.reinforce.val_rollout_micro_batch_size, + gbs=cfg.model.reinforce.num_val_samples, + collate_fn=collate_fn, + load_gbs=False, + use_random_sampler=False, + ) + + # nemo uses the train dataloader to figure out + # max steps to take when max_steps = -1 + # but our train dataloader is for the prompts + # so we instaniate a dummy dataloader + # to get the proper max *optimization* steps + # nemo treats batch size of normal dataloader as GBS/DP + # so we need to offset it by DP + dummy_train_dataloader = torch.utils.data.DataLoader( + dataset=train_ds, batch_size=divide(cfg.model.global_batch_size, parallel_state.get_data_parallel_world_size()) + ) + + init_using_ptl(trainer, ptl_model, dummy_train_dataloader, train_ds) + # make sure the dummy train dataloader is never used + del ptl_model._train_dl + del dummy_train_dataloader + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + rm = RemoteGPTRMClient(cfg.remote_rm) + timer = Timer(cfg.exp_manager.get("max_time_per_run")) + + batch_iterator_cfg = cfg.trainer.reinforce.get("batch_iterator", {}) + batch_iterator_cls = get_batch_iterator_cls(batch_iterator_cfg) + + reinforce_trainer = ReinforceTrainer( + cfg=cfg.trainer.reinforce, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader_builder=train_dataloader_builder, + val_dataloader_builder=val_dataloader_builder, + collate_fn=collate_fn, + rm=rm, + batch_iterator_cls=batch_iterator_cls, + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + reinforce_trainer.load_state_dict(custom_trainer_state_dict) + + reinforce_trainer.fit() + + # Note: The main loop creates multiple HTTPCommunicators which own a + # pytriton.client.FuturesModelClient. At the end of the loop, we manually + # close all FuturesModelClients since we do not use the context manager + # syntax. This guarantees all dangling threads are no longer blocking. + # `atexit` does not suffice since the registered cleanup function can be + # queued behind another blocking atexit registered function. + # TODO: utilize context managers to avoid manual cleanup + rm.communicator.close() + + +if __name__ == "__main__": + main() diff --git a/nemo_aligner/algorithms/reinforce.py b/nemo_aligner/algorithms/reinforce.py new file mode 100644 index 000000000..3e4b8e2a8 --- /dev/null +++ b/nemo_aligner/algorithms/reinforce.py @@ -0,0 +1,612 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from collections import UserDict +from contextlib import nullcontext +from typing import Dict, List, Optional, Union + +import pandas as pd +import torch +from megatron.core import parallel_state as mcore_parallel_state +from megatron.core.utils import divide +from omegaconf.dictconfig import DictConfig +from tqdm import tqdm +from typing_extensions import Self + +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingRandomSampler +from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split +from nemo.utils import logging +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.distributed import ( + SyncTimer, + all_reduce_dict, + masked_global_mean_var, + normalize_tensor, + rebalance_nd_tensor, +) +from nemo_aligner.utils.parallel_state import is_trt_llm_reshard, trt_llm_reshard_region +from nemo_aligner.utils.ppo_utils import calculate_kl_penalty, calculate_rloo_baseline, create_mask +from nemo_aligner.utils.server_utils import FutureResult +from nemo_aligner.utils.train_utils import clip_gradients +from nemo_aligner.utils.trainer_utils import check_progress, compute_num_steps_per_epoch +from nemo_aligner.utils.utils import clear_memory, cpu_dict, masked_mean + + +class ReinforceRolloutBatch(UserDict): + @classmethod + def from_rollout_batches( + cls: Self, rollout_batches: List[Dict], eos_id: int, rollout_batch_seq_length: Optional[int] + ) -> Self: + """Given a list of rollout batches, stack the tensors within and put them in a single dictionary + """ + stacked_dict = cls() + + for k in sorted(rollout_batches[0]): + + list_of_tensors = [item[k] for item in rollout_batches] + + if all(x.ndim == 1 for x in list_of_tensors): + tensor = torch.cat(list_of_tensors) + else: + pad_value = eos_id if k == "response_tokens" else 0 + + list_of_tensors = [row.flatten() for tensor in list_of_tensors for row in tensor] + # TODO: can we avoid padding locally then padding globally? + tensor = torch.nn.utils.rnn.pad_sequence(list_of_tensors, batch_first=True, padding_value=pad_value) + + # find the max sequence length globally + max_seqlen = torch.tensor([tensor.size(-1)], dtype=torch.long, device=torch.cuda.current_device()) + torch.distributed.all_reduce(max_seqlen, op=torch.distributed.ReduceOp.MAX) + + if rollout_batch_seq_length is None or max_seqlen >= rollout_batch_seq_length: + pad_seq_len = max_seqlen.item() + else: + # response tokens must be B x S because computing log probs requires us to offset by 1 + pad_seq_len = rollout_batch_seq_length if k == "response_tokens" else rollout_batch_seq_length - 1 + + tensor = torch.nn.functional.pad(tensor, (0, pad_seq_len - tensor.size(-1)), value=pad_value) + + stacked_dict[k] = tensor + + return stacked_dict + + def gather_and_balance_globally(self): + global_rollout_batch = type(self)() + + for k, tensor in self.data.items(): + # with reshard enabled, PP groups turn into DP groups. So need to balance them first and then + # balance by all the original DP groups + # NOTE: this logic needs to use the pure parallel state, that is one without sharding but needs + # to ping the is_trt_llm_reshard variable + if is_trt_llm_reshard(): + tensor = rebalance_nd_tensor(tensor, group=mcore_parallel_state.get_pipeline_model_parallel_group()) + + tensor = rebalance_nd_tensor(tensor, group=mcore_parallel_state.get_data_parallel_group()) + global_rollout_batch[k] = tensor + + return global_rollout_batch + + def chunk(self, rank, split_size, seed): + chunked_rollout_batch = type(self)() + + batch_set = set(tensor.size(0) for tensor in self.data.values()) + assert len(batch_set) == 1, "batch sizes are not the same across the rollout batch" + B = batch_set.pop() + + g_cpu = torch.Generator() + g_cpu.manual_seed(seed) + indices = torch.arange(B) + + for k in self.data: + chunked_rollout_batch[k] = self.data[k][indices].clone() + + return chunked_rollout_batch + + +def compute_num_rollout_microbatches(dataloader): + return divide( + divide(dataloader.batch_sampler.global_batch_size, dataloader.batch_sampler.micro_batch_size), + parallel_state.get_data_parallel_world_size(), + ) + + +class ReinforceTrainer: + """Trainer to coordinate REINFORCE training + """ + + def __init__( + self, + cfg: DictConfig, + model, + optimizer, + scheduler, + train_dataloader_builder, + val_dataloader_builder, + collate_fn, + rm, + batch_iterator_cls, + logger, + ckpt_callback, + run_timer, + ): + self.cfg = cfg + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.train_dataloader_builder = train_dataloader_builder + self.val_dataloader_builder = val_dataloader_builder + self.collate_fn = collate_fn + self.rm = rm + self.batch_iterator_cls = batch_iterator_cls + self.logger = logger + self.ckpt_callback = ckpt_callback + + # this timer checks if we should stop training + self.run_timer = run_timer + + self.trtllm_reshard = "trt_llm" in cfg and cfg.trt_llm.enable and cfg.trt_llm.reshard + + self.consumed_samples = 0 + # the step here is REINFORCE step + self.step = 0 + # keep track of how many times we optimized the actor + self.reinforce_optimization_step = 0 + + # compute `max_steps` + train_dataloader = self.train_dataloader_builder(consumed_samples=0) + if (not isinstance(train_dataloader.batch_sampler, MegatronPretrainingRandomSampler)) and ( + self.cfg.max_epochs is not None and self.cfg.max_epochs > 1 + ): + # if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py) + # then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling + # to fix this, you should use MegatronPretrainingRandomSampler instead, which alleviates this issue and allows + # random shuffling for each epoch. + raise ValueError( + "max_epochs > 1 is not supported unless using `MegatronPretrainingRandomSampler` as the batch_sampler for your train dataloader" + ) + + self.num_steps_per_epoch = compute_num_steps_per_epoch(train_dataloader.batch_sampler) + self.set_max_steps() + + self.compute_init_policy_kl = self.cfg.initial_policy_kl_penalty > 0 + # size to pad our rollout batch to + self.rollout_batch_seq_length = self.cfg.rollout_batch_seq_length + + # for wandb table + self.train_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + self.val_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + + self.timer = SyncTimer( + reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX + ) + + def generate_reinforce_data(self, rollout_batch): + """generate reinforce specific data for training + """ + reinforce_rollout_data = {} + reinforce_rollout_metrics = {} + + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + prompt_tokens = rollout_batch["prompt_tokens"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + logprobs = rollout_batch["logprobs"] + is_end = rollout_batch["is_end"] + + if self.compute_init_policy_kl: + init_policy_kl = calculate_kl_penalty( + log_probs_a=rollout_batch["logprobs"], + log_probs_b=rollout_batch["init_logprobs"], + use_absolute_kl=self.cfg.use_absolute_kl, + ) + else: + init_policy_kl = torch.tensor(0, dtype=logprobs.dtype, device=logprobs.device) + + mask = create_mask(values=logprobs, prompt_lengths=prompt_lengths, response_lengths=response_lengths) + + init_policy_kl = masked_mean(init_policy_kl, mask, dim=-1) + rewards_with_kl = rewards - self.cfg.initial_policy_kl_penalty * init_policy_kl + + baseline = calculate_rloo_baseline(prompts=prompt_tokens, reward=rewards_with_kl, mask=is_end.float()) + + # collect everything we need to train REINFORCE + reinforce_rollout_data["mask"] = mask + reinforce_rollout_data["rewards_with_kl"] = rewards_with_kl + reinforce_rollout_data["baseline"] = baseline + reinforce_rollout_data["response_tokens"] = response_tokens + reinforce_rollout_data["is_end"] = is_end + + # compute metrics + # these are not global yet + reinforce_rollout_metrics["init_policy_kl"] = init_policy_kl.sum().item() if self.compute_init_policy_kl else 0 + reinforce_rollout_metrics["rewards_with_kl"] = rewards_with_kl.sum().item() + reinforce_rollout_metrics["num_samples"] = prompt_lengths.size(0) + + # now the metrics are global + reinforce_rollout_metrics = all_reduce_dict( + reinforce_rollout_metrics, + group=parallel_state.get_data_parallel_group(), + op=torch.distributed.ReduceOp.SUM, + ) + num_samples = reinforce_rollout_metrics.pop("num_samples") + reinforce_rollout_metrics = {k: v / num_samples for k, v in reinforce_rollout_metrics.items()} + + return reinforce_rollout_data, cpu_dict(reinforce_rollout_metrics) + + def _run_inference(self, dataloader_builder, consumed_samples, is_validation): + """this function is run per DP so the metrics need to be computed globally + assumes that the dataloader is built with the proper consumed samples value + """ + reshard_context = trt_llm_reshard_region if self.trtllm_reshard else nullcontext + + rollout_batches, futures = [], [] + timer_metrics = {} + + with reshard_context(): + # dataloader must be built within the reshard context because it uses DP rank and size + dataloader = dataloader_builder(consumed_samples=consumed_samples) + sampler_iter = iter(dataloader.batch_sampler) + + # must compute the number of microbatches in the reshard context + # so the DP groups are correct + num_microbatches = compute_num_rollout_microbatches(dataloader) + + self.timer.start("batch_iterator_init") + batch_iterator = self.batch_iterator_cls( + sampler_iter, num_microbatches, dataloader.dataset, self.collate_fn + ) + timer_metrics["batch_iterator_init"] = self.timer.stop_and_get_time("batch_iterator_init") + + self.timer.start("generate") + for batch in batch_iterator: + if not is_validation: + for _ in range(self.cfg.num_rollouts_per_prompt): + rollout_batch = self.model.infer(batch) + rollout_batch["prompt_tokens"] = batch["text"] + rollout_batches.append(rollout_batch) + futures.append(self.rm.infer_rm(rollout_batch)) + else: + rollout_batch = self.model.infer(batch) + rollout_batches.append(rollout_batch) + futures.append(self.rm.infer_rm(rollout_batch)) + + timer_metrics["generate"] = self.timer.stop_and_get_time("generate") + + unbalanced_local_batch = ReinforceRolloutBatch.from_rollout_batches( + rollout_batches, + eos_id=self.model.tokenizer.eos_id, + rollout_batch_seq_length=self.cfg.rollout_batch_seq_length, + ) + global_rollout_batch = unbalanced_local_batch.gather_and_balance_globally() + + padded_rollout_sequence_length = global_rollout_batch["response_tokens"].size(-1) + + # the chunking must be outside of the TRT-LLM context because we do logprob calculation in nemo + balanced_local_batch = global_rollout_batch.chunk( + rank=parallel_state.get_data_parallel_rank(), + split_size=parallel_state.get_data_parallel_world_size(), + seed=self.step, + ) + # since we compute the logprobs in nemo we need to disable the resharding + batched_response_tokens = balanced_local_batch["response_tokens"] + + self.timer.start("logprobs") + rollout_logprobs = self.model.get_inference_log_probs(batched_response_tokens) + balanced_local_batch["logprobs"] = rollout_logprobs + timer_metrics["logprobs"] = self.timer.stop_and_get_time("logprobs") + + compute_init_policy_kl = not is_validation and self.compute_init_policy_kl + if compute_init_policy_kl: + self.timer.start("init_logprobs") + rollout_init_logprobs = self.model.get_init_policy_logprobs(batched_response_tokens) + balanced_local_batch["init_logprobs"] = rollout_init_logprobs + timer_metrics["init_logprobs"] = self.timer.stop_and_get_time("init_logprobs") + + # we send the request in sharded context, so we need to keep this sharding and then undo it + with reshard_context(): + self.timer.start("rm_wait") + rm_rollout_batches = [] + for future in futures: + rewards = future.result().squeeze(1) + rm_rollout_batches.append({"rewards": rewards}) + timer_metrics["rm_wait"] = self.timer.stop_and_get_time("rm_wait") + + unbalanced_rm_batch = ReinforceRolloutBatch.from_rollout_batches( + rm_rollout_batches, + eos_id=self.model.tokenizer.eos_id, + rollout_batch_seq_length=padded_rollout_sequence_length, + ) + global_rm_batch = unbalanced_rm_batch.gather_and_balance_globally() + + # chunking needs to be outside of reshard region + # NOTE: the seed here must be the same as the chunk before since we need to shuffle + # these values the same way as the other values + balanced_rm_batch = global_rm_batch.chunk( + rank=parallel_state.get_data_parallel_rank(), + split_size=parallel_state.get_data_parallel_world_size(), + seed=self.step, + ) + balanced_local_batch.update(balanced_rm_batch) + + global_rollout_batch.update(global_rm_batch) + + return balanced_local_batch, cpu_dict(self.compute_rollout_metrics(global_rollout_batch)), timer_metrics + + def compute_rollout_metrics(self, rollout_batch): + table = {} + + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + is_end = rollout_batch["is_end"] + + # take the first sample for logging + reward = rewards[0] + prompt_length = prompt_lengths[0] + response_length = response_lengths[0] + response_token = response_tokens[0] + + table["reward"] = reward.item() + table["prompt"] = self.model.tokenizer.ids_to_text(response_token[:prompt_length].tolist()) + table["response"] = self.model.tokenizer.ids_to_text(response_token[prompt_length:response_length].tolist()) + + metrics = { + "table": table, + "rollout_size": prompt_lengths.size(0), + "response_lengths": response_lengths.float().mean().item(), + "prompt_lengths": prompt_lengths.float().mean().item(), + "generation_length": (response_lengths - prompt_lengths).float().mean().item(), + "rewards": rewards.mean().item(), + "fraction_of_samples_properly_ended": is_end.float().mean().item(), + } + + return metrics + + @torch.no_grad() + def run_validation(self): + self.model.prepare_for_inference() + + _, rollout_metrics, _ = self._run_inference( + self.val_dataloader_builder, consumed_samples=0, is_validation=True + ) + + self.model.finish_inference() + return rollout_metrics + + @torch.no_grad() + def generate_rollouts(self): + timing_metrics = {} + + self.timer.start("prepare_for_inference") + self.model.prepare_for_inference() + timing_metrics["prepare_for_inference"] = self.timer.stop_and_get_time("prepare_for_inference") + + rollout_batch, rollout_metrics, timer_metrics = self._run_inference( + self.train_dataloader_builder, consumed_samples=self.consumed_samples, is_validation=False + ) + + self.consumed_samples += rollout_metrics["rollout_size"] + + reinforce_rollout_data, reinforce_rollout_metrics = self.generate_reinforce_data(rollout_batch) + + self.timer.start("finish_inference") + self.model.finish_inference() + timing_metrics["finish_inference"] = self.timer.stop_and_get_time("finish_inference") + + timing_metrics.update(timer_metrics) + + return ( + reinforce_rollout_data, + rollout_metrics | reinforce_rollout_metrics | {"consumed_samples": self.consumed_samples}, + timing_metrics, + ) + + def run_training(self, dataloader_iter): + self.model.prepare_for_training() + + for batch in dataloader_iter: + self.timer.start("train_step_time") + self.optimizer.zero_grad() + + self.model.prepare_for_training_step() + loss_mean, metrics = self.model.get_loss_and_metrics(batch=batch, forward_only=False) + self.model.finish_training_step() + + grad_norm = clip_gradients(self.model, self.cfg.gradient_clip_val) + grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + lr = self.optimizer.param_groups[0]["lr"] + + self.optimizer.step() + self.scheduler.step() + + if grad_norm is not None: + metrics["grad_norm"] = grad_norm + if lr is not None: + # Some optimizers like adafactor do not require a LR in their initializer + metrics["lr"] = lr + + metrics.update({"loss": loss_mean, "optim_step": self.reinforce_optimization_step}) + metrics["train_step_time"] = self.timer.stop_and_get_time("train_step_time") + + self.logger.log_metrics( + metrics, step=self.step, prefix="train_optim/", + ) + + self.reinforce_optimization_step += 1 + + self.model.finish_training() + + # zero grad again incase it frees up grad mem + self.optimizer.zero_grad() + return loss_mean, metrics + + def fit(self): + epoch_iter = range(self.epoch, self.cfg.max_epochs) + if len(epoch_iter) <= 0: + # epoch done + return + + for _ in epoch_iter: + num_steps_in_epoch = min( + self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch + ) + loop_iter = range(num_steps_in_epoch) + + if not loop_iter: + return # training ended + + global_pbar = tqdm( + loop_iter, initial=self.step, total=self.max_steps, leave=True, desc="REINFORCE Global Step" + ) + + dp_size = parallel_state.get_data_parallel_world_size() + + num_to_load_on_each_dp = divide(self.cfg.model_gbs, dp_size) + + self.run_timer.start_time() + for _ in global_pbar: + step_metrics = {} + timing_metrics = {} + + self.timer.start("rollout_time") + clear_memory() + reinforce_rollout_data, metrics, timer_metrics = self.generate_rollouts() + timing_metrics["rollout_time"] = self.timer.stop_and_get_time("rollout_time") + + timer_metrics = all_reduce_dict(timer_metrics, op=torch.distributed.ReduceOp.MAX) + timing_metrics.update(timer_metrics) + + # logging + table_metrics = metrics.pop("table") + self.train_df.loc[len(self.train_df)] = [ + self.step, + table_metrics["prompt"], + table_metrics["response"], + table_metrics["reward"], + ] + metrics["epoch"] = self.epoch + 1 + self.logger.log_metrics( + metrics, step=self.step, prefix="train_rollouts/", + ) + self.logger.log_table( + key="table/train_rollouts", dataframe=self.train_df, step=self.step, + ) + + rollout_size = reinforce_rollout_data["response_tokens"].size(0) + rollout_dataloader_iter = get_iterator_k_split( + reinforce_rollout_data, divide(rollout_size, num_to_load_on_each_dp) + ) + # start training + clear_memory() + self.timer.start("train_time") + self.run_training(rollout_dataloader_iter) + timing_metrics["train_time"] = self.timer.stop_and_get_time("train_time") + + self.logger.log_metrics(timing_metrics, step=self.step, prefix="timers/") + + self.step += 1 + + run_time_exceeded = self.run_timer.is_finished() + run_val, save_model, is_train_end = check_progress( + self.step, + self.max_steps, + self.cfg.val_check_interval, + self.cfg.save_interval, + 1.0, # TODO:(geshen): allow for limit val batches + run_time_exceeded=run_time_exceeded, + ) + + if run_val: + self.timer.start("validation_time") + val_metrics = self.run_validation() + timing_metrics["validation_time"] = self.timer.stop_and_get_time("validation_time") + + val_table_metrics = val_metrics.pop("table") + + self.val_df.loc[len(self.val_df)] = [ + self.step, + val_table_metrics["prompt"], + val_table_metrics["response"], + val_table_metrics["reward"], + ] + self.logger.log_metrics(val_metrics, step=self.step, prefix="val_rollouts/") + self.logger.log_table("table/val_rollouts", dataframe=self.val_df, step=self.step) + + step_metrics.update({f"val_{k}": v for k, v in val_metrics.items()}) + + step_metrics.update(timing_metrics) + step_metrics.update({f"train_{k}": v for k, v in metrics.items()}) + global_pbar.set_postfix(step_metrics) + + if save_model: + step_metrics = {k: torch.as_tensor(v) for k, v in step_metrics.items()} + self.save(step_metrics, is_train_end=is_train_end) + + if run_time_exceeded: + logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run") + return + + self.logger.finalize() + + def state_dict(self): + return { + "step": self.step, + "consumed_samples": self.consumed_samples, + "epoch": self.epoch, + "reinforce_optimization_step": self.reinforce_optimization_step, + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.consumed_samples = state_dict["consumed_samples"] + self.reinforce_optimization_step = state_dict["reinforce_optimization_step"] + + loaded_values = [self.step, self.consumed_samples, self.reinforce_optimization_step] + + # make sure everyone loaded the same checkpoint as rank 0 + to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(to_broadcast, 0) + + assert loaded_values == to_broadcast.tolist() + # restore max steps we need to run for + self.set_max_steps() + + def save(self, extra_candidates=None, is_train_end=False): + self.model.prepare_for_training() + # load back in the adam states if needed + torch.cuda.synchronize() + torch.distributed.barrier() + + if extra_candidates is None: + extra_candidates = {} + + monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} + monitor_candidates.update(extra_candidates) + + self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) + + self.model.finish_training() + + def set_max_steps(self): + self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs + + if (max_steps := self.cfg.get("max_steps", -1)) >= 0: + self.max_steps = min(self.max_steps, max_steps) + + @property + def epoch(self): + return self.step // self.num_steps_per_epoch diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py new file mode 100644 index 000000000..511d25c3f --- /dev/null +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py @@ -0,0 +1,394 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch +import torch.distributed +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.utils import divide +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.utils import logging +from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.distributed import ( + broadcast_2d_tensor_within_pp, + calculate_distributed_entropy, + from_parallel_logits_to_logprobs, +) +from nemo_aligner.utils.text_generation_utils import ( + TrackLengthGPTModelTextGenerationStrategy, + verify_is_valid_and_clamp_range_, +) +from nemo_aligner.utils.train_utils import ( + grad_reductions, + prepare_for_training_step, + set_eval, + set_sync_funcs, + set_train, +) +from nemo_aligner.utils.trt_llm import GPTGenerateTRTLLM +from nemo_aligner.utils.utils import ( + adapter_control, + clear_memory, + configure_batch_sizes, + cpu_weight_swap, + masked_mean, + offload_distributed_adam, +) + + +class MegatronGPTReinforceActorModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface): + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + self.automatic_optimization = False + + self.init_policy_state_dict = None + self.distributed_adam_offload_manager = None + + # length parameters for generation + self._length_params = OmegaConf.to_container(self.cfg.reinforce.length_params, resolve=True) + # sampling parameters for generation + self._sampling_params = OmegaConf.to_container(self.cfg.reinforce.sampling_params, resolve=True) + + self.to_offload_adam_states = self.cfg.reinforce.offload_adam_states and self.with_distributed_adam + self.forward_micro_batch_size = self.cfg.reinforce.forward_micro_batch_size + + self.use_trtllm_generation = "trt_llm" in self.cfg.reinforce and self.cfg.reinforce.trt_llm.enable + if self.use_trtllm_generation: + self.trtllm_generate = GPTGenerateTRTLLM( + model_cfg=self.cfg, + max_generation_length=self.cfg.reinforce.length_params.get("max_length", 1024), + max_input_len=self.cfg.reinforce.trt_llm.get("max_input_len", 1024), + max_input_tokens=self.cfg.reinforce.trt_llm.get("max_input_tokens", 4096), + generation_batch_size=self.cfg.reinforce.get("rollout_micro_batch_size", 4), + unload_engine_train=self.cfg.reinforce.trt_llm.get("unload_engine_train", False), + trt_model_type=self.cfg.reinforce.trt_llm.get("model_type", "llama"), + end_strings=self.cfg.reinforce.sampling_params["end_strings"], + reshard_model=self.cfg.reinforce.trt_llm.get("reshard", False), + sample_temperature=self.cfg.reinforce.sampling_params["temperature"], + sample_top_k=self.cfg.reinforce.sampling_params["top_k"], + sample_top_p=self.cfg.reinforce.sampling_params["top_p"], + repetition_penalty=self.cfg.reinforce.sampling_params["repetition_penalty"], + use_greedy=self.cfg.reinforce.sampling_params.get("use_greedy", False), + tokenizer=self.tokenizer, + seed=self.cfg.reinforce.trt_llm.get("seed", self.cfg.seed), + ) + + # training calls + def get_actor_forward_output_and_loss_func(self): + def fwd_output_and_loss_func(data_iterator, model): + batch = next(data_iterator) + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("response_tokens", "position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("response_tokens", "baseline", "mask", "rewards_with_kl", "is_end")) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + parallel_logits = model( + batch["response_tokens"], batch["position_ids"], batch["attention_mask"], labels=None, + ) + + def loss_func(parallel_logits): + mask = batch["mask"] + rewards_with_kl = batch["rewards_with_kl"] + baseline = batch["baseline"] + tokens = batch["response_tokens"] + is_end = batch["is_end"] + + is_end_mask = mask * is_end.view(-1, 1) + + curr_log_probs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=parallel_logits, target=tokens, higher_stability=True + ) + + reinforce_loss = -1 * curr_log_probs * (rewards_with_kl - baseline) + + if is_end_mask.sum() > 0: + loss = masked_mean(reinforce_loss, mask) + else: + # hack to disable this update since there are no valid tokens + loss = reinforce_loss.view(-1)[0] * 0 + + reduced_actor_loss = average_losses_across_data_parallel_group([loss]) + return ( + loss, + {"loss": reduced_actor_loss,}, + ) + + return parallel_logits, loss_func + + return fwd_output_and_loss_func + + def prepare_for_training(self): + configure_batch_sizes( + mbs=self.cfg.micro_batch_size, + gbs=self.cfg.global_batch_size, + dp=parallel_state.get_data_parallel_world_size(), + ) + self.onload_adam_states() + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def get_loss_and_metrics(self, batch, forward_only): + sequence_length = batch["response_tokens"].size(1) + + attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(tokens=batch["response_tokens"]) + batch["attention_mask"] = attention_mask + batch["position_ids"] = position_ids + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_actor_forward_output_and_loss_func(), + data_iterator=self._make_data_iterator_list(data_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=sequence_length, + micro_batch_size=self.cfg.micro_batch_size, + ) + + metrics = {} + + for key in ["loss"]: + if losses_reduced_per_micro_batch: + metric_mean = torch.stack( + [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + ).mean() + else: + metric_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + torch.distributed.broadcast(metric_mean, get_last_rank()) + + metrics[key] = metric_mean.cpu().item() + + return metrics["loss"], metrics + + def finish_training_step(self): + grad_reductions(self) + + def finish_training(self): + """no need to offload adam states here + """ + + # inference calls + def get_logprob_output_only_func(self, inference_only=True): + fwd_output_only_func = self.get_forward_output_only_func() + + def log_prob_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + + output_tensor, _ = fwd_output_only_func(iter([batch,]), model) + + def id_func(output_tensor, non_loss_data=True): + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, + target=batch[0], + inference_only=inference_only, + higher_stability=True, + ) + return logprobs + + return output_tensor, id_func + + return log_prob_output_only_func + + @torch.no_grad() + def get_inference_log_probs(self, response_tokens, forward_micro_batch_size=None): + if forward_micro_batch_size is None: + forward_micro_batch_size = self.forward_micro_batch_size + + set_sync_funcs(self, forward_only=True) + + mbs, seq_length = response_tokens.size() + num_microbatches = divide(mbs, forward_micro_batch_size) + attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens) + + batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches) + + fwd_bwd_function = get_forward_backward_func() + logprobs_list = fwd_bwd_function( + forward_step_func=self.get_logprob_output_only_func(inference_only=True), + data_iterator=self._make_data_iterator_list(batch_iter), + model=self.model, + num_microbatches=num_microbatches, + forward_only=True, + seq_length=seq_length, + micro_batch_size=forward_micro_batch_size, + collect_non_loss_data=True, + ) + + logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None + + # Broadcast it from last PP stage to everything else. + logprobs = broadcast_2d_tensor_within_pp(logprobs) + + return logprobs + + def prepare_for_inference(self): + """normally we would configure the micro batch calculator here + but the nemo generation already does the configuration""" + self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() + set_eval(self) + self.offload_adam_states() + + if self.use_trtllm_generation: + # TODO this might be optimized to avoid calling `refit()` twice in a row after a validation step + self.trtllm_generate.refit(self.model) + clear_memory() + + @torch.no_grad() + def infer(self, inference_batch): + prompt_tokens = inference_batch["text"].cuda(non_blocking=True) + prompt_lengths = inference_batch["length"].cuda(non_blocking=True) + inputs = (prompt_tokens, prompt_lengths) + + strategy = TrackLengthGPTModelTextGenerationStrategy( + model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"] + ) + + if self.use_trtllm_generation: + actor_output = self.trtllm_generate.generate(inputs) + response_tokens = actor_output["response_tokens"] + response_lengths = actor_output["response_lengths"] + else: + actor_output = self.generate( + inputs=inputs, + length_params=self._length_params, + sampling_params=self._sampling_params, + strategy=strategy, + ) + response_tokens = torch.cuda.LongTensor(actor_output["token_ids"]) if actor_output else None + response_tokens = broadcast_2d_tensor_within_pp(response_tokens, dtype=torch.long) + response_lengths = strategy.get_lengths() + + max_response_length = response_lengths.max().item() + + # Sanity check to validate response length. + if max_response_length != response_tokens.size(1): + # This may actually happen because NeMo does not always stop generation after `max_length` in batch mode + # => `response_tokens` may contain up to `max_length + max_context_length` tokens. + # TODO once NeMo fixes this issue we should be able to always raise an exception when the check above fails, + # and remove the `if` below. + if ( + max_response_length >= response_tokens.size(1) + or response_tokens.size(1) != prompt_lengths.max().item() + self._length_params["max_length"] + ): + raise AssertionError( + f"max response length ({max_response_length}) does not match the size of " + f"`response_tokens` ({response_tokens.size(1)})" + ) + + # sometimes backends like TRT-LLM will generate invalid tokens + # so we need to also inplace mutate the response_tokens to be within the tokenizer range + is_valid = verify_is_valid_and_clamp_range_( + response_tokens, + response_lengths, + strategy, + self.tokenizer, + self.cfg.reinforce.sampling_params["end_strings"], + ) + + rollout_batch = { + "response_tokens": response_tokens, + "response_lengths": response_lengths, + "prompt_lengths": prompt_lengths, + "is_end": is_valid, + } + + # return in GPU, trainer needs to move to cpu + + return rollout_batch + + def get_init_policy_logprobs(self, response_tokens): + use_peft_init_policy = self.use_peft and self.init_policy_state_dict is None + + context_mgr = ( + adapter_control(self) + if use_peft_init_policy + else cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2) + ) + + with context_mgr: + return self.get_inference_log_probs(response_tokens) + + def finish_inference(self): + # training will onload the adam states, no need to onload it here + self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() + + if self.use_trtllm_generation: + self.trtllm_generate.free() + + set_train(self) + + def offload_adam_states(self): + if self.distributed_adam_offload_manager is None: + + self.distributed_adam_offload_manager = ( + offload_distributed_adam( + self._optimizer.state_dict(state_dict_format=1, gather_on_root=False), force_clear_memory=True + ) + if self.to_offload_adam_states + else nullcontext() + ) + + # offload onto cpu + self.distributed_adam_offload_manager.__enter__() + + def onload_adam_states(self): + if self.distributed_adam_offload_manager is not None: + # load back onto GPU + self.distributed_adam_offload_manager.__exit__(None, None, None) + + self.distributed_adam_offload_manager = None + + def get_ltor_masks_and_position_ids(self, tokens): + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=self.tokenizer.eos_id, + reset_position_ids=self.cfg.data.get("reset_position_ids", False), + reset_attention_mask=self.cfg.data.get("reset_attention_mask", False), + eod_mask_loss=False, # since we ignore the loss mask here + ) + attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1) + position_ids = position_ids.expand(tokens.size(0), -1) + + return attention_mask, loss_mask, position_ids diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index 1d1f5cf67..e447c6dee 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -112,3 +112,28 @@ def select_topk(batch, num_select=1): selected_batch = {k: batch[k][selected_idx] for k in batch.keys()} return selected_batch + + +def calculate_rloo_baseline(prompts, reward, mask): + """ + Function to select the RLOO baseline for each (prompt, response) pair in the batch. + The same baseline is calculated for each prompt. Masked samples are not included + in the baseline calculation. + """ + unique_prompts = torch.unique(prompts, dim=0) + + baseline = torch.zeros_like(reward) + reward_device = reward.get_device() + for i in range(len(unique_prompts)): + is_matching_prompt = (prompts == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(prompts), device=reward_device)[is_matching_prompt] + rloo_mat = (1 - torch.eye(len(prompt_idx))).to(reward_device) + + if mask[prompt_idx].sum() <= 1: + # Ignore sample: set baseline equal to reward + baseline[prompt_idx] = reward[prompt_idx] + else: + rloo = torch.matmul(rloo_mat, reward[prompt_idx] * mask[prompt_idx]) / (mask[prompt_idx].sum() - 1) + baseline[prompt_idx] = rloo + + return baseline