diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index 6a52fac..ecb864f 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -24,11 +24,11 @@ "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "/workspaces/CircuitsVis/python/.venv/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -38,12 +38,17 @@ "%autoreload 2\n", "\n", "# Imports\n", - "import numpy as np\n", "from circuitsvis.attention import attention_patterns, attention_pattern\n", "from circuitsvis.activations import text_neuron_activations\n", "from circuitsvis.examples import hello\n", + "from circuitsvis.logits import token_log_probs\n", "from circuitsvis.tokens import colored_tokens\n", - "from circuitsvis.topk_tokens import topk_tokens" + "from circuitsvis.topk_tokens import topk_tokens\n", + "\n", + "import numpy as np\n", + "import random\n", + "import string\n", + "import torch" ] }, { @@ -63,71 +68,70 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "#### Text Neuron Activations (single sample)" + "#### Text Neuron Activations" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 10, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -4971,11 +4975,17 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "#### Text Neuron Activations (multiple samples)" + "### Attention" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Attention Pattern (single head)" ] }, { @@ -4986,56 +4996,56 @@ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -9870,29 +9880,16 @@ } ], "source": [ - "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n", - "n_layers = 3\n", - "n_neurons_per_layer = 4\n", - "activations = []\n", - "for sample in tokens:\n", - " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n", - " activations.append(sample_activations)\n", - "text_neuron_activations(tokens=tokens, activations=activations)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Attention" + "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", + "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(8,8)))\n", + "attention_pattern(tokens=tokens, attention=attention)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Pattern (single head)" + "#### Attention Patterns" ] }, { @@ -9903,56 +9900,56 @@ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -14788,15 +14785,24 @@ ], "source": [ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", - "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(8,8)))\n", - "attention_pattern(tokens=tokens, attention=attention)" + "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(12,8,8)))\n", + "attention_patterns(tokens=tokens, attention=attention)" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Patterns" + "### Logits" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Token Log Probs" ] }, { @@ -14807,56 +14813,56 @@ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -19691,9 +19697,11 @@ } ], "source": [ - "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", - "attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(12,8,8)))\n", - "attention_patterns(tokens=tokens, attention=attention)" + "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' example']\n", + "token_indicies = torch.randint(0, 50000, (len(tokens),))\n", + "log_probs = torch.rand((len(tokens), 50000), dtype=torch.float32)\n", + "to_string = lambda x: ''.join(random.choice(string.ascii_lowercase) for i in range(10))\n", + "token_log_probs(token_indices=token_indicies, log_probs=log_probs, to_string=to_string)" ] }, { @@ -19718,56 +19726,56 @@ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -24616,62 +24624,62 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 11, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -29519,20 +29527,13 @@ "layer_labels = [2, 7, 9]\n", "topk_tokens(tokens=tokens, activations=activations, max_k=7, first_dimension_name=\"Layer\", third_dimension_name=\"Neuron\", first_dimension_labels=layer_labels)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "circuitsvis-env", + "display_name": ".venv", "language": "python", - "name": "circuitsvis-env" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -29544,11 +29545,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8 (main, Nov 15 2022, 20:55:06) [GCC 10.2.1 20210110]" + "version": "3.10.8" }, "vscode": { "interpreter": { - "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee" + "hash": "f15401ba0f322aa12af871c8a7369e4b5c15cd04566131c2cee112fb01acfcde" } } }, diff --git a/python/circuitsvis/__init__.py b/python/circuitsvis/__init__.py index 02490a0..6a492fc 100644 --- a/python/circuitsvis/__init__.py +++ b/python/circuitsvis/__init__.py @@ -4,6 +4,7 @@ import circuitsvis.attention import circuitsvis.examples import circuitsvis.tokens +import circuitsvis.logits __version__ = version("circuitsvis") diff --git a/python/circuitsvis/logits.py b/python/circuitsvis/logits.py new file mode 100644 index 0000000..974b3e8 --- /dev/null +++ b/python/circuitsvis/logits.py @@ -0,0 +1,94 @@ +"""Log Prob visualization""" +from typing import Callable, List, Union + +import numpy as np +import torch +from circuitsvis.utils.render import RenderedHTML, render + +ArrayRank1 = Union[List[float], np.ndarray, torch.Tensor] +ArrayRank2 = Union[List[List[float]], np.ndarray, torch.Tensor] +ArrayRank3 = Union[List[List[List[float]]], np.ndarray, torch.Tensor] +IntArrayRank1 = Union[List[int], np.ndarray, torch.Tensor] + + +def token_log_probs( + token_indices: torch.Tensor, + log_probs: torch.Tensor, + to_string: Callable[[int], str], + top_k: int = 10, +) -> RenderedHTML: + """ + Takes the log probs for a model on some text. Outputs the tokens coloured by + the log prob, and on hover shows you the top K tokens that the model guessed + for that position, and where the true token ranked in that. + + The intended use case is to help debug and explore a model's outputs. + + Args: + token_indices: Tensor of token indices (ie integers) of shape [N,]. + Assumed to begin with a Beginning of Sequence (BOS) token, which is not + shown in the visualization. + log_probs: Log Probabilities for predicting the next token. Tensor of + shape [N, d_vocab]. + to_string: A function mapping tokens (as integers) to their string value + top_k: How many logits to show + + Returns: + Html: Log prob visualization + """ + if len(token_indices.shape) == 2: + # Remove batch dimension from token indices + token_indices = token_indices.squeeze(0) + + if len(log_probs.shape) == 3: + # Remove batch dimension from log probs + log_probs = log_probs.squeeze(0) + + assert len( + log_probs.shape) == 2, f"Log Probs shape must be 2D: {log_probs.shape}" + assert len( + token_indices.shape) == 1, f"Tokens shape must be 1D: {token_indices.shape}" + assert token_indices.size(0) == log_probs.size( + 0), f"Number of tokens and log prob vectors must be identical, {log_probs.shape}, {token_indices.shape}" + + # Drop the final dimension of log probs, since we don't know what the next + # token is for the final position! + log_probs = log_probs[:-1] + + prompt = [to_string(index.item()) for index in token_indices] + + # Sort log probs and values along the d_vocab dimension + _sorted_log_prob_values, sorted_log_prob_indices = log_probs.sort( + dim=-1, descending=True) + + # Get the top K log probs and indices for each position + # Shapes are [N, K] + top_k_log_probs, top_k_indices = log_probs.topk(top_k, dim=-1) + + # Get the token values (ie strings) for the top K tokens per position + top_k_tokens = [[to_string(token) for token in current_top_k_tokens] + for current_top_k_tokens in top_k_indices.tolist()] + + # Slightly cursed code to get the rank of the correct token at each position + # .nonzero on a 2D array returns a [X, 2] array - X is the number of + # non-zero elements, and each has the pair of indices corresponding to it. + # We only want the index on the d_vocab direction, so we take 1 + # We don't care about predicting the BOS token, so we do token_indices[1:] + correct_token_rank = (sorted_log_prob_indices == + token_indices[1:, None]).nonzero()[:, 1] + assert len(correct_token_rank) == (len(token_indices) - + 1), "Some token indices were missing from sorted_log_prob_indices" + + # Gets the log probs for the correct next token. Weird indexing is necessary + # to use gather. + correct_token_log_prob = log_probs.gather( + index=token_indices[1:, None], dim=-1).squeeze(1) + + return render( + "TokenLogProbs", + prompt=prompt, + topKLogProbs=top_k_log_probs, + topKTokens=top_k_tokens, + correctTokenRank=correct_token_rank, + correctTokenLogProb=correct_token_log_prob, + ) diff --git a/python/circuitsvis/tests/snapshots/snap_test_activations.py b/python/circuitsvis/tests/snapshots/snap_test_activations.py index cad910d..ed37212 100644 --- a/python/circuitsvis/tests/snapshots/snap_test_activations.py +++ b/python/circuitsvis/tests/snapshots/snap_test_activations.py @@ -9,7 +9,7 @@ snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''