diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb
index 6a52fac..ecb864f 100644
--- a/python/Demonstration.ipynb
+++ b/python/Demonstration.ipynb
@@ -24,11 +24,11 @@
"metadata": {},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
+ "/workspaces/CircuitsVis/python/.venv/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
@@ -38,12 +38,17 @@
"%autoreload 2\n",
"\n",
"# Imports\n",
- "import numpy as np\n",
"from circuitsvis.attention import attention_patterns, attention_pattern\n",
"from circuitsvis.activations import text_neuron_activations\n",
"from circuitsvis.examples import hello\n",
+ "from circuitsvis.logits import token_log_probs\n",
"from circuitsvis.tokens import colored_tokens\n",
- "from circuitsvis.topk_tokens import topk_tokens"
+ "from circuitsvis.topk_tokens import topk_tokens\n",
+ "\n",
+ "import numpy as np\n",
+ "import random\n",
+ "import string\n",
+ "import torch"
]
},
{
@@ -63,71 +68,70 @@
]
},
{
- "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Text Neuron Activations (single sample)"
+ "#### Text Neuron Activations"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 10,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -4971,11 +4975,17 @@
]
},
{
- "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Text Neuron Activations (multiple samples)"
+ "### Attention"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Attention Pattern (single head)"
]
},
{
@@ -4986,56 +4996,56 @@
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 3,
@@ -9870,29 +9880,16 @@
}
],
"source": [
- "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n",
- "n_layers = 3\n",
- "n_neurons_per_layer = 4\n",
- "activations = []\n",
- "for sample in tokens:\n",
- " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n",
- " activations.append(sample_activations)\n",
- "text_neuron_activations(tokens=tokens, activations=activations)\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Attention"
+ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n",
+ "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(8,8)))\n",
+ "attention_pattern(tokens=tokens, attention=attention)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Attention Pattern (single head)"
+ "#### Attention Patterns"
]
},
{
@@ -9903,56 +9900,56 @@
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 4,
@@ -14788,15 +14785,24 @@
],
"source": [
"tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n",
- "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(8,8)))\n",
- "attention_pattern(tokens=tokens, attention=attention)"
+ "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(12,8,8)))\n",
+ "attention_patterns(tokens=tokens, attention=attention)"
]
},
{
+ "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Attention Patterns"
+ "### Logits"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Token Log Probs"
]
},
{
@@ -14807,56 +14813,56 @@
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 5,
@@ -19691,9 +19697,11 @@
}
],
"source": [
- "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n",
- "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(12,8,8)))\n",
- "attention_patterns(tokens=tokens, attention=attention)"
+ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' example']\n",
+ "token_indicies = torch.randint(0, 50000, (len(tokens),))\n",
+ "log_probs = torch.rand((len(tokens), 50000), dtype=torch.float32)\n",
+ "to_string = lambda x: ''.join(random.choice(string.ascii_lowercase) for i in range(10))\n",
+ "token_log_probs(token_indices=token_indicies, log_probs=log_probs, to_string=to_string)"
]
},
{
@@ -19718,56 +19726,56 @@
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 6,
@@ -24616,62 +24624,62 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 11,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -29519,20 +29527,13 @@
"layer_labels = [2, 7, 9]\n",
"topk_tokens(tokens=tokens, activations=activations, max_k=7, first_dimension_name=\"Layer\", third_dimension_name=\"Neuron\", first_dimension_labels=layer_labels)"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
"kernelspec": {
- "display_name": "circuitsvis-env",
+ "display_name": ".venv",
"language": "python",
- "name": "circuitsvis-env"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -29544,11 +29545,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.8 (main, Nov 15 2022, 20:55:06) [GCC 10.2.1 20210110]"
+ "version": "3.10.8"
},
"vscode": {
"interpreter": {
- "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee"
+ "hash": "f15401ba0f322aa12af871c8a7369e4b5c15cd04566131c2cee112fb01acfcde"
}
}
},
diff --git a/python/circuitsvis/__init__.py b/python/circuitsvis/__init__.py
index 02490a0..6a492fc 100644
--- a/python/circuitsvis/__init__.py
+++ b/python/circuitsvis/__init__.py
@@ -4,6 +4,7 @@
import circuitsvis.attention
import circuitsvis.examples
import circuitsvis.tokens
+import circuitsvis.logits
__version__ = version("circuitsvis")
diff --git a/python/circuitsvis/logits.py b/python/circuitsvis/logits.py
new file mode 100644
index 0000000..974b3e8
--- /dev/null
+++ b/python/circuitsvis/logits.py
@@ -0,0 +1,94 @@
+"""Log Prob visualization"""
+from typing import Callable, List, Union
+
+import numpy as np
+import torch
+from circuitsvis.utils.render import RenderedHTML, render
+
+ArrayRank1 = Union[List[float], np.ndarray, torch.Tensor]
+ArrayRank2 = Union[List[List[float]], np.ndarray, torch.Tensor]
+ArrayRank3 = Union[List[List[List[float]]], np.ndarray, torch.Tensor]
+IntArrayRank1 = Union[List[int], np.ndarray, torch.Tensor]
+
+
+def token_log_probs(
+ token_indices: torch.Tensor,
+ log_probs: torch.Tensor,
+ to_string: Callable[[int], str],
+ top_k: int = 10,
+) -> RenderedHTML:
+ """
+ Takes the log probs for a model on some text. Outputs the tokens coloured by
+ the log prob, and on hover shows you the top K tokens that the model guessed
+ for that position, and where the true token ranked in that.
+
+ The intended use case is to help debug and explore a model's outputs.
+
+ Args:
+ token_indices: Tensor of token indices (ie integers) of shape [N,].
+ Assumed to begin with a Beginning of Sequence (BOS) token, which is not
+ shown in the visualization.
+ log_probs: Log Probabilities for predicting the next token. Tensor of
+ shape [N, d_vocab].
+ to_string: A function mapping tokens (as integers) to their string value
+ top_k: How many logits to show
+
+ Returns:
+ Html: Log prob visualization
+ """
+ if len(token_indices.shape) == 2:
+ # Remove batch dimension from token indices
+ token_indices = token_indices.squeeze(0)
+
+ if len(log_probs.shape) == 3:
+ # Remove batch dimension from log probs
+ log_probs = log_probs.squeeze(0)
+
+ assert len(
+ log_probs.shape) == 2, f"Log Probs shape must be 2D: {log_probs.shape}"
+ assert len(
+ token_indices.shape) == 1, f"Tokens shape must be 1D: {token_indices.shape}"
+ assert token_indices.size(0) == log_probs.size(
+ 0), f"Number of tokens and log prob vectors must be identical, {log_probs.shape}, {token_indices.shape}"
+
+ # Drop the final dimension of log probs, since we don't know what the next
+ # token is for the final position!
+ log_probs = log_probs[:-1]
+
+ prompt = [to_string(index.item()) for index in token_indices]
+
+ # Sort log probs and values along the d_vocab dimension
+ _sorted_log_prob_values, sorted_log_prob_indices = log_probs.sort(
+ dim=-1, descending=True)
+
+ # Get the top K log probs and indices for each position
+ # Shapes are [N, K]
+ top_k_log_probs, top_k_indices = log_probs.topk(top_k, dim=-1)
+
+ # Get the token values (ie strings) for the top K tokens per position
+ top_k_tokens = [[to_string(token) for token in current_top_k_tokens]
+ for current_top_k_tokens in top_k_indices.tolist()]
+
+ # Slightly cursed code to get the rank of the correct token at each position
+ # .nonzero on a 2D array returns a [X, 2] array - X is the number of
+ # non-zero elements, and each has the pair of indices corresponding to it.
+ # We only want the index on the d_vocab direction, so we take 1
+ # We don't care about predicting the BOS token, so we do token_indices[1:]
+ correct_token_rank = (sorted_log_prob_indices ==
+ token_indices[1:, None]).nonzero()[:, 1]
+ assert len(correct_token_rank) == (len(token_indices) -
+ 1), "Some token indices were missing from sorted_log_prob_indices"
+
+ # Gets the log probs for the correct next token. Weird indexing is necessary
+ # to use gather.
+ correct_token_log_prob = log_probs.gather(
+ index=token_indices[1:, None], dim=-1).squeeze(1)
+
+ return render(
+ "TokenLogProbs",
+ prompt=prompt,
+ topKLogProbs=top_k_log_probs,
+ topKTokens=top_k_tokens,
+ correctTokenRank=correct_token_rank,
+ correctTokenLogProb=correct_token_log_prob,
+ )
diff --git a/python/circuitsvis/tests/snapshots/snap_test_activations.py b/python/circuitsvis/tests/snapshots/snap_test_activations.py
index cad910d..ed37212 100644
--- a/python/circuitsvis/tests/snapshots/snap_test_activations.py
+++ b/python/circuitsvis/tests/snapshots/snap_test_activations.py
@@ -9,7 +9,7 @@
snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''