Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Self-Rewarding Algorithm with TRT Support #321

Open
wants to merge 271 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
271 commits
Select commit Hold shift + click to select a range
e7ea083
add logging
gshennvm Mar 21, 2024
6cd8923
fix critic client
gshennvm Mar 21, 2024
887874d
fix
gshennvm Mar 21, 2024
6e2ca2d
fix
gshennvm Mar 21, 2024
bacd786
fix
gshennvm Mar 21, 2024
17b04ca
fix
gshennvm Mar 21, 2024
9fee58b
fix typo
gshennvm Mar 21, 2024
61ba204
better train timing
gshennvm Mar 21, 2024
c9dad2a
remove prints
gshennvm Mar 21, 2024
e5f68e2
with timing
gshennvm Mar 21, 2024
74791ae
delete unused func
gshennvm Mar 21, 2024
e40ebd6
add critic logging
gshennvm Mar 21, 2024
72ba6c6
add
gshennvm Apr 1, 2024
bfb61e4
cleanup
gshennvm Apr 6, 2024
21206c5
update
gshennvm Apr 6, 2024
148acf4
fix
gshennvm Apr 6, 2024
d7c9990
fix bug
gshennvm Apr 6, 2024
537d6e5
fix bug
gshennvm Apr 6, 2024
47400ba
test
gshennvm Apr 6, 2024
d7b2b23
fix bug
gshennvm Apr 6, 2024
ce76226
fix
gshennvm Apr 7, 2024
8edf534
add
gshennvm Apr 7, 2024
6379a2e
fix
gshennvm Apr 8, 2024
eadae31
fix again
gshennvm Apr 8, 2024
e2b97d9
fix
gshennvm Apr 8, 2024
d9bdf7c
fix mean
gshennvm Apr 8, 2024
1c7d215
fix
gshennvm Apr 8, 2024
3638301
add debug
gshennvm Apr 8, 2024
4cca85f
fix
gshennvm Apr 8, 2024
1b19bdd
add data iter for VP
gshennvm Apr 8, 2024
3f045ae
move
gshennvm Apr 8, 2024
3c9fe3d
fixing
gshennvm Apr 8, 2024
f36f394
add
gshennvm Apr 8, 2024
5211bc2
chunking needs to be moved out
gshennvm Apr 8, 2024
0f59edf
fix
gshennvm Apr 8, 2024
c3fe2f7
fix metrics
gshennvm Apr 9, 2024
5d3e07d
fix dtype
gshennvm Apr 9, 2024
15887e5
merge
gshennvm Apr 9, 2024
2ad76ba
fix
gshennvm Apr 9, 2024
9d9a6b6
make the global id management into a class
gshennvm Apr 10, 2024
d6fb55d
fix
gshennvm Apr 11, 2024
0983164
trtllm0.9 changes (#149)
jiemingz Apr 17, 2024
fe765cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
dfac922
trtllm patch file
jiemingz Apr 17, 2024
c159aa3
dockerfile
jiemingz Apr 17, 2024
d81caef
fix build
gshennvm Apr 18, 2024
92c19f6
fix bug
gshennvm Apr 18, 2024
7088f54
add groupnorm build
jiemingz Apr 19, 2024
472a56c
upgrade to latest te and mcore
gshennvm Apr 18, 2024
032bf35
Merge remote-tracking branch 'origin/dev' into aligner_trt_build
gshennvm Apr 19, 2024
d5f55f5
fix
gshennvm Apr 19, 2024
c7cdca1
specify max token
Apr 20, 2024
04d02c8
fix
gshennvm Apr 20, 2024
56ccacf
Merge remote-tracking branch 'origin/geshen/main_trt' into aligner_tr…
gshennvm Apr 20, 2024
d23865f
fix critic checkpoint loading
gshennvm Apr 22, 2024
3c21c81
add assert
gshennvm Apr 22, 2024
2c99dcb
fix bug
gshennvm Apr 22, 2024
9e8526d
fix
gshennvm Apr 22, 2024
c1daeb9
fix
gshennvm Apr 23, 2024
e16c357
update dockerfile
gshennvm Apr 23, 2024
410eaf5
update to 24.03.01 deps
gshennvm Apr 23, 2024
e405432
fix
gshennvm Apr 24, 2024
07cfa67
update dockerfile
gshennvm Apr 24, 2024
b2dfee0
add dockerfileg
gshennvm Apr 26, 2024
63cd6b3
fix trtllm patch
Apr 29, 2024
6901348
clamp output with warning
Apr 29, 2024
74a0bb1
fix
gshennvm Apr 29, 2024
b6a05fd
remove debug statements
gshennvm Apr 30, 2024
db2701b
Merge remote-tracking branch 'origin/main' into aligner_trt_build
gshennvm Apr 30, 2024
8dd5c59
add debug info
gshennvm May 6, 2024
b5d6f88
bump pytrition version
gshennvm May 6, 2024
5464827
add critic speed
gshennvm May 7, 2024
00e4298
critic speedup
gshennvm May 7, 2024
fe6864b
Merge remote-tracking branch 'origin/geshen/critic_refactor' into ges…
gshennvm May 7, 2024
80579ec
fix
gshennvm May 13, 2024
f81f55a
add pad sequence length
gshennvm May 14, 2024
4a034f4
dockerfile
gshennvm May 15, 2024
66b5a54
higher stability
gshennvm May 16, 2024
7841381
Merge remote-tracking branch 'origin/main' into geshen/debug_critic
gshennvm May 16, 2024
1779c51
add
gshennvm May 16, 2024
e357ef9
add hack for ckpt
gshennvm May 24, 2024
6b606e8
fix conf
gshennvm May 24, 2024
a669837
no import
gshennvm May 24, 2024
ada5f45
add
gshennvm May 24, 2024
393acc6
fix
gshennvm May 25, 2024
b6a4d59
run through
gshennvm May 25, 2024
977e6e7
fix
gshennvm May 25, 2024
621718d
adaptive
gshennvm May 25, 2024
6109b8b
output tensor
gshennvm May 25, 2024
866c22b
add logging
gshennvm May 26, 2024
02aa2b8
fix for llama3
gshennvm May 28, 2024
e6f27c5
disable last checkpoint
gshennvm May 31, 2024
c689d2a
fix padding bug
gshennvm Jun 1, 2024
cd4aaa5
add critic warmup
gshennvm Jun 1, 2024
993e358
Revert "add"
gshennvm Jun 8, 2024
28fcaf3
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jun 8, 2024
ef347e5
fix module missing bug
gshennvm Jun 11, 2024
752d0bd
Ensure critic server does not squeeze out a singleton batch dim (#199)
terrykong Jun 11, 2024
78e6536
Merge branch 'geshen/llama3_rlhf' into geshen/trt_llm_to_main
gshennvm Jun 12, 2024
4de3eeb
Merge branch 'geshen/trt_llm_to_main' of github.com:NVIDIA/NeMo-Align…
gshennvm Jun 12, 2024
8a39881
TRTLLM PP wrong format WAR
jiemingz May 17, 2024
666e969
docker file branch
gshennvm Jun 12, 2024
3bec1bc
fix config
gshennvm Jun 12, 2024
3e7ca5f
remove prints
gshennvm Jun 12, 2024
12a0aae
remove print
gshennvm Jun 12, 2024
3956b6d
remove unneeded statement
gshennvm Jun 12, 2024
e090663
no save topk
gshennvm Jun 12, 2024
af83947
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jun 24, 2024
cc03b76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
605bda1
critic speedup
gshennvm Jun 24, 2024
b3dedfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
9fb90ff
fix
gshennvm Jun 24, 2024
bf62bcc
Merge branch 'geshen/critic_speedup' of github.com:NVIDIA/NeMo-Aligne…
gshennvm Jun 24, 2024
aea50ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
0a9416e
fix
gshennvm Jun 24, 2024
41ffeb5
Merge branch 'geshen/critic_speedup' of github.com:NVIDIA/NeMo-Aligne…
gshennvm Jun 24, 2024
e3d85bf
add parallel_state
gshennvm Jun 24, 2024
1225041
fix
gshennvm Jun 24, 2024
5315fc5
rename
gshennvm Jun 24, 2024
9c6141b
fix
gshennvm Jun 24, 2024
2cef93e
fix
gshennvm Jun 24, 2024
f5bd8c5
pull changes from degert/spin-trt-beta (#220)
gshennvm Jun 25, 2024
cbd7095
add text input support
gshennvm Jun 25, 2024
c72e4c1
clean up location of tokenization
gshennvm Jun 25, 2024
220cf17
remove useless imports
gshennvm Jun 25, 2024
1ed3c46
refactor rm server
gshennvm Jun 26, 2024
aa4249d
remove run rm file
gshennvm Jun 26, 2024
22615b7
remove preferred batch size logic
gshennvm Jun 26, 2024
95a9507
add comment
gshennvm Jun 26, 2024
0fdb124
allow users to specify their own preferred batch size
gshennvm Jun 26, 2024
031287c
add comments and add changelog
gshennvm Jun 26, 2024
fa37930
update changelog
gshennvm Jun 26, 2024
c83a98d
remove old reward model callable
gshennvm Jun 26, 2024
a9f722a
inference should be done with collect loss data =True
gshennvm Jun 26, 2024
15460d2
Merge branch 'geshen/critic_speedup' into geshen/trt_llm_to_main
gshennvm Jun 27, 2024
f6c09ea
fix issues with merge
gshennvm Jun 27, 2024
1d85152
cleanup configs
gshennvm Jun 27, 2024
b3078c5
add strip to sequence length
gshennvm Jun 27, 2024
1eb0823
Merge remote-tracking branch 'origin/geshen/critic_speedup' into gesh…
gshennvm Jun 27, 2024
6203e90
change
gshennvm Jun 27, 2024
e51c45f
fix
gshennvm Jun 27, 2024
1b220f1
clean actor
gshennvm Jun 28, 2024
0815b56
backwards compatibility in actor
gshennvm Jun 28, 2024
8d75cbf
Apply suggestions from code review
gshennvm Jun 28, 2024
ab4c549
modify changelog
gshennvm Jun 28, 2024
82fb3a1
fixup! modify changelog
gshennvm Jun 28, 2024
12f85d2
add comments to ppo_critic config
gshennvm Jun 28, 2024
f90f4d6
add note on breaking change in inference rm
gshennvm Jun 28, 2024
527557d
change inference mbs to 4
gshennvm Jun 28, 2024
fe9d288
add comments for inference rm config
gshennvm Jun 28, 2024
f3124d3
revert gbs flag back to previous in ppo critic
gshennvm Jun 28, 2024
efadcae
delete unused variable
gshennvm Jun 28, 2024
35a2895
Update nemo_aligner/algorithms/critic_server_trainer.py
gshennvm Jun 28, 2024
4487932
remove add_eos arg, and update attribute annotate script
gshennvm Jun 29, 2024
7e7f27b
Merge branch 'geshen/critic_speedup' of github.com:NVIDIA/NeMo-Aligne…
gshennvm Jun 29, 2024
e9c7b39
no mutation on inputs when processing them for inference
gshennvm Jun 29, 2024
c6f6da4
fix bug when padding
gshennvm Jun 29, 2024
ebb69f4
add comment for forward_micro_batch_size in training_rm.yaml
gshennvm Jun 29, 2024
2775e81
change non_blocking to use sync
gshennvm Jun 29, 2024
fe0399f
Merge branch 'geshen/critic_speedup' into geshen/trt_llm_to_main
gshennvm Jun 30, 2024
ffa253f
nemo export api changes
jiemingz Jul 1, 2024
7ca9e34
upgrade to newer nemo export
gshennvm Jul 1, 2024
8181168
fix imports
gshennvm Jul 1, 2024
4d0853d
Communicator hang fix in the actor loop (#200)
terrykong Jul 1, 2024
ec548b8
add nemo guard for when things don't stop properly
gshennvm Jul 3, 2024
ce7a07f
cleanup communicator clean
gshennvm Jul 3, 2024
bb2fc48
fix
gshennvm Jul 3, 2024
606f690
critic speedup
gshennvm Jun 24, 2024
f48dc29
middle of PP should be broadcasted as well
gshennvm Jul 11, 2024
708bc24
update with critic changes
gshennvm Jul 11, 2024
48ad685
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jul 11, 2024
d053475
general cleanup
gshennvm Jul 11, 2024
0b4a92d
add checker for trt
gshennvm Jul 11, 2024
b72a5ec
remove comments
gshennvm Jul 11, 2024
984acaa
fix
gshennvm Jul 12, 2024
c11e1d7
fix
gshennvm Jul 12, 2024
14a9926
another fix
gshennvm Jul 12, 2024
7c2fc3e
add typing
gshennvm Jul 12, 2024
fe02867
cleanup
gshennvm Jul 12, 2024
02ad2fa
ppo trainer should use stop and get time
gshennvm Jul 12, 2024
24c53be
add some comments
gshennvm Jul 12, 2024
8a25e5e
critic warmup should have good default
gshennvm Jul 12, 2024
24f138a
added ppo in changelog
gshennvm Jul 12, 2024
9c72c53
add comments
gshennvm Jul 12, 2024
5ed9cd8
Avoids crash in PPOTrainer if using adafactor w/o learning rate (#234)
terrykong Jul 12, 2024
8b6627a
rename
gshennvm Jul 12, 2024
56032c8
Merge branch 'geshen/trt_llm_to_main' of github.com:NVIDIA/NeMo-Align…
gshennvm Jul 12, 2024
1e17f8b
Raise exceptions if using trtllm and use_Greedy in sampling params is…
terrykong Jul 12, 2024
e0a94d0
fix bugs
gshennvm Jul 12, 2024
280ad36
Initial commit of TRT version of self-rewarding
trias702 Jul 13, 2024
3af712d
Fixed bug again for limit_train_batches
trias702 Jul 13, 2024
835b3b3
cleanup pad id handling when PP > 1
gshennvm Jul 13, 2024
2b95331
fix issue with PP > 1 check
gshennvm Jul 14, 2024
261269a
add is_end logic
gshennvm Jul 14, 2024
83ba660
add is end logic
gshennvm Jul 14, 2024
f3912e7
add is end logic
gshennvm Jul 14, 2024
09d2783
fix
gshennvm Jul 14, 2024
5105ed9
fix
gshennvm Jul 14, 2024
a2bf8a0
fix another bug
gshennvm Jul 15, 2024
d9d45d6
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jul 15, 2024
f00d09e
update changelog
gshennvm Jul 15, 2024
830e599
update dockerfile
gshennvm Jul 15, 2024
09c357c
Update the hash of the conversion script to include TE fix for mcore …
terrykong Jul 16, 2024
c41cc08
update docs
gshennvm Jul 16, 2024
ab8c97b
Lots of bugfixes
trias702 Jul 17, 2024
92fec51
Exposed repetition_penalty to TRT and added in intra-mb randomness
trias702 Jul 17, 2024
7f7d4f9
removed assert statements in trt_llm
trias702 Jul 18, 2024
395cf04
add repetition penalty
gshennvm Jul 18, 2024
3a8da14
random seed and better clamping
gshennvm Jul 18, 2024
4174248
add memory logging
gshennvm Jul 18, 2024
36d8ab4
make it clear seed can be None
gshennvm Jul 18, 2024
82e8793
fix seed args
gshennvm Jul 18, 2024
7506be0
clear memory
gshennvm Jul 19, 2024
c8b88c3
Fixed the PP bug I was getting
trias702 Jul 19, 2024
1d1f051
add more clear memory
gshennvm Jul 20, 2024
aae1cd3
remove assert
gshennvm Jul 21, 2024
b4afcf6
RC-1 status reached
trias702 Jul 25, 2024
dab8c61
fix unloading
gshennvm Jul 25, 2024
5651e20
Added new generation algo
trias702 Jul 27, 2024
ff3bc1c
optimisations for generation code
trias702 Jul 30, 2024
aadc662
merged and made more bugfixes
trias702 Aug 1, 2024
314c217
Fixed final trt decoding bug for pp > 1 and llama3 tokenizer
trias702 Aug 3, 2024
2fdb5b0
RC-2 status reached
trias702 Aug 5, 2024
31f4ba3
Integrated incorrect ref log probs fixes from PR 228
trias702 Aug 5, 2024
7dd511c
Added length control methodology to self rewarding
trias702 Aug 9, 2024
11d22e8
Fixed bug with length control
trias702 Aug 12, 2024
de8dda8
Fixed LocalNonpersistentObject error and added meta judge logic
trias702 Aug 17, 2024
4f81524
Attempts to fix the oscillation issue
trias702 Aug 27, 2024
723f55f
Stable fix for meta judge oscillation
trias702 Sep 5, 2024
acd4d07
Moved the templates to the conf file instead
trias702 Sep 24, 2024
a249a44
Merged to latest aligner main
trias702 Sep 24, 2024
ce687ed
Merge remote-tracking branch 'origin' into degert/self-rewarding-trt
trias702 Sep 25, 2024
005702b
Fixes to yaml tab issues
trias702 Sep 25, 2024
a35e359
Far enough along that I can file the PR
trias702 Sep 26, 2024
fc1def7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
eab62af
Added self-rewarding rst doc
trias702 Sep 26, 2024
8805db5
Merge branch 'degert/self-rewarding-trt' of https://github.com/NVIDIA…
trias702 Sep 26, 2024
fb61b86
Fixed bad merge in utils.py
trias702 Sep 26, 2024
a7817b3
Removed trt patch file
trias702 Sep 26, 2024
71029d9
Fixed ordering of clamp bug
trias702 Oct 1, 2024
9051695
Merge remote-tracking branch 'origin' into degert/self-rewarding-trt
trias702 Oct 1, 2024
35f8ee4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
ab2d3ce
Merge branch 'degert/self-rewarding-trt' of https://github.com/NVIDIA…
trias702 Oct 1, 2024
294afcd
Merge remote-tracking branch 'origin' into degert/self-rewarding-trt
trias702 Oct 24, 2024
c465f30
Added CI test for SPIN
trias702 Oct 26, 2024
2092084
Added CI tests for self-rewarding
trias702 Oct 29, 2024
99c7bcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
4e435f7
Added CI tests for generation
trias702 Oct 29, 2024
3d429fb
Merged in v13 TRT changes
trias702 Nov 1, 2024
b460a14
Removed max_input_tokens from TRT algos
trias702 Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ jobs:
test_case:
- ppo-llama3-pp2-reshard
- dpo-llama3
- spin-llama3
- self_rewarding-llama3
- generation-llama3

with:
RUNNER: self-hosted-azure
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### New Features and Optimizations
- Implement Kahneman-Tversky Optimization (KTO).
- Sequence packing is now supported when running SFT with SFTChatDataset.
- Implemented the self rewarding and meta self rewarding algorithms.

### Breaking Changes

Expand Down
214 changes: 214 additions & 0 deletions docs/user-guide/self_rewarding.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
.. include:: /content/nemo.rsts

Model Alignment by Self-Rewarding Language Models
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

Original paper: https://arxiv.org/abs/2401.10020
Meta Self-Rewarding paper: https://arxiv.org/abs/2407.19594

The NeMo framework supports efficient model alignment via the NeMo Aligner codebase.

All algorithms in NeMo Aligner will work with any GPT based model that is from mcore(i.e in the config it has ``mcore_gpt=True``). For the purposes of this tutorial, we will go through the entire Self-Rewarding pipeline using the newly released `2B GPT model with 4096 sequence length <https://huggingface.co/nvidia/GPT-2B-001>`__. This same tutorial also works for other GPT models(such as LLaMa2) of any size.

Obtaining a pretrained model
############################
To start, we must first get a pretrained model to align. There are 2 models we recommend to get started. The rest of the tutorial will work with either model, but for demonstration purposes we will use the smaller 2B model.

.. tab-set::

.. tab-item:: 2B GPT
:sync: key1

#. Get the 2B checkpoint via ``wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo``
#. Extract the NeMo File to a folder with ``mkdir model_checkpoint && tar -xvf GPT-2B-001_bf16_tp1.nemo -C model_checkpoint``
#. And then run the script to convert from old NeMo checkpoint to Megatron-Core checkpoint. The script is located `here <https://github.com/NVIDIA/NeMo/blob/86b198ff93438d454f9c7f3550bcfb7d4e59feab/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py>`__.
.. code-block:: bash

python convert_nemo_gpt_to_mcore.py \
--in-folder ./model_checkpoint \
--out-file ./mcore_gpt.nemo

.. tab-item:: LLaMa2 7B
:sync: key2

#. Download the `Llama 2 7B LLM model and tokenizer <https://huggingface.co/meta-llama/Llama-2-7b>`__ into the models folder.
#. Convert the LLaMa2 LLM into ``.nemo`` format
.. code-block:: bash

python /opt/NeMo/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \
--input_name_or_path /path/to/llama --output_path /output_path/mcore_gpt.nemo

After these steps you should have a file ``mcore_gpt.nemo`` to use in NeMo-Aligner.

.. note::
Mcore models use TransformerEngine as a backend, and it tries to find efficient kernels. But depending on the GPU you have it may not find them. If you ever face errors that relate to kernel finding set these variables on top of your script.

.. code-block:: bash

export NVTE_MASKED_SOFTMAX_FUSION=0
export NVTE_FLASH_ATTN=0
export NVTE_FUSED_ATTN=0

Additionally, TransformerEngine is non-deterministic by default, meaning subsequent runs of SPIN using identical parameters will produce different results, which is not ideal for parameter perturbation.
Helpfully, TransformerEngine exposes a flag to set if you want to guarantee deterministic training runs:

.. code-block:: bash

export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
export NVTE_MASKED_SOFTMAX_FUSION=0

SFT vs Foundational (base) model for Self-Rewarding Training
############################################################
Self-Rewarding can be run on either base/foundational models, that is, models which have only been trained on autoregressive language prediction tasks and not on instruction following tasks,
or, you can also run Self-Rewarding on models which have been SFTed on instruction-based datasets as well, similar to DPO/PPO. Either type of model will work well with Self-Rewarding. If you would like to start with a supervised fine tuned model instead of a base model, please see our full guide on how to perform SFT on a Megatron GPT model :ref:`SFT guide <sft>`.

Self-Rewarding Model Training
#############################

Self-Rewarding training uses the exact same dataset formatting and files as the NeMo-Aligner SFT trainer. Please see the data formatting section of SFT to understand the data format necessary for SPIN :ref:`SFT guide <sft>`

Once your data is processed into the correct format you are ready to begin Self-Rewarding training. You must start with a pretrained or SFT trained model. For this section we will use the SFT model trained in the previous step to train the Self-Rewarding model.
For the purposes of the following sections, we'll assume your training jsonl file is located in ``/path/to/train_sft_format.jsonl`` and your validation jsonl file is located in ``/path/to/valid_sft_format.jsonl``.

Due to some limitations of the Nemo Aligner system and reusing code files, the parameters for Self-Rewarding share the same parameter namespace as SPIN, so these parameters are labelled as ``spin``, but they apply to the self-rewarding algorithm.

For the parameters below, the ``model.spin.ref_policy_kl_penalty`` corresponds to the beta parameter in the Self-Rewarding paper, and ``trainer.self_rewarding.max_iterations`` corresponds to number of iterations.

Self-Rewarding is a very generation-heavy algorithm, with N*k generations per sample in the training data. As such, it is highly advisable to enable TRTLLM in order to vastly speedup training generation times (5-7X speedup).
You can enable TRT by setting ``trainer.self_rewarding.trt_llm.enable=true`` along with ``trainer.self_rewarding.trt_llm.model_type`` (set this to ``gptnext`` for Nemotron models, and ``llama`` for the llama family of models).
If you want to train using Meta-Self-Rewarding instead of the original Self-Rewarding, you need to set ``model.spin.use_meta_judge=true``. When using meta mode, you need to also set ``model.spin.meta_judge_pcnt`` which controls the maximum percent of any GBS which can be populated by meta-judge training samples.
If you want to use Length Control (Meta-Self-Rewarding paper, section 2.1, last paragraph), you can set that with ``model.spin.length_control``. This parameter accepts either a scalar, or a list of size == number of iterations, where
each iteration will apply its corresponding length control value. This allows you to create a schedule of different length control values for each iteration. This logic will work for both Self-Rewarding and Meta Self-Rewarding.
You can also control which variant of DPO loss is used for training using the ``model.spin.preference_loss`` parameter. Valid entries are: dpo, scale, rpo_bwd_kl, rpo_fwd_kl, ipo, and rpo_sq. Default is dpo.


.. tab-set::

.. tab-item:: Terminal
:sync: key3

To run Self-Rewarding model training on the terminal directly

.. code-block:: bash

export GPFS="/path/to/nemo-aligner-repo"
export TRAIN_DATA_PATH="/path/to/train_sft_format.jsonl"
export VALID_DATA_PATH="/path/to/valid_sft_format.jsonl"

python -u ${GPFS}/examples/nlp/gpt/train_gpt_self_rewarding.py \
trainer.num_nodes=1 \
trainer.devices=8 \
model.micro_batch_size=1 \
model.global_batch_size=64 \
pretrained_checkpoint.restore_from_path=/path/to/megatron_gpt_sft.nemo \
"model.data.train_ds.file_path=${TRAIN_DATA_PATH}" \
"model.data.validation_ds.file_path=${VALID_DATA_PATH}" \
exp_manager.create_wandb_logger=false \
exp_manager.wandb_logger_kwargs.project=spin_training \
exp_manager.wandb_logger_kwargs.name=spin_training \
exp_manager.explicit_log_dir=/results \
++model.sequence_parallel=false \
++model.apply_rope_fusion=false \
trainer.self_rewarding.max_iterations=3 \
trainer.self_rewarding.max_epochs=1 \
model.spin.ref_policy_kl_penalty=0.1 \
model.spin.use_meta_judge=false \
model.spin.length_params.max_length=2048 \
model.data.train_ds.max_seq_length=4096

.. tab-item:: Slurm
:sync: key4

To run SPIN model training using Slurm. The script below uses 4 nodes, but you can change the node count to something different.

.. code-block:: bash

#!/bin/bash
#SBATCH -A <<ACCOUNT NAME>>
#SBATCH -p <<PARTITION NAME>>
#SBATCH -N 4
#SBATCH -t 4:00:00
#SBATCH -J <<JOB NAME>>
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node 8
#SBATCH --exclusive
#SBATCH --overcommit

GPFS="/path/to/nemo-aligner-repo"
PRETRAINED_CHECKPOINT_NEMO_FILE="/path/to/megatron_gpt_sft.nemo"

TRAIN_DATA_PATH="/path/to/train_sft_format.jsonl"
VALID_DATA_PATH="/path/to/valid_sft_format.jsonl"

PROJECT="<<WANDB PROJECT>>"

CONTAINER=<<<CONTAINER>>> # use the latest NeMo Training container, Aligner will work there
MOUNTS="--container-mounts=${GPFS}:${GPFS},${TRAIN_DATA_PATH}:${TRAIN_DATA_PATH},${VALID_DATA_PATH}:${VALID_DATA_PATH},${PRETRAINED_CHECKPOINT_NEMO_FILE}:${PRETRAINED_CHECKPOINT_NEMO_FILE}"

RESULTS_DIR="/path/to/result_dir"

OUTFILE="${RESULTS_DIR}/rm-%j_%t.out"
ERRFILE="${RESULTS_DIR}/rm-%j_%t.err"
mkdir -p ${RESULTS_DIR}

read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& echo "---------------" \
&& echo "Starting training" \
&& cd ${GPFS} \
&& export PYTHONPATH="${GPFS}:${PYTHONPATH}" \
&& export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 \
&& export NVTE_MASKED_SOFTMAX_FUSION=0 \
&& export HYDRA_FULL_ERROR=1 \
&& python -u ${GPFS}/examples/nlp/gpt/train_gpt_self_rewarding.py \
trainer.num_nodes=${SLURM_JOB_NUM_NODES} \
trainer.devices=8 \
pretrained_checkpoint.restore_from_path='${PRETRAINED_CHECKPOINT_NEMO_FILE}' \
"model.data.train_ds.file_path=${TRAIN_DATA_PATH}" \
"model.data.validation_ds.file_path=${VALID_DATA_PATH}" \
model.micro_batch_size=1 \
model.global_batch_size=64 \
exp_manager.explicit_log_dir=${RESULTS_DIR} \
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name=${NAME} \
exp_manager.wandb_logger_kwargs.project=${PROJECT} \
trainer.self_rewarding.max_iterations=3 \
trainer.self_rewarding.max_epochs=1 \
model.spin.ref_policy_kl_penalty=0.1 \
model.spin.use_meta_judge=false \
model.spin.length_params.max_length=2048 \
model.data.train_ds.max_seq_length=4096
EOF

srun -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}"
set +x

During Self-Rewarding training, there will be several metrics recorded to WandB which you can monitor, the following of which are specific to Self-Rewarding:

- chosen_lengths: average token length of chosen responses (average taken across GBS)
- reject_lengths: as above but for rejected responses
- chosen_generated_rewards: the average reward (across GBS) generated by the LLM-as-a-judge for chosen responses
- rejected_generated_rewards: as above but for rejected responses
- rewards_chosen_mean: see below for a definition of what reward means in this context
- rewards_rejected_mean: as above but for rejected responses
- bad_samples_per_GBS: the percentage of samples in a GBS which are excluded from training because of bad output from the LLM-as-a-judge (could be caused by parse errors, or all responses being judge with the same score, etc)
- bad_ends_per_GBS: only valid if using TRT, this tracks the percentage of each GBS where TRT generates incorrect stop tokens (should be really low, < 1%)
- preference_loss: the raw DPO variant loss
- sft_loss: if adding an SFT loss (categorical cross-entropy loss) for the chosen response, then you can see that raw loss here

The ``reward`` in this case is calculated as the difference between model log probs and the reference log probs, multiplied by the KL penalty (beta in the original paper), for the ground truth and generated responses.
During training, the acc should generally be increasing, but don't worry if its absolute value remains low, as it doesn't correlate to finalised MTBench or MMLU scores. It should just be generally increasing.
All metrics will be grouped by either ``train/`` or ``val/`` in WandB, representing whether that metric is from the training or validation set, respectively.
You can also see a table which will print out the prompt, chosen response, and rejected response for each validation step. This allows you to keep track of response quality and hallucinations.

When it comes to ideal hyperparameters for Self-Rewarding training, much will depend on the characteristics of your SFT (or base/foundational) model and your training data, so there is no one-size-fits-all parameter set which will work in all cases.
Additionally, Self-Rewarding (with or without meta) is a complex algorithm with a lot of moving pieces and a lot of parameters, so finding what works well for your model and data can be difficult.
Below are some of observations from the Nvidia Alignment team as to what parameters we have seen work well:

* global_batch_size: we recommend using 64, and going up to 128 only for large models (70B+) that are also training with large datasets
* iterations/epochs: the original paper uses 3 iterations with 1 epoch per iteration, and we find this to be sufficient for most use cases
* learning rate: for SFT/aligned models, we recommend a smaller LR, between 3e-7 and 1e-7. If training a foundational model, then something between 3e-6 to 9e-7.
* ref_policy_kl_penalty: we did not see large changes from perturbations to this value; we recommend 0.1 - 0.001
* length_control: depends very much on model size and data, but we found good results with [0,0,0.1]
* use_meta_judge: we have found stronger results when settings this to true, which is in line with the paper's results
* meta_judge_pcnt: we recommend you do not set this higher than 0.15 (15%). Any higher, and we have observed that the llm-as-a-judge model starts to output identical scores for every response (always a 5)
Loading
Loading