Skip to content

Commit

Permalink
feat: enable trl's autounwrap (#1060)
Browse files Browse the repository at this point in the history
* feat: test trl's autounwrap

* fix: add check for adapter

* feat: add config to disable autounwrap

* chore: fix lint
  • Loading branch information
NanoCode012 authored Jan 11, 2024
1 parent 54fe07a commit b432889
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
// The flags below simplify debugging by overriding the axolotl config
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step
Expand Down
2 changes: 1 addition & 1 deletion devtools/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
8 changes: 4 additions & 4 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ While debugging it's helpful to simplify your test scenario as much as possible.
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
- `micro_batch_size: 1`
- `max_steps: 1`
- `max_steps: 1`
- `val_set_size: 0`
5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
- Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
- HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
- **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**


## Debugging with VSCode

Expand Down Expand Up @@ -74,7 +74,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
// The flags below simplify debugging by overriding the axolotl config
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step
Expand All @@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler

- The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
- The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
- `./devtools/temp_debug/axolotl_outputs`
- `./devtools/temp_debug/axolotl_outputs`
- `./devtools/temp_debug/.hf-cache/datasets`

>[!Tip]
Expand Down
9 changes: 9 additions & 0 deletions docs/rlhf.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ datasets:
```yaml
rl: ipo
```
#### Trl autounwrap for peft
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
```yaml
# load ref model when adapter training.
rl_adapter_ref_model: true
```
13 changes: 9 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,15 @@ def train(
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)

safe_serialization = cfg.save_safetensors is True

Expand Down

0 comments on commit b432889

Please sign in to comment.