Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instruction-tuning Support #196

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
59f0191
Add example config for the construction of chat templates
rrutmann Jul 15, 2024
8b60a83
chore: add chat template config based on jinja2
lllAlexanderlll Jul 15, 2024
ba2f65c
chore: update chat template config based on jinja2
lllAlexanderlll Jul 15, 2024
47e71c3
chore: Add apply chat template feature with role mapping
lllAlexanderlll Jul 15, 2024
3303147
chore: extend to multiple chat templates
lllAlexanderlll Jul 16, 2024
0c6bbf5
fix: data driven chat tempalte key retrieval
lllAlexanderlll Jul 16, 2024
32f5756
chore: Add 'index' to output JSONL
lllAlexanderlll Jul 16, 2024
482f7af
fix: Add süecical token to be kept during treinaing to allow for earl…
lllAlexanderlll Jul 16, 2024
1d72770
chore: Update output file
lllAlexanderlll Jul 16, 2024
0bd9bfa
build: Add jsonlines dependency
rrutmann Jul 16, 2024
ed2f4ce
chore: integration of collator wrapper with loss masking functionalit…
lllAlexanderlll Jul 16, 2024
6e24ea2
chore: Use SFT config replaction with uuid as file pair identification.
lllAlexanderlll Jul 18, 2024
6e716b4
chore: Add loss masking test
lllAlexanderlll Jul 18, 2024
a0376a6
chore: Merge branch 'main' into sft_loss_masking
lllAlexanderlll Jul 18, 2024
70dc498
fix: copy raw config file for truly original content
lllAlexanderlll Jul 19, 2024
242e429
chore: add pbin file for testing loss masking
Jul 22, 2024
bddcf8b
chore: add pbin file with more data for testing loss masking
Jul 22, 2024
f86b6ed
chore: use a hash not uuid for showing which config belongs to whoch …
lllAlexanderlll Jul 22, 2024
7632a02
chore: add pbin file for testing loss masking
Jul 22, 2024
15719a3
chore: Fix loss masking when starting within an assistant answer
lllAlexanderlll Jul 22, 2024
ab0f34c
chore: add lost collator wrappr again
lllAlexanderlll Jul 22, 2024
0a545ca
chore: fix the loss masking test and the implementation. Improve docu…
lllAlexanderlll Jul 22, 2024
c6d0a61
chore: Merge commit '15ed069beaa2c83dcd15b087e4d0864b1aec4caa' into s…
lllAlexanderlll Jul 22, 2024
0c52856
chore: Merge branch 'sft_with_main' into sft_loss_masking
lllAlexanderlll Jul 22, 2024
12c74bc
feat(sft): Do not reuse last targets for Instruction Tuning
rrutmann Jul 23, 2024
fc7bec1
Merge remote-tracking branch 'origin/sft_with_main' into sft_sample_g…
rrutmann Jul 23, 2024
25fdcd7
refactor(sft): Make reuse_last_target optional
rrutmann Jul 23, 2024
01109e2
docs: Correct spelling
rrutmann Jul 23, 2024
7148e1e
Update comment
rrutmann Jul 23, 2024
75611dd
Merge pull request #193 from Modalities/sft_sample_generator
rrutmann Jul 23, 2024
1f36d64
chore: apply review changes: only single chat template, do raise erro…
lllAlexanderlll Jul 23, 2024
9921ed3
chore: Merge branch 'sft_with_main' into sft_loss_masking
lllAlexanderlll Jul 23, 2024
0f49355
chore: run loss masking with padded and truncated pbin file. Refine e…
lllAlexanderlll Jul 23, 2024
bf2f1a3
chore: restore original lorem ipsum config, as we have our own sft co…
lllAlexanderlll Jul 23, 2024
99e7ac3
Merge pull request #192 from Modalities/sft_loss_masking
lllAlexanderlll Jul 24, 2024
76b34ab
feat: added fixed number of elements to ResumableBatchSampler
le1nux Jul 11, 2024
2f4bdb2
feat: added fixed number of batches to dataloader
le1nux Jul 11, 2024
9840e0a
fix: fixed error in fixed_num_batches calculation
le1nux Jul 17, 2024
3906845
feat: implemented test for fixed_num_batches in dataloader
le1nux Jul 17, 2024
cc57125
refactor: removed fixed_num_batches from config_lorem_ipsum.yaml
le1nux Jul 17, 2024
603b367
refactor: moved SequenceDataset to test
le1nux Jul 19, 2024
78f4d03
refactor: added another check that fixed_num_batches > skip_num_batches
le1nux Jul 19, 2024
8372b21
fix: check for fixed_num_batches is not None before comparison
lllAlexanderlll Jul 29, 2024
c8ba69c
chore: add missing reuse_last_target: False in example SFT config
lllAlexanderlll Jul 29, 2024
60e9894
chore: Merge branch 'main' into sft_with_main
le1nux Jul 30, 2024
b0c69a9
fix: removed sequence_length from MemMapDataset
le1nux Jul 31, 2024
d041900
chore: added reraise of exception in PackedMemMapDatasetBase
le1nux Jul 31, 2024
3912e36
chore: fixed typo
le1nux Jul 31, 2024
cc4eef7
refactor: made non-public methods related to apply_chat_template private
le1nux Jul 31, 2024
00a35ee
chore: minor style improvement
le1nux Jul 31, 2024
009bb8a
chore: fix suggestions of PR first round
lllAlexanderlll Aug 5, 2024
a47dc8c
chore: Move collate functions to dataloader package
lllAlexanderlll Aug 5, 2024
705101c
chore: renamed MaskingTokenConfig to LossMaskingTokenConfig
lllAlexanderlll Aug 5, 2024
c399b94
chore: Add explaination to vectorized loss masking
lllAlexanderlll Aug 5, 2024
483ea83
chore: moved SFTConfig and tests
lllAlexanderlll Aug 5, 2024
4ea1d76
chore: update SFT README
lllAlexanderlll Aug 12, 2024
3945426
chore: test for reuse last target, update sft readme, create folder f…
lllAlexanderlll Aug 12, 2024
4d7e53d
chore: Merge branch 'main' into sft_with_main
lllAlexanderlll Aug 12, 2024
eee2bac
chore: fix tokenization tests and renaming of loss masking config field
lllAlexanderlll Aug 12, 2024
4f53f0c
chore: Update SFT_README.md
lllAlexanderlll Aug 13, 2024
ed50d2f
docs: add doc strings
lllAlexanderlll Aug 13, 2024
d5867a4
chore: update instruction tuning e2e test with output artifact check
lllAlexanderlll Aug 13, 2024
94d89cb
chore: Update readme
lllAlexanderlll Aug 13, 2024
42cf6ce
chore: refine names of helper functions and doc strings
lllAlexanderlll Aug 19, 2024
d98a26a
fix: apply renaming
lllAlexanderlll Aug 19, 2024
d95bd46
chore: Update SFT_README
lllAlexanderlll Aug 19, 2024
c6b0e4c
chore(sft): Improve check on correctness of loss masked sequences
rrutmann Aug 20, 2024
b9fbcec
chore(sft): Change special tokens used for instruction tuning
rrutmann Aug 20, 2024
6538072
chore: Add artifacts to .gitignore
rrutmann Aug 20, 2024
2748887
chore(sft): Add splitting functionality and introduce a new entry poi…
lllAlexanderlll Aug 20, 2024
72ed828
fix(sft): do not append hash twice
lllAlexanderlll Aug 20, 2024
6594633
test(sft): Use special tokens already existing in tokenizers vocabulary
rrutmann Sep 9, 2024
66f0bea
test(sft): Add data and config for tests
rrutmann Sep 9, 2024
0daec5b
test(sft): Add documentation for test
rrutmann Sep 9, 2024
125311f
chore: Pass black check
rrutmann Sep 9, 2024
9cbc8f3
chore: Merge remote-tracking branch 'origin/main' into sft_with_main
rrutmann Sep 9, 2024
b121eea
chore: improve error message and readme
lllAlexanderlll Sep 16, 2024
396aba5
chore: Update SFT_README.md
lllAlexanderlll Sep 16, 2024
8416c9d
Update SFT_README.md
lllAlexanderlll Sep 16, 2024
b08f02f
chore: Merge branch 'main' into sft_with_main
rrutmann Sep 23, 2024
eb658c9
test: fix failing sft e2e test
davidkaczer Nov 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions SFT_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Supervised Fine-tuning with Modalities

Currently supported are Instruction-tuning and Low-rank Adaption (LorA), as explained in more detail next.

## Instruction-tuning
* entry point to prepare data
* jinja2 templates
* The b_include_to_loss_token, e_include_to_loss_token are required to be part of each chat template for proper loss masking!
* hash to connect files

* truncation, padding
* re-use last target
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved

### Create Prompts from Conversations
To prepare the instruction-tuning data we created a new entry point `apply_chat_template`, which requires a [configuration file](./config_files/data_preparation/apply_chat_template_config.yaml). Wihtin it we define:
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved
* the path to instruction-tuning dataset as a JSONL file wereas each line contains a structured conversation as an array of dictionaries.
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved
* A [jinja2](https://jinja.palletsprojects.com/en/3.1.x/) chat template which defines the rules how to glue `chat_template_data` and the data within the JSONL together to one `chat` string.

As part of the `chat_template_data`, we require the special tokens `b_include_to_loss_token` and `e_include_to_loss_token`.
> ❗ You should choose sequences which are tokenized into a single token and will not appear in the assistant utterances of the instruction-tuning data!

They are used to mark the begin and end of the assistant turns, as we need to include only tokens between those into the loss computation during instruction-tuning with modalities.

```yaml
chat_template_data:
...
special_tokens:
b_include_to_loss_token: ^
e_include_to_loss_token: $
```
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved

Run the `apply_chat_template` entry point with:
```bash
modalities data apply_chat_template --config_file_path config_files/data_preparation/apply_chat_template_config.yaml
```

This will create two files
1. The new JSONL file with a new attribute `chat` containing the conversations e.g. `lorem_ipsum_sft_converted.aadd295.jsonl`
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved
2. The config used to generate the `chat` e.g. `sft_chat_template_config.aadd295.yaml`

> Both files names contain the first 7 symbols of the hash of the config file, to group files which belong together!

### Create idx and pbin files
Before continuing with the instruction-tuning you need to index the created JSONL and convert it to a packed data file.
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved

> Make sure to use the same hash for correct grouping when defining the output file names!

For example:
```bash
# create idx file
modalities data create_raw_index --index_path data/lorem_ipsum_sft_converted.aadd295.idx data/lorem_ipsum_sft_converted.aadd295.jsonl

# create pbin file
modalities data pack_encoded_data --config_file_path config_files/data_preparation/packed_chat_dataset_config.yaml
```

> The [packed_chat_dataset_config.yaml](config_files/data_preparation/packed_chat_dataset_config.yaml) must use truncation and padding!

### Instruction-Tuning

With your prepared instruction-tuning data as pbin file, you can now instruction-tune.

Make sure to use the wrapped collate function.

* You need to look up the `b_include_to_loss_token` and `e_include_to_loss_token` as defined within your `sft_chat_template_config.aadd295.yaml`. If configured the pbin creation correctly, you only need to check for matching hash suffixes.
* Set the `loss_ignore_index` which gets ignored by your loss function. In torch this is usually -100.
* We need a tokenizer to tokenize the `b_include_to_loss_token` and `e_include_to_loss_token`
* We need to not re-use the last token

For example (Copied from [config_files/training/config_lorem_ipsum_sft.yaml](config_files/training/config_lorem_ipsum_sft.yaml)):
```yaml
collate_fn:
component_key: collate_fn
variant_key: mask_loss_collator_wrapper
config:
wrapped_collate_fn:
component_key: collate_fn
variant_key: gpt_2_llm_collator
config:
sample_key: ${settings.referencing_keys.sample_key}
target_key: ${settings.referencing_keys.target_key}
target_keys_to_mask:
- ${settings.referencing_keys.target_key}
loss_ignore_index: -100
mask_tokens:
b_include_to_loss_token: ^
e_include_to_loss_token: $
tokenizer:
instance_key: tokenizer
pass_type: BY_REFERENCE
```
and
```yaml
train_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_continuous
config:
raw_data_path: ./data/lorem_ipsum_sft_converted.aadd295.pbin
sequence_length: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}
reuse_last_target: true
```

# TODO
lllAlexanderlll marked this conversation as resolved.
Show resolved Hide resolved
Reuse last token

Finally, run the instruction-tuning with the `run` entry point:
```bash
torch.distributed.run --nnodes 1 --nproc_per_node 2 --rdzv-endpoint=0.0.0.0:29555 src/modalities/__main__.py run --config_file_path config_files/training/config_lorem_ipsum_sft.yaml
```

## Low-rank Adaption (LorA)

TBD
37 changes: 37 additions & 0 deletions config_files/data_preparation/apply_chat_template_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
settings:
src_path: data/lorem_ipsum_sft.jsonl
dst_path: data/lorem_ipsum_sft_converted.jsonl
conversations_key: conversations

instruction_data_transformation:
role_mapping:
human_1: User1
human_2: User2
gpt: Assistant

# The b_include_to_loss_token, e_include_to_loss_token are required to be part of each chat template for proper loss masking!
jinja2_chat_template: |
le1nux marked this conversation as resolved.
Show resolved Hide resolved
{{ chat_template_data.system_instruction + '\n' }}
{% for turn in conversation %}
{{ turn.from + ':' }}
{% if turn.from == chat_template_data.assistant_role %}
{{ chat_template_data.special_tokens.b_include_to_loss_token}}
{% else %}
{{ " " }}
{% endif %}
{{ turn.value + '\n'}}
{% if turn.from == chat_template_data.assistant_role %}
{{ chat_template_data.special_tokens.e_assistant_token}}
{{ chat_template_data.special_tokens.e_include_to_loss_token}}
{% endif %}
{% endfor %}

# The key-value pairs of chat_template_data are passed to the Jinja2 template and
# are not type checked for full compliance with the chat tempalate!
chat_template_data:
assistant_role: Assistant
system_instruction: "You are Mody, a helpful assistant trained by the modalities team. Answer friendly and informatively to the user's messages."
special_tokens:
b_include_to_loss_token: ^
e_include_to_loss_token: $
e_assistant_token: °
le1nux marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions config_files/data_preparation/packed_chat_dataset_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
settings:
src_path: data/lorem_ipsum_sft_converted.aadd295.jsonl
dst_path: data/lorem_ipsum_sft_converted.aadd295.pbin
index_path: data/lorem_ipsum_sft_converted.aadd295.idx
jq_pattern: .chat
num_cpus: 1
eod_token: <|endoftext|>
processing_batch_size: 5
raw_samples_queue_size: 300
processed_samples_queue_size: 300
sequence_length: 2048

tokenizer:
component_key: tokenizer
variant_key: pretrained_hf_tokenizer
config:
pretrained_model_name_or_path: data/tokenizer/hf_gpt2
padding: max_length
truncation: true
max_length: ${settings.sequence_length}
special_tokens:
pad_token: ${settings.eod_token}
Loading