Skip to content

Commit

Permalink
option to pad to max length
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 28, 2023
1 parent 69ddccc commit 79dad36
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def construct_dataset(self, input_batch):
class CausalDatasetBuilder(DatasetBuilder):
"""Builds generative dataset for Causal LM."""

def __init__(self, tokenizer, max_length, train_on_prompt=True, pad_to_max_length=False):
def __init__(self, tokenizer, max_length, pad_to_max_length=False, train_on_prompt=True):
super().__init__(tokenizer, max_length)
self.train_on_prompt = train_on_prompt
self.pad_to_max_length = pad_to_max_length
Expand Down Expand Up @@ -303,13 +303,15 @@ def build_dataset(
eval_data,
tokenizer,
max_length: int,
pad_to_max_length: bool,
train_on_prompt: bool,
):
# TODO (chiragjn): This should not be loading the entire dataset in memory all at once. Make this streaming
# TODO (chiragjn): Add dataset packing to increase training efficiency
builder = CausalDatasetBuilder(
tokenizer=tokenizer,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
train_on_prompt=train_on_prompt,
)
dataset_dict = DatasetDict(train=Dataset.from_list(train_data), eval=Dataset.from_list(eval_data))
Expand Down

0 comments on commit 79dad36

Please sign in to comment.