diff --git a/CHANGELOG.md b/CHANGELOG.md index f19d4e2fe..f21c1a2f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Critic and Reward Model server refactored. Now the reward model will have a flag called `model.forward_micro_batch_size` which determines the micro batch size on which it runs inferences. This can be higher than the training micro batch size since during inference, we have less memory pressure. - In the critic and reward model server, it is now possible to specify `inference_micro_batch_size` as a list. This allows us to provide more information to PyTriton regarding the preferred batch sizes for inference. - It is no longer a requirement to specify `num_rollout_samples` to be a multiple of `inference_micro_batch_size * dp size` in PPO. +- Sequence packing is now supported when running SFT with SFTChatDataset. +- Add online rejection sampling algorithm. ### Breaking Changes - `inference.micro_batch_size` is now renamed to `inference.inference_micro_batch_size` when running reward model inference in `inference_rm.yaml`. This is to stay consistent with the naming scheme of the PPO critic. diff --git a/Dockerfile b/Dockerfile index 2d8f0ca18..e80eb34f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,6 +18,7 @@ ARG MLM_TAG=a3fe0c75df82218901fa2c3a7c9e389aa5f53182 # On: core_r0.8.0 ARG ALIGNER_COMMIT=main ARG TRTLLM_VERSION=v0.10.0 ARG PROTOBUF_VERSION=4.24.4 + ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 FROM ${BASE_IMAGE} AS aligner-bump diff --git a/README.md b/README.md index b689ba961..7c1452ad3 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ We provide an official NeMo-Aligner Dockerfile which is based on stable, tested Alternatively, you can build the NeMo Dockerfile here [NeMo Dockerfile](https://github.com/NVIDIA/NeMo/blob/main/Dockerfile) and add `RUN pip install nemo-aligner` at the end. ## Future work -- Add Rejection Sampling support. - We will continue improving the stability of the PPO learning phase. - Improve the performance of RLHF. +- Add TRT-LLM inference support for Rejection Sampling. ## Contribute to NeMo-Aligner We welcome community contributions! Please refer to [CONTRIBUTING.md](https://github.com/NVIDIA/NeMo-Aligner/blob/main/CONTRIBUTING.md) for guidelines. diff --git a/docs/README.md b/docs/README.md index 46480425d..95258ef38 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,12 +1,13 @@ -# Documentation +# ReadMe ## Custom Trainers -NeMo-Aligner uses custom trainers to coordinate all aspects of training. There are currently 3 custom trainers: -1. [SupervisedTrainer](/nemo_aligner/algorithms/supervised.py): for SFT, SteerLM and Reward modeling. +NeMo-Aligner uses custom trainers to coordinate all aspects of training. There are currently three custom trainers: +1. [SupervisedTrainer](/nemo_aligner/algorithms/supervised.py): for SFT, SteerLM, and Reward modeling. 2. [DPOTrainer](/nemo_aligner/algorithms/dpo.py): for DPO training. 3. [CriticServerTrainer](/nemo_aligner/algorithms/critic_server_trainer.py): trains the RL critic via PyTriton requests. It will also run the reward model depending on the configuration. 4. [PPOTrainer](/nemo_aligner/algorithms/ppo.py): performs the RLHF PPO training, since PPO has components such as the Critic, this trainer will send inference and train requests via [PyTriton](https://github.com/triton-inference-server/pytriton) to the CriticServerTrainer to train and run inference on the critic. +5. [RSTrainer](/nemo_aligner/algorithms/rs.py): performs the Rejection Sampling (RS) training. Since RS needs a reward model, this trainer will send inference requests via [PyTriton](https://github.com/triton-inference-server/pytriton) to run inference on the reward model. ## Configuration guide @@ -16,13 +17,14 @@ See the example configurations in the [conf folder](/examples/nlp/gpt/conf/) for ## APIs Our custom trainers will only call predefined APIs on the model passed in. These APIs are defined in [alignable_interface.py](/nemo_aligner/models/alignable_interface.py). -## Launching scripts and their description +## Launching Scripts * Supervised Fine Tuning Training: [train_gpt_sft.py](/examples/nlp/gpt/train_gpt_sft.py) with [gpt_sft.yaml](/examples/nlp/gpt/conf/gpt_sft.yaml). * DPO Training: [train_gpt_dpo.py](/examples/nlp/gpt/train_gpt_dpo.py) with [gpt_dpo.yaml](/examples/nlp/gpt/conf/gpt_dpo.yaml). * Reward Model Training: [train_reward_model.py](/examples/nlp/gpt/train_reward_model.py) with [training_rm.yaml](/examples/nlp/gpt/conf/training_rm.yaml). * Reward Model Inference: [serve_reward_model.py](/examples/nlp/gpt/serve_reward_model.py) with [inference_rm.yaml](/examples/nlp/gpt/conf/inference_rm.yaml). * PPO Critic Server: [serve_ppo_critic.py](/examples/nlp/gpt/serve_ppo_critic.py) with [gpt_ppo_critic.yaml](/examples/nlp/gpt/conf/gpt_ppo_critic.yaml). * PPO Actor Training: [train_gpt_ppo_actor.py](/examples/nlp/gpt/train_gpt_ppo_actor.py) with [gpt_ppo_actor.yaml](/examples/nlp/gpt/conf/gpt_ppo_actor.yaml). +* Rejection Sampling Training: [train_gpt_rs_actor.py](/examples/nlp/gpt/train_gpt_rs_actor.py) with [gpt_rs_actor.yaml](/examples/nlp/gpt/conf/gpt_rs_actor.yaml). To run a full RLHF PPO job, we need to start both the CriticServerTrainer and PPOTrainer. diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 03c476332..650d67a6e 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -33,6 +33,9 @@ :ref:`Model Alignment by DPO, RPO and IPO ` DPO, RPO, and IPO are simpler alignment methods compared to RLHF. DPO introduces a novel parameterization of the reward model in RLHF, which allows us to extract the corresponding optimal policy. Similarly, RPO and IPO provide alternative parameterizations or optimization strategies, each contributing unique approaches to refining model alignment. +:ref:`Model Alignment by Rejection Sampling (RS) ` + RS is a simple online alignment algorithm. In RS, the policy model generates several responses. These responses are assigned a score by the reward model, and the highest scoring responses are used for SFT. + :ref:`Fine-tuning Stable Diffusion with DRaFT+ ` DRaFT+ is an algorithm for fine-tuning text-to-image generative diffusion models. It achieves this by directly backpropagating through a reward model. This approach addresses the mode collapse issues from the original DRaFT algorithm and improves diversity through regularization. diff --git a/docs/user-guide/rs.rst b/docs/user-guide/rs.rst new file mode 100644 index 000000000..ac7ea30ee --- /dev/null +++ b/docs/user-guide/rs.rst @@ -0,0 +1,230 @@ +.. include:: /content/nemo.rsts + +.. _model-aligner-rs: + +Model Alignment by Rejection Sampling +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + +In this tutorial, we will guide you through the process of aligning a NeMo Framework model using rejection sampling. This method can be applied to various models, including LLaMa2 and Mistral, with our scripts functioning consistently across different models. + +Rejection Sampling is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the rejection sampling algorithm on the `Anthropic-HH-RLHF `__ dataset. + +Rejection Sampling Training +############ + +After you have fine-tuned a GPT model using Supervised Fine-Tuning (SFT), and trained a reward model as explained in the preceding section, you can start aligning the policy using rejection sampling. + +During rejection sampling training, we have two models interacting with each other, which Aligner runs in separate jobs: + +#. The Policy Network: This is the model we are training and it should start from an SFT model. +#. The Reward Model (RM): This model accepts a prompt combined with a response as input and produces a single scalar value, known as the reward. The rejection sampling algorithm aims to maximize this reward. + +The next section discusses how to launch each of these two jobs. + +Launching the Reward Model and Critic Server +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +To launch the server: + +.. code-block:: bash + + #!/bin/bash + CHECKPOINT_NEMO_FILE="/path/to/trained_rm.nemo" + GPFS="/path/to/nemo-aligner-repo" + + RESULTS_DIR="critic_results_dir" + + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/serve_reward_model.py \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + ++model.tensor_model_parallel_size=4 \ + rm_model_file=${RM_NEMO_FILE} + + +The above example launches the reward model server on 8 GPUs and 1 node. Please make sure to change trainer.devices, trainer.num_nodes depending on your model size and scale. Aligner will work on any scale. Also, make sure to tune the trainer.rs.inference_micro_batch_size argument. This argument sets the size of the batch the RS actor is allowed to send to the critic per DP rank. + +Launch the Initial Policy and RS Actor Training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +The RS Actor training job contains the master controller that makes the HTTP calls to all servers when needed. To launch the RS Actor and Initial Policy server: + +.. code-block:: bash + + GPFS="/path/to/nemo-aligner-repo" + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + PRETRAINED_ACTOR_NEMO_FILE="/path/to/sft_checkpoint.nemo" + RESULTS_DIR="/path/to/actor_results_dir" + + ACTOR_LR=1e-6 + NUM_ROLLOUTS=32 + ACTOR_GBS=32 + CRITIC_PORT=5555 + host_critic="$(scontrol show hostnames=$SLURM_JOB_NODELIST_HET_GROUP_0 | head -n1)" + + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/train_gpt_rs_actor.py \ + "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ + pretrained_checkpoint.restore_from_path=\"${PRETRAINED_ACTOR_NEMO_FILE}\" \ + exp_manager.checkpoint_callback_params.save_top_k=1 \ + exp_manager.explicit_log_dir=\"${RESULTS_DIR}\" \ + trainer.rs.max_epochs=1 \ + trainer.rs.max_steps=313 \ + trainer.rs.val_check_interval=4 \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + ++model.tensor_model_parallel_size=4 \ + model.global_batch_size=${ACTOR_GBS} \ + model.micro_batch_size=1 \ + model.optim.lr=\"\\\$\{multiply:${ACTOR_LR},1.001\}\" \ + model.optim.sched.warmup_steps=0 \ + model.optim.sched.constant_steps=312 \ + model.optim.sched.min_lr=${ACTOR_LR} \ + model.optim.weight_decay=0.01 \ + model.rs.num_rollout_samples=${NUM_ROLLOUTS} \ + model.rs.rollout_micro_batch_size=16 \ + model.rs.forward_micro_batch_size=16 \ + model.rs.val_rollout_micro_batch_size=8 \ + model.data.data_impl=jsonl \ + remote_rm.reward_model.ip=${host_critic} \ + remote_rm.reward_model.port=${CRITIC_PORT} \ + model.rs.num_rollouts_per_prompt=8 \ + model.rs.top_n_rollouts=1 + +The above command launches the initial and actor server on 1 node with 8 GPUs. + +Launching Both Servers for Rejection Sampling training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +You can use slurm to launch the 2 jobs and get them to coordinate together in a full Rejection Sampling job via the following: + +.. code-block:: bash + + #!/bin/bash + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + #SBATCH hetjob + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + + NAME="2p_rs" + + # PARAMETERS + RM_NEMO_FILE="/path/to/trained_rm.nemo" + + ACTOR_NEMO_FILE="/path/to/sft_model.nemo" + + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + RESULTS_DIR="/path/to/results_dir" + mkdir -p $RESULTS_DIR + + GPFS="/path/to/nemo-aligner-repo" + MOUNTS="--container-mounts=MOUNTS" # mounts + + CONTAINER=<<>> # use the latest NeMo Training container, Aligner will work there + + PROJECT=rs_run + + CRITIC_LOG_DIR="${RESULTS_DIR}/critic_results" + CRITIC_OUTFILE="${CRITIC_LOG_DIR}/critic_output_%j_%t.log" + CRITIC_ERRFILE="${CRITIC_LOG_DIR}/critic_error_%j_%t.err" + CRITIC_PORT=5567 + CRITIC_CONFIG_PATH="${GPFS}/examples/nlp/gpt/conf" + CRITIC_CONFIG_NAME="inference_rm" + + CONF_DIR="${GPFS}/examples/nlp/gpt/conf" + CONFIG_NAME="gpt_rs_actor" + + mkdir -p $CRITIC_LOG_DIR + + CRITIC_NAME="${NAME}_critic" + + read -r -d '' cmd_critic_inference <`__ script from the NeMo codebase to run more rigorous evaluation of your trained model. diff --git a/examples/nlp/gpt/conf/gpt_rs_actor.yaml b/examples/nlp/gpt/conf/gpt_rs_actor.yaml new file mode 100644 index 000000000..b819ca287 --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_rs_actor.yaml @@ -0,0 +1,180 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + # these args are respected + num_nodes: 8 + devices: 8 + accelerator: gpu + precision: bf16 + + rs: + max_epochs: 1 + max_steps: -1 # max rs steps (-1 to go through the whole train set) + val_check_interval: 10 + save_interval: ${.val_check_interval} + gradient_clip_val: 1.0 + + # pick up from the model + # *do not change this* + model_gbs: ${model.global_batch_size} + model_mbs: ${model.micro_batch_size} + + # the sequence length to pad the rollout batch to + # this reduces fragmentation at the cost of using more + # memory, set to null if we don't want to pad it + # to a constant size + rollout_batch_seq_length: null + + # no need to change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.rs.max_epochs} + max_steps: ${.rs.max_steps} + +remote_rm: + # what to pad the inputs to + # set to None if no padding when sending data for reward model inference + pad_to_length: ${model.encoder_seq_length} + + # reward model server + reward_model: + name: reward_model + ip: localhost + port: 5555 + + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt_rs_actor + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_rs + name: gpt3_rs_2b + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_global_rewards + save_top_k: 1 + mode: max + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt-{step}-{consumed_samples}-{rs_optimization_step}-{epoch}-{val_global_rewards:.3f}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +pretrained_checkpoint: + restore_from_path: null + +model: + + rs: + # training generation mbs + rollout_micro_batch_size: 8 + num_rollout_samples: 512 + + # mbs to do log prob inference, can be set to + # lower than rollout_micro_batch_size to reduce + # memory usage + forward_micro_batch_size: ${.rollout_micro_batch_size} + + num_rollouts_per_prompt: 4 # Number of completions to sample per prompt + top_n_rollouts: 1 # Number of completions to select based on reward and train upon (per prompt) + + # val generation mbs + val_rollout_micro_batch_size: ${.rollout_micro_batch_size} + num_val_samples: ${.num_rollout_samples} + + # to offload during generation or not + offload_adam_states: True + + # params for generation + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 0 + top_p: 1.0 + repetition_penalty: 1.0 + add_BOS: False + all_probs: False + compute_logprob: False + end_strings: ["<|endoftext|>", ""] + + # length argument for autoregressive sampling + # max length means max amount of tokens to generate + length_params: + max_length: ${int_div:${model.encoder_seq_length}, 2} + min_length: 1 + + #peft + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + mcore_gpt: True + # these control the mbs/gbs during RS training + micro_batch_size: 1 + global_batch_size: 64 + megatron_amp_O2: True + + encoder_seq_length: 4096 + max_position_embeddings: ${model.encoder_seq_length} + + ## Sequence Parallelism + sequence_parallel: False + + # miscellaneous + seed: 1234 + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-7 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-8 + + precision: ${trainer.precision} + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: null + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True \ No newline at end of file diff --git a/examples/nlp/gpt/train_gpt_rs_actor.py b/examples/nlp/gpt/train_gpt_rs_actor.py new file mode 100644 index 000000000..611967a59 --- /dev/null +++ b/examples/nlp/gpt/train_gpt_rs_actor.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.multiprocessing as mp +from megatron.core import parallel_state +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.rs import RSTrainer +from nemo_aligner.data.nlp.builders import ( + build_dataloader, + build_train_valid_test_rlhf_datasets, + collate_with_pad_to_max_batch, +) +from nemo_aligner.models.nlp.gpt.megatron_gpt_rs_actor import MegatronGPTRSModel +from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMClient +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + resolve_and_create_trainer, + retrieve_custom_trainer_state_dict, +) +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo + +"""Script to start RS training""" + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) +OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="gpt_rs_actor") +def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + trainer = resolve_and_create_trainer(cfg, "rs") + + exp_manager(trainer, cfg.exp_manager) + + logger = CustomLoggerWrapper(trainer.loggers) + + ptl_model = load_from_nemo( + MegatronGPTRSModel, cfg.model, trainer, strict=True, restore_path=cfg.pretrained_checkpoint.restore_from_path, + ) + + init_peft(ptl_model, cfg.model) + + # pull values from checkpoint + trainer_restore_path = trainer.ckpt_path + + # TODO: log this restore path + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the entire dataset + train_valid_test_num_samples = [-1, -1, -1] + train_ds, validation_ds, _ = build_train_valid_test_rlhf_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl=cfg.model.data.data_impl, + splits_string=cfg.model.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.seed, + tokenizer=ptl_model.tokenizer, + ) + + max_seqlen = cfg.model.rs.length_params.max_length + eos_id = ptl_model.tokenizer.eos_id + + # collate fn to pad to the max seq length in the batch + collate_fn = collate_with_pad_to_max_batch(max_seqlen, eos_id, cfg) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.rs.rollout_micro_batch_size, + gbs=cfg.model.rs.num_rollout_samples, + collate_fn=collate_fn, + load_gbs=False, + ) + + val_dataloader = build_dataloader( + cfg=cfg, + dataset=validation_ds, + consumed_samples=0, + mbs=cfg.model.rs.rollout_micro_batch_size, + gbs=cfg.model.rs.num_val_samples, + collate_fn=collate_fn, + load_gbs=False, + use_random_sampler=False, + ) + + # nemo uses the train dataloader to figure out + # max steps to take when max_steps = -1 + # but our train dataloader is for the prompts + # so we instaniate a dummy dataloader + # to get the proper max *optimization* steps + # nemo treats batch size of normal dataloader as GBS/DP + # so we need to offset it by DP + dummy_train_dataloader = torch.utils.data.DataLoader( + dataset=train_ds, batch_size=divide(cfg.model.global_batch_size, parallel_state.get_data_parallel_world_size()) + ) + + init_using_ptl(trainer, ptl_model, dummy_train_dataloader, train_ds) + # make sure the dummy train dataloader is never used + del ptl_model._train_dl + del dummy_train_dataloader + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + rm = RemoteGPTRMClient(cfg.remote_rm) + timer = Timer(cfg.exp_manager.get("max_time_per_run")) + + rs_trainer = RSTrainer( + cfg=cfg.trainer.rs, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + rm=rm, + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + num_rollouts_per_prompt=cfg.model.rs.num_rollouts_per_prompt, + top_n_rollouts=cfg.model.rs.top_n_rollouts, + ) + + if custom_trainer_state_dict is not None: + rs_trainer.load_state_dict(custom_trainer_state_dict) + + rs_trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/nemo_aligner/algorithms/rs.py b/nemo_aligner/algorithms/rs.py new file mode 100644 index 000000000..493b743d4 --- /dev/null +++ b/nemo_aligner/algorithms/rs.py @@ -0,0 +1,478 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import pandas as pd +import torch +from megatron.core import parallel_state +from megatron.core.utils import divide +from omegaconf.dictconfig import DictConfig +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm + +from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split +from nemo.utils import logging +from nemo_aligner.utils.distributed import SyncTimer, pad_list, pad_tensors_to_max_global_seq_len +from nemo_aligner.utils.ppo_utils import create_mask, select_topk +from nemo_aligner.utils.train_utils import clip_gradients +from nemo_aligner.utils.trainer_utils import check_progress, compute_num_steps_per_epoch +from nemo_aligner.utils.utils import clear_memory, cpu_dict + + +def compute_num_rollout_microbatches(dataloader): + return divide( + divide(dataloader.batch_sampler.global_batch_size, dataloader.batch_sampler.micro_batch_size), + parallel_state.get_data_parallel_world_size(), + ) + + +class RSTrainer: + """Trainer to coordinate RS training + """ + + def __init__( + self, + cfg: DictConfig, + model, + optimizer, + scheduler, + train_dataloader, + val_dataloader, + logger, + ckpt_callback, + run_timer, + num_rollouts_per_prompt, + top_n_rollouts, + rm, + ): + self.cfg = cfg + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.logger = logger + self.ckpt_callback = ckpt_callback + self.num_rollouts_per_prompt = num_rollouts_per_prompt + self.top_n_rollouts = top_n_rollouts + self.rm = rm + + # this timer checks if we should stop training + self.run_timer = run_timer + + self.consumed_samples = 0 + # the step here is RS step + self.step = 0 + # keep track of how many times we optimized the actor + self.rs_optimization_step = 0 + + # compute `max_steps` + self.num_steps_per_epoch = compute_num_steps_per_epoch(self.train_dataloader.batch_sampler) + self.set_max_steps() + + # size to pad our rollout batch to + self.rollout_batch_seq_length = self.cfg.rollout_batch_seq_length + + # for wandb table + self.train_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + self.val_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + + self.timer = SyncTimer( + reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX + ) + + def generate_rs_data(self, rollout_batches): + """generate rs specific data for training + """ + rs_rollout_data = defaultdict(list) + rs_rollout_metrics = defaultdict(int) + num_samples = 0 + + def post_process_tensor(tensor): + return map(lambda x: x.flatten(), tensor.cpu().split(1, dim=0)) + + for rollout_batch in rollout_batches: + # NOTE: all items in rollout batch or out of this computation + # must have a leading B dimension + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + response_tokens = rollout_batch["response_tokens"] + + prompt_tokens = rollout_batch["prompt_tokens"] + + num_samples += prompt_lengths.size(0) + + # mask will mask out the loss on the prompt tokens + mask = create_mask( + values=torch.zeros([response_tokens.shape[0], response_tokens.shape[1] - 1]), + prompt_lengths=prompt_lengths, + response_lengths=response_lengths, + ) + + # collect everything we need to train actor + rs_rollout_data["mask"].extend(post_process_tensor(mask)) + rs_rollout_data["response_tokens"].extend(post_process_tensor(response_tokens)) + rs_rollout_data["prompt_tokens"].extend(post_process_tensor(prompt_tokens)) + + # average across the samples for the non global metrics + rs_rollout_metrics = {k: v / num_samples for k, v in rs_rollout_metrics.items()} + + for k in rs_rollout_data: + rollout_batch_seq_length = self.rollout_batch_seq_length + pad_value = self.model.tokenizer.eos_id + + # all other tensors in the rollout batch + # will be B x S -1 (because we don't predict anything for the last token) + + rs_rollout_data[k] = pad_tensors_to_max_global_seq_len( + rs_rollout_data[k], + pad_value=pad_value, + group=parallel_state.get_data_parallel_group(), + sequence_length_to_pad_to=rollout_batch_seq_length, + ) + + return rs_rollout_data, cpu_dict(rs_rollout_metrics) + + def _run_inference(self, dataloader_iter, num_microbatches, is_validation): + """this function is run per DP so the metrics need to be computed globally + """ + rollout_batches = [] + if not is_validation: + full_batches = [] # compute metrics over all batches, not just the selected ones + for _, inference_batch in zip(range(num_microbatches), dataloader_iter): + + current_batch = None + prompt_tokens, response_tokens, response_lengths, prompt_lengths, rewards = ( + [], + [], + [], + [], + [], + ) + for _ in range(self.num_rollouts_per_prompt): + rollout_batch = self.model.infer(inference_batch) + reward = self.rm.infer_rm(rollout_batch).result().detach() + + prompt_tokens.append(inference_batch["text"]) + response_tokens.append(rollout_batch["response_tokens"]) + response_lengths.append(rollout_batch["response_lengths"]) + prompt_lengths.append(rollout_batch["prompt_lengths"]) + rewards.append(reward) + + all_rollouts = {} + all_rollouts["response_tokens"] = torch.concatenate( + pad_list(response_tokens, pad_value=self.model.tokenizer.eos_id) + ) + all_rollouts["prompt_tokens"] = torch.concatenate(prompt_tokens) + all_rollouts["response_lengths"] = torch.concatenate(response_lengths) + all_rollouts["prompt_lengths"] = torch.concatenate(prompt_lengths) + all_rollouts["rewards"] = torch.concatenate(rewards) + + rollout_batch = select_topk(all_rollouts, self.top_n_rollouts) + + rollout_batches.append(rollout_batch) + full_batches.append(all_rollouts) + return rollout_batches, cpu_dict(self.compute_global_rollout_metrics(full_batches)) + + else: + for _, inference_batch in zip(range(num_microbatches), dataloader_iter): + rollout_batch = self.model.infer(inference_batch) + + rewards = self.rm.infer_rm(rollout_batch).result().detach() + rollout_batch["rewards"] = rewards + rollout_batches.append(rollout_batch) + + return rollout_batches, cpu_dict(self.compute_global_rollout_metrics(rollout_batches)) + + def compute_global_rollout_metrics(self, rollout_batches): + metrics = defaultdict(lambda: 0) + table = {} + + num_samples = 0 + for i, rollout_batch in enumerate(rollout_batches): + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + + # table logging + if i == 0: + reward = rewards[0] + prompt_length = prompt_lengths[0] + response_length = response_lengths[0] + response_token = response_tokens[0] + + table["reward"] = reward.item() + table["prompt"] = self.model.tokenizer.ids_to_text(response_token[:prompt_length].tolist()) + table["response"] = self.model.tokenizer.ids_to_text( + response_token[prompt_length:response_length].tolist() + ) + + metrics["response_lengths"] += (response_lengths - prompt_lengths).sum() + metrics["prompt_lengths"] += prompt_lengths.sum() + metrics["rewards"] += rewards.sum() + num_samples += prompt_lengths.size(0) + + tensor_to_accumulate = torch.tensor( + [metrics["response_lengths"], metrics["prompt_lengths"], metrics["rewards"], num_samples], + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + torch.distributed.all_reduce(tensor_to_accumulate, group=parallel_state.get_data_parallel_group()) + + ( + global_response_lengths, + global_prompt_lengths, + global_rewards, + global_num_samples, + ) = tensor_to_accumulate.tolist() + + metrics = { + "table": table, + "global_response_lengths_mean": global_response_lengths / global_num_samples, + "global_prompt_lengths": global_prompt_lengths / global_num_samples, + "global_rewards": global_rewards / global_num_samples, + } + return metrics + + @torch.no_grad() + def run_validation(self): + self.model.prepare_for_inference() + + num_val_micro_batches = compute_num_rollout_microbatches(self.val_dataloader) + val_dataloader = iter(self.val_dataloader) + + _, rollout_metrics = self._run_inference(val_dataloader, num_val_micro_batches, is_validation=True) + self.model.finish_inference() + return rollout_metrics + + @torch.no_grad() + def generate_rollouts(self, dataloader_iter, num_microbatches): + + self.model.prepare_for_inference() + rollout_batches, rollout_metrics = self._run_inference(dataloader_iter, num_microbatches, is_validation=False) + + rs_rollout_data, rs_rollout_metrics = map(cpu_dict, self.generate_rs_data(rollout_batches)) + + self.model.finish_inference() + + self.consumed_samples += ( + rs_rollout_data["response_tokens"].size(0) * parallel_state.get_data_parallel_world_size() + ) + return rs_rollout_data, rollout_metrics | rs_rollout_metrics | {"consumed_samples": self.consumed_samples} + + def run_training(self, dataloader_iter): + self.model.prepare_for_training() + + for batch in dataloader_iter: + """ + batch has [mask, advantages, prev_logprobs, response_tokens, rewards, values, returns] + mask: [mbs, seq_len-1] + response_tokens: [mbs, seq_len] + + """ + self.timer.start("train_step_time") + self.optimizer.zero_grad() + + self.model.prepare_for_training_step() + loss_mean, metrics = self.model.get_loss_and_metrics(batch=batch, forward_only=False) + self.model.finish_training_step() + + grad_norm = clip_gradients(self.model, self.cfg.gradient_clip_val) + grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + lr = self.optimizer.param_groups[0]["lr"] + + self.optimizer.step() + self.scheduler.step() + + if grad_norm is not None: + metrics["grad_norm"] = grad_norm + + metrics.update({"lr": lr, "loss": loss_mean, "optim_step": self.rs_optimization_step}) + + self.timer.stop("train_step_time") + metrics["train_step_time"] = self.timer.get("train_step_time") + + self.logger.log_metrics( + metrics, step=self.step, prefix="train_optim/", + ) + + self.rs_optimization_step += 1 + + self.model.finish_training() + + # zero grad again incase it frees up grad mem + self.optimizer.zero_grad() + return loss_mean, metrics + + def fit(self): + + epoch_iter = range(self.epoch, self.cfg.max_epochs) + if len(epoch_iter) <= 0: + # epoch done + return + + for _ in epoch_iter: + num_steps_in_epoch = min( + self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch + ) + loop_iter = range(num_steps_in_epoch) + + if not loop_iter: + return # training ended + + dataloader_iter = iter(self.train_dataloader) + + global_pbar = tqdm(loop_iter, initial=self.step, total=self.max_steps, leave=True, desc="RS Global Step") + + num_rollout_micro_batches = compute_num_rollout_microbatches(self.train_dataloader) + dp_size = parallel_state.get_data_parallel_world_size() + + num_to_load_on_each_dp = divide(self.cfg.model_gbs, dp_size) + + self.run_timer.start_time() + for _ in global_pbar: + step_metrics = {} + timing_metrics = {} + + self.timer.start("rollout_time") + rs_rollout_data, metrics = self.generate_rollouts(dataloader_iter, num_rollout_micro_batches) + + self.timer.stop("rollout_time") + timing_metrics["rollout_time"] = self.timer.get("rollout_time") + + # logging + table_metrics = metrics.pop("table") + self.train_df.loc[len(self.train_df)] = [ + self.step, + table_metrics["prompt"], + table_metrics["response"], + table_metrics["reward"], + ] + metrics["epoch"] = self.epoch + 1 + self.logger.log_metrics( + metrics, step=self.step, prefix="train_rollouts/", + ) + self.logger.log_table( + key="table/train_rollouts", dataframe=self.train_df, step=self.step, + ) + + rollout_size = rs_rollout_data["response_tokens"].size(0) + rollout_dataloader_iter = get_iterator_k_split( # Does not have prompt info + rs_rollout_data, divide(rollout_size, num_to_load_on_each_dp) + ) + # start training + clear_memory() + self.timer.start("train_time") + self.run_training(rollout_dataloader_iter) + + self.timer.stop("train_time") + timing_metrics["train_time"] = self.timer.get("train_time") + + self.step += 1 + + run_time_exceeded = self.run_timer.is_finished() + run_val, save_model, is_train_end = check_progress( + self.step, + self.max_steps, + self.cfg.val_check_interval, + self.cfg.save_interval, + 1.0, # TODO:(geshen): allow for limit val batches + run_time_exceeded=run_time_exceeded, + ) + + if run_val: + self.timer.start("validation_time") + val_metrics = self.run_validation() + self.timer.stop("validation_time") + timing_metrics["validation_time"] = self.timer.get("validation_time") + + val_table_metrics = val_metrics.pop("table") + + self.val_df.loc[len(self.val_df)] = [ + self.step, + val_table_metrics["prompt"], + val_table_metrics["response"], + val_table_metrics["reward"], + ] + self.logger.log_metrics(val_metrics, step=self.step, prefix="val_rollouts/") + self.logger.log_table("table/val_rollouts", dataframe=self.val_df, step=self.step) + + step_metrics.update({f"val_{k}": v for k, v in val_metrics.items()}) + + self.logger.log_metrics(timing_metrics, step=self.step, prefix="timers/") + + step_metrics.update(timing_metrics) + step_metrics.update({f"train_{k}": v for k, v in metrics.items()}) + global_pbar.set_postfix(step_metrics) + + if save_model: + step_metrics = {k: torch.as_tensor(v) for k, v in step_metrics.items()} + self.save(step_metrics, is_train_end=is_train_end) + + if run_time_exceeded: + logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run") + return + + self.logger.finalize() + + def state_dict(self): + return { + "step": self.step, + "consumed_samples": self.consumed_samples, + "epoch": self.epoch, + "rs_optimization_step": self.rs_optimization_step, + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.consumed_samples = state_dict["consumed_samples"] + self.rs_optimization_step = state_dict["rs_optimization_step"] + + loaded_values = [self.step, self.consumed_samples, self.rs_optimization_step] + + # make sure everyone loaded the same checkpoint as rank 0 + to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(to_broadcast, 0) + + assert loaded_values == to_broadcast.tolist() + # restore max steps we need to run for + self.set_max_steps() + + def save(self, extra_candidates=None, is_train_end=False): + self.model.prepare_for_training() + # load back in the adam states if needed + torch.cuda.synchronize() + torch.distributed.barrier() + + if extra_candidates is None: + extra_candidates = {} + + monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} + monitor_candidates.update(extra_candidates) + + self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) + + self.model.finish_training() + + def set_max_steps(self): + self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs + + if (max_steps := self.cfg.get("max_steps", -1)) >= 0: + self.max_steps = min(self.max_steps, max_steps) + + @property + def epoch(self): + return self.step // self.num_steps_per_epoch diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py new file mode 100644 index 000000000..49dd3846c --- /dev/null +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py @@ -0,0 +1,261 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch +from megatron.core import parallel_state +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.utils import divide +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface +from nemo_aligner.utils.distributed import broadcast_2d_tensor_within_pp, from_parallel_logits_to_logprobs +from nemo_aligner.utils.text_generation_utils import TrackLengthGPTModelTextGenerationStrategy +from nemo_aligner.utils.train_utils import ( + grad_reductions, + prepare_for_training_step, + set_eval, + set_sync_funcs, + set_train, +) +from nemo_aligner.utils.utils import ( + adapter_control, + configure_batch_sizes, + cpu_weight_swap, + masked_mean, + offload_distributed_adam, +) + + +class MegatronGPTRSModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface): + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + self.automatic_optimization = False + + self.distributed_adam_offload_manager = None + + # length parameters for generation + self._length_params = OmegaConf.to_container(self.cfg.rs.length_params, resolve=True) + # sampling parameters for generation + self._sampling_params = OmegaConf.to_container(self.cfg.rs.sampling_params, resolve=True) + + self.to_offload_adam_states = self.cfg.rs.offload_adam_states + self.forward_micro_batch_size = self.cfg.rs.forward_micro_batch_size + + def get_actor_forward_output_and_loss_func(self): + def fwd_output_and_loss_func(data_iterator, model): + + batch = next(data_iterator) + response_tokens = batch["response_tokens"] + mask = batch["mask"] + + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + data=response_tokens, + eod_token=self.tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + + batch = { + "tokens": response_tokens, + "attention_mask": attention_mask, + "position_ids": position_ids, + "mask": mask, + } + + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("tokens", "position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("tokens", "advantages", "mask", "prev_log_probs")) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + parallel_logits = model(batch["tokens"], batch["position_ids"], batch["attention_mask"], labels=None,) + + # TODO: This loss depends on the mbs, which is can lead to inconsistencies. See https://github.com/NVIDIA/NeMo/issues/8343. + def loss_func(parallel_logits): + mask = batch["mask"] + tokens = batch["tokens"] + + curr_log_probs = from_parallel_logits_to_logprobs(vocab_parallel_logits=parallel_logits, target=tokens) + loss = -1 * masked_mean(curr_log_probs, mask) # Loss is mean logits on response tokens + + reduced_actor_loss = average_losses_across_data_parallel_group([loss]) + return ( + loss, + {"loss": reduced_actor_loss,}, + ) + + return parallel_logits, loss_func + + return fwd_output_and_loss_func + + def prepare_for_training(self): + configure_batch_sizes( + mbs=self.cfg.micro_batch_size, + gbs=self.cfg.global_batch_size, + dp=parallel_state.get_data_parallel_world_size(), + ) + self.onload_adam_states() + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def get_loss_and_metrics(self, batch, forward_only): + sequence_length = batch["response_tokens"].size(-1) + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_actor_forward_output_and_loss_func(), + data_iterator=data_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=sequence_length, + micro_batch_size=self.cfg.micro_batch_size, + ) + + metrics = {} + + for key in ["loss"]: + if losses_reduced_per_micro_batch: + metric_mean = torch.stack( + [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + ).mean() + else: + metric_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + torch.distributed.broadcast(metric_mean, get_last_rank()) + + metrics[key] = metric_mean.cpu().item() + + return metrics["loss"], metrics + + def finish_training_step(self): + grad_reductions(self) + + def finish_training(self): + """no need to offload adam states here + """ + + def prepare_for_inference(self): + """normally we would configure the micro batch calculator here + but the nemo generation already does the configuration""" + self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() + set_eval(self) + self.offload_adam_states() + + @torch.no_grad() + def infer(self, inference_batch): + prompt_tokens = inference_batch["text"].cuda(non_blocking=True) + prompt_lengths = inference_batch["length"].cuda(non_blocking=True) + inputs = (prompt_tokens, prompt_lengths) + + strategy = TrackLengthGPTModelTextGenerationStrategy( + model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"] + ) + actor_output = self.generate( + inputs=inputs, length_params=self._length_params, sampling_params=self._sampling_params, strategy=strategy + ) + + response_lengths = strategy.get_lengths() + max_response_length = response_lengths.max().item() + + response_tokens = torch.cuda.LongTensor(actor_output["token_ids"]) + + # Sanity check to validate response length. + if max_response_length != response_tokens.size(1): + # This may actually happen because NeMo does not always stop generation after `max_length` in batch mode + # => `response_tokens` may contain up to `max_length + max_context_length` tokens. + # TODO once NeMo fixes this issue we should be able to always raise an exception when the check above fails, + # and remove the `if` below. + if ( + max_response_length >= response_tokens.size(1) + or response_tokens.size(1) != prompt_lengths.max().item() + self._length_params["max_length"] + ): + raise AssertionError( + f"max response length ({max_response_length}) does not match the size of " + f"`response_tokens` ({response_tokens.size(1)})" + ) + + rollout_batch = { + "response_tokens": response_tokens, + "response_lengths": response_lengths, + "prompt_lengths": prompt_lengths, + } + + # return in GPU, trainer needs to move to cpu + return rollout_batch + + def finish_inference(self): + # training will onload the adam states, no need to onload it here + self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() + set_train(self) + + def offload_adam_states(self): + if self.distributed_adam_offload_manager is None: + + self.distributed_adam_offload_manager = ( + offload_distributed_adam(self._optimizer.state_dict(state_dict_format=1, gather_on_root=False)) + if self.to_offload_adam_states and self.with_distributed_adam + else nullcontext() + ) + + # offload onto cpu + self.distributed_adam_offload_manager.__enter__() + + def onload_adam_states(self): + if self.distributed_adam_offload_manager is not None: + # load back onto GPU + self.distributed_adam_offload_manager.__exit__(None, None, None) + + self.distributed_adam_offload_manager = None + + def get_ltor_masks_and_position_ids(self, tokens): + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=self.tokenizer.eos_id, + reset_position_ids=self.cfg.data.get("reset_position_ids", False), + reset_attention_mask=self.cfg.data.get("reset_attention_mask", False), + eod_mask_loss=False, # since we ignore the loss mask here + ) + attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1) + + return attention_mask, loss_mask, position_ids diff --git a/nemo_aligner/models/nlp/gpt/reward_critic_clients.py b/nemo_aligner/models/nlp/gpt/reward_critic_clients.py index 5aa4f6afe..94990ba7e 100644 --- a/nemo_aligner/models/nlp/gpt/reward_critic_clients.py +++ b/nemo_aligner/models/nlp/gpt/reward_critic_clients.py @@ -75,6 +75,16 @@ def result(self): return rewards.flatten(), values +class RMFutureResult(FutureResult): + def __init__(self, rm_future): + self.rm_future = rm_future + + def result(self): + rewards = get_future_result(self.rm_future, "rewards") + + return rewards + + class SaveFuture(FutureResult): def __init__(self, pytriton_save_future): self.pytriton_save_future = pytriton_save_future @@ -170,3 +180,40 @@ def save(self): ) return SaveFuture(save_future) + + +@dataclass +class RemoteGPTRMClient: + cfg: DictConfig + + def __post_init__(self): + cfg = self.cfg + + server_dict = {cfg.reward_model.name: (cfg.reward_model.ip, cfg.reward_model.port)} + + self.communicator = HTTPCommunicator.create_http_communicator_from_dict(server_dict) + self.communicator.print_server_dict() + self.pad_to_length = self.cfg.pad_to_length + + def infer_rm(self, rollout_batch): + response_tokens = rollout_batch["response_tokens"].cpu() + og_seq_length = response_tokens.size(-1) + + if self.pad_to_length is not None: + assert ( + og_seq_length <= self.pad_to_length + ), f"original shape before padding {og_seq_length} is higher than {self.pad_to_length}" + response_tokens = torch.nn.functional.pad( + response_tokens, (0, self.pad_to_length - response_tokens.size(-1)), value=0 + ) + + send_data = { + "tokens": response_tokens.numpy(), + "sequence_lengths": rollout_batch["response_lengths"].unsqueeze(1).cpu().numpy(), + } + + rm_future = run_if_model_parallel_src( + self.communicator.send_data_to_server, server_name=self.cfg.reward_model.name, data=send_data + ) + + return RMFutureResult(rm_future) diff --git a/nemo_aligner/utils/distributed.py b/nemo_aligner/utils/distributed.py index 923ba8ed3..3b93ea8a6 100644 --- a/nemo_aligner/utils/distributed.py +++ b/nemo_aligner/utils/distributed.py @@ -386,6 +386,16 @@ def is_finished(self): return is_finished_tensor.item() +def pad_list(tensor_list, pad_value): + """ + Pad list of tensors to max seq len + """ + max_N = max(tensor.size(1) for tensor in tensor_list) + padded_tensors = [torch.nn.functional.pad(t, (0, max_N - t.size(1))) for t in tensor_list] + + return padded_tensors + + def run_distributed_inference(inputs=None, infer_fn=None): tokens, lengths = None, None dp_rank = parallel_state.get_data_parallel_rank() diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index cbe074e5c..1d1f5cf67 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -14,7 +14,10 @@ """helper functions for PPO training""" +import operator + import torch + from nemo_aligner.utils.utils import masked_mean @@ -89,3 +92,23 @@ def create_mask(values, prompt_lengths, response_lengths): # as it is because we want to include one EOS token. mask[i, prompt_lengths[i] - 1 : response_lengths[i] - 1] = 1.0 return mask + + +def select_topk(batch, num_select=1): + """ + Function to select the topk responses for each unique prompt in a batch. + Please note that this function samples the same top response for each identical prompt. + Duplicate prompts in the same batch may cause unexpected behavior. + """ + unique_prompts = torch.unique(batch["prompt_tokens"], dim=0) + selected_idx = [] + + for i in range(len(unique_prompts)): + is_matching_prompt = (batch["prompt_tokens"] == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(batch["prompt_tokens"]))[is_matching_prompt] + sorted_idx = zip(prompt_idx, batch["rewards"][is_matching_prompt]) + sorted_idx = sorted(sorted_idx, key=operator.itemgetter(1)) + selected_idx += [x[0].item() for x in sorted_idx[-1 * num_select :]] + + selected_batch = {k: batch[k][selected_idx] for k in batch.keys()} + return selected_batch