Skip to content

Commit

Permalink
Add linechart capability for per-token rewards (#33)
Browse files Browse the repository at this point in the history
* Add linechart capability

* Add argument for not aligning tokens

* Fix docstring
  • Loading branch information
ljvmiranda921 authored Feb 22, 2024
1 parent 2dbe89c commit cf82f2a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 40 deletions.
29 changes: 21 additions & 8 deletions analysis/draw_per_token_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

# Draw the per token reward

import argparse
import json
from pathlib import Path
from typing import List
import argparse

import numpy as np
import spacy_alignments as tokenizations
Expand Down Expand Up @@ -45,6 +45,16 @@ def get_args():
default=[8, 8],
help="Control the figure size when plotting.",
)
parser.add_argument(
"--line_chart",
action="store_true",
help="Draw a line chart instead of a heatmap.",
)
parser.add_argument(
"--do_not_align_tokens",
action="store_true",
help="If set, then tokens will not be aligned. May cause issues in the plot.",
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -85,19 +95,22 @@ def main():
whitespace_tokenizer = lambda x: x.split(" ") # noqa
reference_tokens = whitespace_tokenizer(text)

for _, results in rewards.items():
results["aligned_rewards"] = align_tokens(
reference_tokens=reference_tokens,
predicted_tokens=results["tokens"],
rewards=results["rewards"],
)
if not args.do_not_align_tokens:
for _, results in rewards.items():
results["aligned_rewards"] = align_tokens(
reference_tokens=reference_tokens,
predicted_tokens=results["tokens"],
rewards=results["rewards"],
)

reward_key = "rewards" if args.do_not_align_tokens else "aligned_rewards"
draw_per_token_reward(
tokens=reference_tokens,
rewards=[reward["aligned_rewards"] for _, reward in rewards.items()],
rewards=[reward[reward_key] for _, reward in rewards.items()],
model_names=[model_name for model_name, _ in rewards.items()],
output_path=args.output_path,
figsize=args.figsize,
line_chart=args.line_chart,
)


Expand Down
72 changes: 40 additions & 32 deletions herm/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

# Module for visualizing datasets and post-hoc analyses.

from pathlib import Path
from collections import Counter
from pathlib import Path
from typing import List, Optional, Tuple

import datasets
Expand All @@ -32,6 +32,7 @@ def draw_per_token_reward(
font_size: int = 12,
output_path: Path = None,
figsize: Tuple[int, int] = (12, 12),
line_chart: bool = False,
) -> "matplotlib.axes.Axes":
"""Draw a heatmap that combines the rewards
Expand All @@ -41,7 +42,8 @@ def draw_per_token_reward(
font_size (int)
output_path (Optional[Path]): if set, then save the figure in the specified path.
figsize (Tuple[int, int]): control the figure size when plotting.
RETURNS (matplotlib.axes.Axes): an Axes class containing the heatmap.
line_chart (bool): if set, will draw a line chart instead of a figure.
RETURNS (matplotlib.axes.Axes): an Axes class containing the figure.
"""
fig, ax = plt.subplots(figsize=figsize)
matplotlib.rcParams.update(
Expand All @@ -52,43 +54,49 @@ def draw_per_token_reward(
}
)
rewards = np.array(rewards)
im = ax.imshow(
rewards,
cmap="viridis",
vmax=np.max(rewards),
vmin=np.min(rewards),
)
fig.colorbar(im, ax=ax, orientation="horizontal", aspect=20, location="bottom")
ax.set_xticks(np.arange(len(tokens)), [f'"{token}"' for token in tokens])
ax.set_yticks(np.arange(len(model_names)), model_names)

# Add text
avg = np.mean(rewards)
for i in range(len(model_names)):
for j in range(len(tokens)):
color = "k" if rewards[i, j] >= avg else "w"
ax.text(
j,
i,
round(rewards[i, j], 4),
ha="center",
va="center",
color=color,
)

# Make it look better
ax.xaxis.tick_top()
ax.tick_params(left=False, top=False)
ax.spines[["right", "top", "left", "bottom"]].set_visible(False)
if not line_chart:
im = ax.imshow(
rewards,
cmap="viridis",
vmax=np.max(rewards),
vmin=np.min(rewards),
)
fig.colorbar(im, ax=ax, orientation="horizontal", aspect=20, location="bottom")
ax.set_xticks(np.arange(len(tokens)), [f'"{token}"' for token in tokens])
ax.set_yticks(np.arange(len(model_names)), model_names)

# Add text
avg = np.mean(rewards)
for i in range(len(model_names)):
for j in range(len(tokens)):
color = "k" if rewards[i, j] >= avg else "w"
ax.text(j, i, round(rewards[i, j], 4), ha="center", va="center", color=color)

# Make it look better
ax.xaxis.tick_top()
ax.tick_params(left=False, top=False)
ax.spines[["right", "top", "left", "bottom"]].set_visible(False)
else:
print("Drawing line chart")
idxs = np.arange(0, len(tokens))
for model_name, per_token_rewards in zip(model_names, rewards):
ax.plot(idxs, per_token_rewards, label=model_name, marker="x")

ax.legend(loc="upper left")
ax.set_xticks(np.arange(len(tokens)), [f'"{token}"' for token in tokens])
ax.set_xlabel("Tokens")
ax.set_ylabel("Reward")
ax.spines[["right", "top"]].set_visible(False)

# Added information
title = "Cumulative substring rewards"
ax.set_title(title, pad=20)

# fig.tight_layout()
fig.subplots_adjust(left=0.5)
if not line_chart:
fig.subplots_adjust(left=0.5)
if output_path:
print(f"Saving per-token-reward heatmap to {output_path}")
print(f"Saving per-token-reward plot to {output_path}")
plt.savefig(output_path, transparent=True, dpi=120)

plt.show()
Expand Down

0 comments on commit cf82f2a

Please sign in to comment.