Skip to content

Commit

Permalink
Add Log Prob Visualiser (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
neelnanda-io authored Jan 7, 2023
1 parent 9fad349 commit 1914745
Show file tree
Hide file tree
Showing 16 changed files with 6,440 additions and 7,827 deletions.
5,441 changes: 2,721 additions & 2,720 deletions python/Demonstration.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/circuitsvis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import circuitsvis.attention
import circuitsvis.examples
import circuitsvis.tokens
import circuitsvis.logits

__version__ = version("circuitsvis")

Expand Down
94 changes: 94 additions & 0 deletions python/circuitsvis/logits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Log Prob visualization"""
from typing import Callable, List, Union

import numpy as np
import torch
from circuitsvis.utils.render import RenderedHTML, render

ArrayRank1 = Union[List[float], np.ndarray, torch.Tensor]
ArrayRank2 = Union[List[List[float]], np.ndarray, torch.Tensor]
ArrayRank3 = Union[List[List[List[float]]], np.ndarray, torch.Tensor]
IntArrayRank1 = Union[List[int], np.ndarray, torch.Tensor]


def token_log_probs(
token_indices: torch.Tensor,
log_probs: torch.Tensor,
to_string: Callable[[int], str],
top_k: int = 10,
) -> RenderedHTML:
"""
Takes the log probs for a model on some text. Outputs the tokens coloured by
the log prob, and on hover shows you the top K tokens that the model guessed
for that position, and where the true token ranked in that.
The intended use case is to help debug and explore a model's outputs.
Args:
token_indices: Tensor of token indices (ie integers) of shape [N,].
Assumed to begin with a Beginning of Sequence (BOS) token, which is not
shown in the visualization.
log_probs: Log Probabilities for predicting the next token. Tensor of
shape [N, d_vocab].
to_string: A function mapping tokens (as integers) to their string value
top_k: How many logits to show
Returns:
Html: Log prob visualization
"""
if len(token_indices.shape) == 2:
# Remove batch dimension from token indices
token_indices = token_indices.squeeze(0)

if len(log_probs.shape) == 3:
# Remove batch dimension from log probs
log_probs = log_probs.squeeze(0)

assert len(
log_probs.shape) == 2, f"Log Probs shape must be 2D: {log_probs.shape}"
assert len(
token_indices.shape) == 1, f"Tokens shape must be 1D: {token_indices.shape}"
assert token_indices.size(0) == log_probs.size(
0), f"Number of tokens and log prob vectors must be identical, {log_probs.shape}, {token_indices.shape}"

# Drop the final dimension of log probs, since we don't know what the next
# token is for the final position!
log_probs = log_probs[:-1]

prompt = [to_string(index.item()) for index in token_indices]

# Sort log probs and values along the d_vocab dimension
_sorted_log_prob_values, sorted_log_prob_indices = log_probs.sort(
dim=-1, descending=True)

# Get the top K log probs and indices for each position
# Shapes are [N, K]
top_k_log_probs, top_k_indices = log_probs.topk(top_k, dim=-1)

# Get the token values (ie strings) for the top K tokens per position
top_k_tokens = [[to_string(token) for token in current_top_k_tokens]
for current_top_k_tokens in top_k_indices.tolist()]

# Slightly cursed code to get the rank of the correct token at each position
# .nonzero on a 2D array returns a [X, 2] array - X is the number of
# non-zero elements, and each has the pair of indices corresponding to it.
# We only want the index on the d_vocab direction, so we take 1
# We don't care about predicting the BOS token, so we do token_indices[1:]
correct_token_rank = (sorted_log_prob_indices ==
token_indices[1:, None]).nonzero()[:, 1]
assert len(correct_token_rank) == (len(token_indices) -
1), "Some token indices were missing from sorted_log_prob_indices"

# Gets the log probs for the correct next token. Weird indexing is necessary
# to use gather.
correct_token_log_prob = log_probs.gather(
index=token_indices[1:, None], dim=-1).squeeze(1)

return render(
"TokenLogProbs",
prompt=prompt,
topKLogProbs=top_k_log_probs,
topKTokens=top_k_tokens,
correctTokenRank=correct_token_rank,
correctTokenLogProb=correct_token_log_prob,
)
4 changes: 2 additions & 2 deletions python/circuitsvis/tests/snapshots/snap_test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/>
<script crossorigin type="module">
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.34.0/dist/cdn/esm.js";
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.0.0/dist/cdn/esm.js";
render(
"circuits-vis-mock",
TextNeuronActivations,
Expand All @@ -19,7 +19,7 @@

snapshots['TestTextNeuronActivations.test_single_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/>
<script crossorigin type="module">
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.34.0/dist/cdn/esm.js";
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.0.0/dist/cdn/esm.js";
render(
"circuits-vis-mock",
TextNeuronActivations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

snapshots['TestTopk.test_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/>
<script crossorigin type="module">
import { render, TopkTokens } from "https://unpkg.com/circuitsvis@1.34.0/dist/cdn/esm.js";
import { render, TopkTokens } from "https://unpkg.com/circuitsvis@1.0.0/dist/cdn/esm.js";
render(
"circuits-vis-mock",
TopkTokens,
Expand Down
4 changes: 2 additions & 2 deletions python/circuitsvis/tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

class TestTextNeuronActivations:
def test_single_matches_snapshot(self, snapshot, monkeypatch):
# Monkeypatch uuid4 to always return the same uuid
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0")

res = text_neuron_activations(
tokens=["a", "b"],
Expand All @@ -17,8 +17,8 @@ def test_single_matches_snapshot(self, snapshot, monkeypatch):
snapshot.assert_match(str(res))

def test_multi_matches_snapshot(self, snapshot, monkeypatch):
# Monkeypatch uuid4 to always return the same uuid
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0")

res = text_neuron_activations(
tokens=[["a", "b"], ["c", "d", "e"]],
Expand Down
1 change: 1 addition & 0 deletions python/circuitsvis/tests/test_topk_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class TestTopk:
def test_matches_snapshot(self, snapshot, monkeypatch):
# Monkeypatch uuid4 to always return the same uuid
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0")
res = topk_tokens(
tokens=[["a", "b", "c", "d", "e"], ["f", "g", "h"]],
activations=[
Expand Down
Loading

0 comments on commit 1914745

Please sign in to comment.