Skip to content

Commit

Permalink
1. support flat_packing
Browse files Browse the repository at this point in the history
2. fix knapsack, may cause #5443
3. avoid supervised examples wrongly truncation
  • Loading branch information
AlongWY committed Sep 18, 2024
1 parent 1a3e654 commit 7cab73b
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 54 deletions.
3 changes: 2 additions & 1 deletion src/llamafactory/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
SFTDataCollatorWithFlattingPacking,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer


__all__ = [
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"SFTDataCollatorWithFlattingPacking",
"Role",
"split_dataset",
"get_dataset",
Expand Down
39 changes: 37 additions & 2 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence

import torch
from transformers import DataCollatorForSeq2Seq

from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase

if TYPE_CHECKING:
from transformers import ProcessorMixin
Expand Down Expand Up @@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
return features


@dataclass
class SFTDataCollatorWithFlattingPacking(DefaultDataCollator):
r"""
Data collator for flatting packing.
"""

tokenizer: PreTrainedTokenizerBase = None
label_pad_token_id: int = -100
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
return_position_ids: bool = True

def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]:
# todo: not support multi-model
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for instances in features:
for input_ids, labels in zip(instances["input_ids"], instances["labels"]):
ret["input_ids"] += input_ids
if is_labels_provided:
ret["labels"] += [self.label_pad_token_id] + labels[1:]
else:
ret["labels"] += [self.label_pad_token_id] + input_ids[1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(input_ids)))

assert len(ret["input_ids"]) == len(ret["labels"])

features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors)
return features


@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Expand Down
8 changes: 5 additions & 3 deletions src/llamafactory/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
print_flatting_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

Expand Down Expand Up @@ -78,8 +78,10 @@ def __init__(self, data, **kwargs):
processor=processor,
data_args=data_args,
)

print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
if data_args.packing and data_args.flat_packing:
print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer)
else:
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
Expand Down
6 changes: 6 additions & 0 deletions src/llamafactory/data/processors/processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
# filter out numbers that are larger than the capacity
numbers = [number for number in numbers if number <= capacity]
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []

Expand All @@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack

# avoid endless loop
if remaining_capacity == capacity:
break

knapsacks.append(current_knapsack)

return knapsacks
Expand Down
103 changes: 66 additions & 37 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin

from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template


logger = get_logger(__name__)


Expand All @@ -53,13 +51,16 @@ def _encode_supervised_example(
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
if total_length >= cutoff_len and cutoff_len > 0:
break

source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if cutoff_len > 0:
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
else:
source_len, target_len = len(source_ids), len(target_ids)

if train_on_prompt:
source_label = source_ids
Expand Down Expand Up @@ -112,7 +113,7 @@ def preprocess_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
cutoff_len=data_args.cutoff_len if data_args.allow_truncation else 0,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
Expand All @@ -132,13 +133,16 @@ def preprocess_packed_supervised_dataset(
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
invalid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)

# reserved for the padding token / flat_packing don't need
num_reserved = 0 if data_args.flat_packing else 1
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
Expand All @@ -154,13 +158,13 @@ def preprocess_packed_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
cutoff_len=data_args.cutoff_len - num_reserved if data_args.allow_truncation else 0,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
if length > data_args.cutoff_len - num_reserved:
invalid_num += 1
else:
lengths.append(length)
length2indexes[length].append(valid_num)
Expand All @@ -170,36 +174,52 @@ def preprocess_packed_supervised_dataset(
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1

if invalid_num > 0:
logger.warning(
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved)
)

model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")

if data_args.flat_packing:
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids.append(batch_input_ids[index])
packed_labels.append(batch_labels[index])
packed_images.append(batch_images[index])
packed_videos.append(batch_videos[index])
else:
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

# flat_packing don't need attention masks
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

# flatting packing don't need pad
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["attention_mask"].append(packed_attention_masks)

model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
Expand All @@ -213,3 +233,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))


def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"])))
input_ids = list(itertools.chain(*example["input_ids"]))
print("input_ids:\n{}".format(input_ids))
print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False)))
print("label_ids:\n{}".format(list(itertools.chain(*example["labels"]))))
print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False))
11 changes: 11 additions & 0 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ class DataArguments:
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
flat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing with flattening, need flash atten."}
)
allow_truncation: bool = field(
default=False,
metadata={"help": "Allow truncation when processing supervised examples."}
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
Expand Down Expand Up @@ -148,3 +156,6 @@ def split_arg(arg):

if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")

if self.neat_packing and self.flat_packing:
raise ValueError("`neat_packing` is incompatible with `flat_packing`.")
39 changes: 28 additions & 11 deletions src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,24 @@

from typing import TYPE_CHECKING, List, Optional

from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...data import SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWithFlattingPacking, get_dataset, \
get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...extras.logging import get_logger
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer


if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback

from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

logger = get_logger(__name__)


def run_sft(
model_args: "ModelArguments",
Expand All @@ -50,15 +53,29 @@ def run_sft(
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction

data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
if (
data_args.packing and
data_args.flat_packing and
(getattr(model.config, "_attn_implementation", None) != "flash_attention_2")
):
logger.warning("The `flat_packing` only support `flash_attention_2`! Maybe cause Out of memory!")

if (data_args.packing and data_args.flat_packing):
data_collator = SFTDataCollatorWithFlattingPacking(
template=template,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
)
else:
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)

# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
Expand Down

0 comments on commit 7cab73b

Please sign in to comment.