Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Jun 12, 2024
1 parent c9e0cc1 commit f66059f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
15 changes: 1 addition & 14 deletions rewardbench/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,7 @@ def tokenize_row(self, feature) -> Dict:
batch[f"{k}{type_key}"] = tokens

else:
chosen_tokens = self.tokenizer(
chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
rejected_tokens = self.tokenizer(
rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
prompt_tokens = self.tokenizer(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)

batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
raise ValueError("Encoder-decoder models are not supported yet.")

return batch

Expand Down
8 changes: 6 additions & 2 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,13 @@ def main():
# tokenize dataset
column_names = list(dataset.features)

import ipdb; ipdb.set_trace()
import ipdb

ipdb.set_trace()
tokenized_dataset = dataset.map(dpo.tokenize_row, remove_columns=column_names)
import ipdb; ipdb.set_trace()
import ipdb

ipdb.set_trace()
dataloader = torch.utils.data.DataLoader(
tokenized_dataset,
batch_size=BATCH_SIZE,
Expand Down

0 comments on commit f66059f

Please sign in to comment.