Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatting Packing / maybe fix #5443 and #5426 #5458

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/llamafactory/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
SFTDataCollatorWithFlattingPacking,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer


__all__ = [
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"SFTDataCollatorWithFlattingPacking",
"Role",
"split_dataset",
"get_dataset",
Expand Down
39 changes: 37 additions & 2 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence

import torch
from transformers import DataCollatorForSeq2Seq

from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase

if TYPE_CHECKING:
from transformers import ProcessorMixin
Expand Down Expand Up @@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
return features


@dataclass
class SFTDataCollatorWithFlattingPacking(DefaultDataCollator):
r"""
Data collator for flatting packing.
"""

tokenizer: PreTrainedTokenizerBase = None
label_pad_token_id: int = -100
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
return_position_ids: bool = True

def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]:
# todo: not support multi-model
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for instances in features:
for input_ids, labels in zip(instances["input_ids"], instances["labels"]):
ret["input_ids"] += input_ids
if is_labels_provided:
ret["labels"] += [self.label_pad_token_id] + labels[1:]
else:
ret["labels"] += [self.label_pad_token_id] + input_ids[1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(input_ids)))

assert len(ret["input_ids"]) == len(ret["labels"])

features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors)
return features


@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Expand Down
8 changes: 5 additions & 3 deletions src/llamafactory/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
print_flatting_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

Expand Down Expand Up @@ -78,8 +78,10 @@ def __init__(self, data, **kwargs):
processor=processor,
data_args=data_args,
)

print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
if data_args.packing and data_args.flat_packing:
print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer)
else:
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
Expand Down
6 changes: 6 additions & 0 deletions src/llamafactory/data/processors/processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
# filter out numbers that are larger than the capacity
numbers = [number for number in numbers if number <= capacity]
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []

Expand All @@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack

# avoid endless loop
if remaining_capacity == capacity:
break

knapsacks.append(current_knapsack)

return knapsacks
Expand Down
103 changes: 66 additions & 37 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template


logger = get_logger(__name__)


Expand All @@ -53,13 +51,16 @@ def _encode_supervised_example(
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
if total_length >= cutoff_len and cutoff_len > 0:
break

source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
Copy link
Contributor Author

@AlongWY AlongWY Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里导致 Inst 数据被异常截断 #5426, 也许考虑引入一个新的参数来保证是否可以被截断?我的样本是2轮次的 tool 调用,但是如果截断就只会学习到输出 tool_calls 没有最后的答案了。 而且这里现在截断的实现方式将会导致 user 和 assistant 的内容被截断。如在 mistral 模板中, 会产生 [INST] xxxxxxx 的结果,而xxxxx[/INST] 就不见了,这显然是不正确的。

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得不是这里的问题?non-packing 也会有同样的行为

Copy link
Contributor Author

@AlongWY AlongWY Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不过我确实觉得需要加一个参数控制一下,因为有些情况下不允许一个样本被中间截断

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不截 prompt 的话 assistant 放在哪里呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接跳过,drop掉这个样本

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了参数控制是否可以截断,默认不能截断

source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if cutoff_len > 0:
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
else:
source_len, target_len = len(source_ids), len(target_ids)

if train_on_prompt:
source_label = source_ids
Expand Down Expand Up @@ -112,7 +113,7 @@ def preprocess_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
cutoff_len=data_args.cutoff_len if data_args.allow_truncation else 0,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
Expand All @@ -132,13 +133,16 @@ def preprocess_packed_supervised_dataset(
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
invalid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)

# reserved for the padding token / flat_packing don't need
num_reserved = 0 if data_args.flat_packing else 1
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
Expand All @@ -154,13 +158,13 @@ def preprocess_packed_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
cutoff_len=data_args.cutoff_len - num_reserved if data_args.allow_truncation else 0,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
if length > data_args.cutoff_len - num_reserved:
invalid_num += 1
else:
lengths.append(length)
length2indexes[length].append(valid_num)
Expand All @@ -170,36 +174,52 @@ def preprocess_packed_supervised_dataset(
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1

if invalid_num > 0:
logger.warning(
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved)
)

model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")

if data_args.flat_packing:
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids.append(batch_input_ids[index])
packed_labels.append(batch_labels[index])
packed_images.append(batch_images[index])
packed_videos.append(batch_videos[index])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

延迟处理,此时先不返回 position ids,在 collator 中整合并返回 position ids

else:
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

# flat_packing don't need attention masks
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

# flatting packing don't need pad
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["attention_mask"].append(packed_attention_masks)

model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
Expand All @@ -213,3 +233,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))


def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"])))
input_ids = list(itertools.chain(*example["input_ids"]))
print("input_ids:\n{}".format(input_ids))
print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False)))
print("label_ids:\n{}".format(list(itertools.chain(*example["labels"]))))
print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False))
11 changes: 11 additions & 0 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ class DataArguments:
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
flat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing with flattening, need flash atten."}
)
allow_truncation: bool = field(
default=False,
metadata={"help": "Allow truncation when processing supervised examples."}
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
Expand Down Expand Up @@ -148,3 +156,6 @@ def split_arg(arg):

if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")

if self.neat_packing and self.flat_packing:
raise ValueError("`neat_packing` is incompatible with `flat_packing`.")
39 changes: 28 additions & 11 deletions src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,24 @@

from typing import TYPE_CHECKING, List, Optional

from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...data import SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWithFlattingPacking, get_dataset, \
get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...extras.logging import get_logger
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer


if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback

from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

logger = get_logger(__name__)


def run_sft(
model_args: "ModelArguments",
Expand All @@ -50,15 +53,29 @@ def run_sft(
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction

data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
if (
data_args.packing and
data_args.flat_packing and
(getattr(model.config, "_attn_implementation", None) != "flash_attention_2")
):
logger.warning("The `flat_packing` only support `flash_attention_2`! Maybe cause Out of memory!")

if (data_args.packing and data_args.flat_packing):
data_collator = SFTDataCollatorWithFlattingPacking(
template=template,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
)
else:
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)

# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
Expand Down