Skip to content

Commit

Permalink
[wip] Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Feb 19, 2024
1 parent 7d0c459 commit 962872b
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions analysis/per_token_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def main():
def _tokenify_string(string):
tokens = tokenizer.tokenize(string)
cumulative_texts = [tokenizer.convert_tokens_to_string(tokens[: i + 1]) for i, _ in enumerate(tokens)]
return cumulative_texts
return cumulative_texts, tokens

substrings = _tokenify_string(args.text)
substrings, tokens = _tokenify_string(args.text)
dataset = Dataset.from_list([{"text": substring} for substring in substrings])

# Load reward model pipeline
Expand Down Expand Up @@ -194,8 +194,8 @@ def _tokenify_string(string):
)

# Report the results
for reward, token in zip(per_token_rewards, substrings):
print(f"Reward: {round(reward, 3)} | Substring: {token}")
for reward, span in zip(per_token_rewards, substrings):
print(f"Reward: {round(reward, 3)} | Substring: {span}")

# Save the results
save_results(
Expand All @@ -204,6 +204,7 @@ def _tokenify_string(string):
model=args.model,
chat_template=args.chat_template,
substrings=substrings,
tokens=tokens,
rewards=per_token_rewards,
)

Expand Down Expand Up @@ -306,6 +307,7 @@ def save_results(
model: str,
chat_template: str,
substrings: List[str],
tokens: List[str],
rewards: List[str],
):
# Hash the text first using base16
Expand All @@ -329,6 +331,7 @@ def save_results(
"chat_template": chat_template,
"model_chat_hash": model_chat_hash.decode(),
"substrings": substrings,
"tokens": tokens,
"rewards": rewards,
}

Expand Down

0 comments on commit 962872b

Please sign in to comment.