From 7475455bbaec5870f60fdb25b99e04a0c05dcff0 Mon Sep 17 00:00:00 2001 From: Karthik Ganesan Date: Thu, 14 Sep 2023 09:07:36 +0530 Subject: [PATCH] Update train.py Make llama-2 template as default to support fine-tuning llama-2 models better --- fastchat/train/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fastchat/train/train.py b/fastchat/train/train.py index 89dff81dd..ea833239e 100644 --- a/fastchat/train/train.py +++ b/fastchat/train/train.py @@ -84,7 +84,7 @@ def preprocess( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: - conv = get_conversation_template("vicuna") + conv = get_conversation_template("llama-2") roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates @@ -111,10 +111,10 @@ def preprocess( ).input_ids targets = input_ids.clone() - assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + assert conv.sep_style == SeparatorStyle.LLAMA2 # Mask targets. Only compute loss on the assistant outputs. - sep = conv.sep + conv.roles[1] + ": " + sep = conv.sep for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) @@ -131,19 +131,20 @@ def preprocess( break parts[0] += sep # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. - instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore the user instructions target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID cur_len += turn_len - target[cur_len:] = IGNORE_TOKEN_ID + target[cur_len+1:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) rank0_print(tokenizer.decode(z)) + cur_len += 2 #hackish fix to match the total length TODO(Karthik) if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID