Skip to content

Commit a57ca60

Browse files
committed
delete
1 parent 4da941c commit a57ca60

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

tests/trainer/test_trainer.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,22 @@
163163
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
164164

165165

166+
def get_dataset(file_path, tokenizer, max_len):
167+
dataset = datasets.load_dataset("text", file_path)
168+
169+
# Define tokenization function
170+
def tokenize_function(examples):
171+
tokenized = tokenizer(examples["text"], add_special_tokens=True, truncation=True, max_length=max_len)
172+
# Add labels as a copy of input_ids
173+
tokenized["labels"] = tokenized["input_ids"].copy()
174+
return tokenized
175+
176+
# Apply tokenization and remove original text column
177+
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
178+
179+
return tokenized_dataset
180+
181+
166182
class StoreLossCallback(TrainerCallback):
167183
"""
168184
Simple callback to store the loss.
@@ -1528,13 +1544,7 @@ def test_multiple_peft_adapters(self):
15281544
tiny_model = get_peft_model(tiny_model, peft_config, "adapter1")
15291545
tiny_model.add_adapter("adapter2", peft_config)
15301546

1531-
train_dataset = LineByLineTextDataset(
1532-
tokenizer=tokenizer,
1533-
file_path=PATH_SAMPLE_TEXT,
1534-
block_size=tokenizer.max_len_single_sentence,
1535-
)
1536-
for example in train_dataset.examples:
1537-
example["labels"] = example["input_ids"]
1547+
train_dataset = get_dataset(PATH_SAMPLE_TEXT, tokenizer, tokenizer.max_len_single_sentence)
15381548

15391549
tokenizer.pad_token = tokenizer.eos_token
15401550

@@ -3754,13 +3764,7 @@ def test_trainer_eval_multiple(self):
37543764
MODEL_ID = "openai-community/gpt2"
37553765
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
37563766
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
3757-
dataset = LineByLineTextDataset(
3758-
tokenizer=tokenizer,
3759-
file_path=PATH_SAMPLE_TEXT,
3760-
block_size=tokenizer.max_len_single_sentence,
3761-
)
3762-
for example in dataset.examples:
3763-
example["labels"] = example["input_ids"]
3767+
dataset = get_dataset(PATH_SAMPLE_TEXT, tokenizer, tokenizer.max_len_single_sentence)
37643768
with tempfile.TemporaryDirectory() as tmp_dir:
37653769
training_args = TrainingArguments(
37663770
output_dir=tmp_dir,
@@ -3784,11 +3788,7 @@ def test_trainer_eval_multiple(self):
37843788
def test_trainer_eval_lm(self):
37853789
MODEL_ID = "distilbert/distilroberta-base"
37863790
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
3787-
dataset = LineByLineTextDataset(
3788-
tokenizer=tokenizer,
3789-
file_path=PATH_SAMPLE_TEXT,
3790-
block_size=tokenizer.max_len_single_sentence,
3791-
)
3791+
dataset = get_dataset(PATH_SAMPLE_TEXT, tokenizer, tokenizer.max_len_single_sentence)
37923792
self.assertEqual(len(dataset), 31)
37933793

37943794
def test_training_iterable_dataset(self):
@@ -4975,13 +4975,7 @@ def test_trainer_works_without_model_config(self):
49754975
model = BasicTextGenerationModel(vocab_size=tokenizer.vocab_size, hidden_size=32)
49764976
# Note that this class does not have a config attribute
49774977

4978-
train_dataset = LineByLineTextDataset(
4979-
tokenizer=tokenizer,
4980-
file_path=PATH_SAMPLE_TEXT,
4981-
block_size=tokenizer.max_len_single_sentence,
4982-
)
4983-
for example in train_dataset.examples:
4984-
example["labels"] = example["input_ids"]
4978+
train_dataset = get_dataset(PATH_SAMPLE_TEXT, tokenizer, tokenizer.max_len_single_sentence)
49854979

49864980
with tempfile.TemporaryDirectory() as tmpdir:
49874981
training_args = TrainingArguments(

0 commit comments

Comments
 (0)