Skip to content

Commit

Permalink
1. flatting_packing don't need reserve token for padding
Browse files Browse the repository at this point in the history
2. Fix mistral assistant message
  • Loading branch information
AlongWY committed Sep 17, 2024
1 parent c7a8590 commit dfd9ab3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,13 @@ def preprocess_packed_supervised_dataset(
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
invalid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)
count_drop = 0

# reserved for the padding token / flatting_packing don't need
num_reserved = 0 if data_args.flatting_packing else 1
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
Expand All @@ -146,13 +149,13 @@ def preprocess_packed_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
cutoff_len=data_args.cutoff_len - num_reserved,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len - 1: # reserved for the padding token
count_drop += 1
if length > data_args.cutoff_len - num_reserved:
invalid_num += 1
else:
lengths.append(length)
length2indexes[length].append(valid_num)
Expand All @@ -162,11 +165,13 @@ def preprocess_packed_supervised_dataset(
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1

if count_drop > 0:
logger.warning("Dropped lengthy {} example with length > {}.".format(count_drop, data_args.cutoff_len - 1))
if invalid_num > 0:
logger.warning(
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved)
)

model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}"]), # mistral add space here
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
format_function=MistralFunctionFormatter(slots=[], tool_format="mistral"),
format_observation=MistralObservationFormatter(tool_format="mistral"),
Expand Down

0 comments on commit dfd9ab3

Please sign in to comment.