Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Long Text Fine-Tuning Support #5532

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LongWriter-GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LongWriter-GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
Expand Down
8 changes: 8 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -622,5 +622,13 @@
"prompt": "content"
},
"folder": "python"
},
"nlp_paper_inst": {
"file_name": "nlp_paper_inst.json",
"columns": {
"prompt": "prompt",
"response": "response",
"system": "system"
}
}
}
4,389 changes: 4,389 additions & 0 deletions data/nlp_paper_inst.json

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions examples/train_lora/start_glm_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@


CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 WANDB_API_KEY=974207f7173417ef95d2ebad4cbe7f2f9668a093 llamafactory-cli train \
--stage sft \
--do_train True \
--model_name_or_path /mnt/ceph/develop/jiawei/model_checkpoint/LongWriter-glm4-9b-base \
--preprocessing_num_workers 1 \
--finetuning_type lora \
--template glm4 \
--flash_attn auto \
--dataset_dir data \
--dataset nlp_paper_inst \
--cutoff_len 21000 \
--learning_rate 2e-5 \
--num_train_epochs 1.0 \
--max_samples 100000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 1 \
--save_steps 1 \
--warmup_steps 0 \
--warmup_ratio 0.03 \
--weight_decay 0.1 \
--optim adamw_torch \
--packing True \
--report_to wandb \
--run_name LongWriter-test4 \
--output_dir saves/LongWriter-glm4-9b-base/lora/train_2024-11-01-test4 \
--bf16 True \
--plot_loss True \
--ddp_timeout 180000000 \
--include_num_input_tokens_seen True \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout 0.05 \
--lora_target all
23 changes: 18 additions & 5 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,34 @@ def _encode_supervised_example(
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
pack_data_preprocess: bool = False,
) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
eos_indice = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break

source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
if pack_data_preprocess and len(source_ids)+len(target_ids) >= cutoff_len:
raise ValueError(f"""Packing dataset `len(source_ids)+len(target_ids)` needs a larger cutoff_len:
{len(source_ids)+len(target_ids)}> {cutoff_len}""")
else:
if eos_indice >= cutoff_len:
logger.warning(f"""cutoff_len {cutoff_len} is too small for the input turn_idx: {turn_idx}, drop it.
eg: The eos_indice is exactly one less than the bubble length, causing the last one to be discarded.
""")
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L59 raise exception.
curious why L66 just break?

Copy link
Author

@glide-the glide-the Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When pack_data_preprocess is true, cutoff_len is not used for truncating the input

pack_data_preprocess and len(source_ids)+len(target_ids) >= cutoff_len:
Used for verifying the maximum packing of long texts. For example, when the message length is >= 21, it should report an error instead of discarding the data if it doesn't form a complete training pack.

Ideal situation:
image

Cases where an error should be reported:
image

preprocess_packed_supervised_dataset receives batched data from dataset.map. When the number of processing threads is 1, only one process handles the data. The graph is too abstract; normally, it would be divided into batch_size pieces for all processes to handle.


    dataset = dataset.map(
        preprocess_func,
        batched=True,
        batch_size=data_args.preprocessing_batch_size,
        remove_columns=column_names,
        **kwargs,
    )

@hiyouga

if pack_data_preprocess:
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), len(source_ids)+len(target_ids))
else:
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - eos_indice)

source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
eos_indice += source_len + target_len

if train_on_prompt:
source_label = source_ids
Expand Down Expand Up @@ -157,6 +169,7 @@ def preprocess_packed_supervised_dataset(
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
pack_data_preprocess=data_args.pack_data_preprocess,
)
length = len(input_ids)
if length > data_args.cutoff_len:
Expand Down
10 changes: 10 additions & 0 deletions src/llamafactory/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,16 @@ def register_model_group(
)


register_model_group(
models={
"LongWriter-glm4-9b": {
DownloadSource.DEFAULT: "THUDM/LongWriter-glm4-9b",
}
},
template="glm4_long",
)


register_model_group(
models={
"InternLM-7B": {
Expand Down
6 changes: 6 additions & 0 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class DataArguments:
default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
pack_data_preprocess: bool = field(
default=False,
metadata={"help": """defult 'False', When pack_data_preprocess is true, cutoff_len is not used for truncating the input;
hiyouga marked this conversation as resolved.
Show resolved Hide resolved
instead, it checks whether the training input parameter exceeds this value.
If it does, an error is raised."""},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether or not to disable the mask on the prompt."},
Expand Down
3 changes: 3 additions & 0 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column():
packing = gr.Checkbox()
neat_packing = gr.Checkbox()
pack_data_preprocess = gr.Checkbox()

with gr.Column():
train_on_prompt = gr.Checkbox()
Expand All @@ -119,6 +120,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim,
packing,
neat_packing,
pack_data_preprocess,
train_on_prompt,
mask_history,
resize_vocab,
Expand All @@ -137,6 +139,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim=optim,
packing=packing,
neat_packing=neat_packing,
pack_data_preprocess=pack_data_preprocess,
train_on_prompt=train_on_prompt,
mask_history=mask_history,
resize_vocab=resize_vocab,
Expand Down
24 changes: 24 additions & 0 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,30 @@
"info": "패킹된 시퀀스 간의 크로스 어텐션을 피합니다.",
},
},
"pack_data_preprocess": {
"en": {
"label": "Preprocess pack data",
"info": """cutoff_len is not used for truncating the input;
instead, it checks whether the training input parameter exceeds this value.
If it does, an error is raised.""",
},
"ru": {
"label": "Предобработка упакованных данных",
"info": "избегайте перекрестного внимания между упакованными последовательностями."
},
"zh": {
"label": "提前打包数据",
"info": """cutoff_len 不用于截断输入;
相反,它会检查训练输入参数是否超出此值。
如果超出,则会引发错误。""",
},
"ko": {
"label": "데이터 전처리",
"info": """cutoff_len은 입력을 자르는 데 사용되지 않으며;
대신, 훈련 입력 매개변수가 이 값을 초과하는지 확인합니다.
초과할 경우 오류가 발생합니다."""
},
},
"train_on_prompt": {
"en": {
"label": "Train on prompt",
Expand Down
24 changes: 24 additions & 0 deletions tests/data/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,27 @@ def test_yi_template():
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)



@pytest.mark.xfail(reason="The fast tokenizer of glm4 LongWriter model is corrupted.")
def test_glm_long_template():
prompt_str = (
"<|user|>\nHow are you"
"<|assistant|>\nI am fine!"
"<|user|>\n你好"
"<|assistant|>\n"
)
answer_str = "很高兴认识你!"
model_id = "THUDM/LongWriter-glm4-9b"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, token=HF_TOKEN, trust_remote_code=True)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="glm4_long"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str
input_ids = content_ids[0]
assert input_ids == prompt_ids + answer_ids
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
return content_ids