diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 6e92d2507..c9999b042 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -133,6 +133,7 @@ jobs: - "Main_Demo" # - "No_Position_Experiment" - "Othello_GPT" + # - "T5" steps: - uses: actions/checkout@v3 - name: Install Poetry diff --git a/demos/T5.ipynb b/demos/T5.ipynb new file mode 100644 index 000000000..363073ad1 --- /dev/null +++ b/demos/T5.ipynb @@ -0,0 +1,724 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1686188/569054096.py:18: DeprecationWarning:\n", + "\n", + "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "\n", + "/tmp/ipykernel_1686188/569054096.py:19: DeprecationWarning:\n", + "\n", + "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "DEBUG_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install torchtyping\n", + " # Install my janky personal plotting utils\n", + " %pip install git+https://github.com/TransformerLensOrg/neel-plotly.git\n", + " # Install another version of node that makes PySvelte work way faster\n", + " %pip install circuitsvis\n", + " # Needed for PySvelte to work, v3 came out and broke things...\n", + " %pip install typeguard==2.13.3" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "\n", + "if IN_COLAB or not DEBUG_MODE:\n", + " # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dontsov/.local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning:\n", + "\n", + "IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "\n", + "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning:\n", + "\n", + "urllib3 (2.2.1) or chardet (3.0.4) doesn't match a supported version!\n", + "\n", + "WARNING:root:Support for T5 in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", + "If using T5 for interpretability research, keep in mind that T5 has some significant architectural differences to GPT. The major one is that T5 is an Encoder-Decoder modelAlso, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Moving model to device: cuda\n", + "Loaded pretrained model t5-small into HookedTransformer\n" + ] + } + ], + "source": [ + "# Imports\n", + "import torch\n", + "\n", + "from transformers import AutoTokenizer\n", + "from transformer_lens import HookedEncoderDecoder\n", + "\n", + "model_name = \"t5-small\"\n", + "model = HookedEncoderDecoder.from_pretrained(model_name)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## basic sanity check - model generates smth" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generated token: \"Bonjour\", token id: 21845\n", + "generated token: \",\", token id: 6\n", + "generated token: \"comment\", token id: 1670\n", + "generated token: \"\", token id: 3\n", + "generated token: \"êtes\", token id: 6738\n", + "generated token: \"-\", token id: 18\n", + "generated token: \"vous\", token id: 3249\n", + "generated token: \"?\", token id: 58\n", + "generated token: \"\", token id: 1\n", + "translate English to French: Hello, how are you? \n", + " Bonjour, comment êtes-vous?\n" + ] + } + ], + "source": [ + "prompt = \"translate English to French: Hello, how are you? \"\n", + "inputs = tokenizer(prompt, return_tensors=\"pt\")\n", + "input_ids = inputs[\"input_ids\"]\n", + "attention_mask = inputs[\"attention_mask\"]\n", + "decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", + "\n", + "\n", + "while True:\n", + " logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", + " # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n", + "\n", + " token_idx = torch.argmax(logits[0, -1, :]).item()\n", + " print(\"generated token: \\\"\", tokenizer.decode(token_idx), \"\\\", token id: \", token_idx, sep=\"\")\n", + "\n", + " # append token to decoder_input_ids\n", + " decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[token_idx]]).to(input_ids.device)], dim=-1)\n", + "\n", + " # break if End-Of-Sequence token generated\n", + " if token_idx == tokenizer.eos_token_id:\n", + " break\n", + "\n", + "print(prompt, \"\\n\", tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### visualise encoder patterns" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import circuitsvis as cv\n", + "# Testing that the library works\n", + "cv.examples.hello(\"Neel\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"translate English to French: Hello, how are you? \"\n", + "inputs = tokenizer(prompt, return_tensors=\"pt\")\n", + "input_ids = inputs[\"input_ids\"]\n", + "attention_mask = inputs[\"attention_mask\"]\n", + "\n", + "\n", + "logits,cache = model.run_with_cache(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids, remove_batch_dim=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hook_embed\n", + "encoder.0.hook_resid_pre\n", + "encoder.0.ln1.hook_scale\n", + "encoder.0.ln1.hook_normalized\n", + "encoder.0.attn.hook_q\n", + "encoder.0.attn.hook_k\n", + "encoder.0.attn.hook_v\n", + "encoder.0.attn.hook_attn_scores\n", + "encoder.0.attn.hook_pattern\n", + "encoder.0.attn.hook_z\n", + "encoder.0.hook_attn_out\n", + "encoder.0.hook_resid_mid\n", + "encoder.0.ln2.hook_scale\n", + "encoder.0.ln2.hook_normalized\n", + "encoder.0.mlp.hook_pre\n", + "encoder.0.mlp.hook_post\n", + "encoder.0.hook_mlp_out\n", + "encoder.0.hook_resid_post\n", + "encoder.1.hook_resid_pre\n", + "encoder.1.ln1.hook_scale\n", + "encoder.1.ln1.hook_normalized\n", + "encoder.1.attn.hook_q\n", + "encoder.1.attn.hook_k\n", + "encoder.1.attn.hook_v\n", + "encoder.1.attn.hook_attn_scores\n", + "encoder.1.attn.hook_pattern\n", + "encoder.1.attn.hook_z\n", + "encoder.1.hook_attn_out\n", + "encoder.1.hook_resid_mid\n", + "encoder.1.ln2.hook_scale\n", + "encoder.1.ln2.hook_normalized\n", + "encoder.1.mlp.hook_pre\n", + "encoder.1.mlp.hook_post\n", + "encoder.1.hook_mlp_out\n", + "encoder.1.hook_resid_post\n", + "encoder.2.hook_resid_pre\n", + "encoder.2.ln1.hook_scale\n", + "encoder.2.ln1.hook_normalized\n", + "encoder.2.attn.hook_q\n", + "encoder.2.attn.hook_k\n", + "encoder.2.attn.hook_v\n", + "encoder.2.attn.hook_attn_scores\n", + "encoder.2.attn.hook_pattern\n", + "encoder.2.attn.hook_z\n", + "encoder.2.hook_attn_out\n", + "encoder.2.hook_resid_mid\n", + "encoder.2.ln2.hook_scale\n", + "encoder.2.ln2.hook_normalized\n", + "encoder.2.mlp.hook_pre\n", + "encoder.2.mlp.hook_post\n", + "encoder.2.hook_mlp_out\n", + "encoder.2.hook_resid_post\n", + "encoder.3.hook_resid_pre\n", + "encoder.3.ln1.hook_scale\n", + "encoder.3.ln1.hook_normalized\n", + "encoder.3.attn.hook_q\n", + "encoder.3.attn.hook_k\n", + "encoder.3.attn.hook_v\n", + "encoder.3.attn.hook_attn_scores\n", + "encoder.3.attn.hook_pattern\n", + "encoder.3.attn.hook_z\n", + "encoder.3.hook_attn_out\n", + "encoder.3.hook_resid_mid\n", + "encoder.3.ln2.hook_scale\n", + "encoder.3.ln2.hook_normalized\n", + "encoder.3.mlp.hook_pre\n", + "encoder.3.mlp.hook_post\n", + "encoder.3.hook_mlp_out\n", + "encoder.3.hook_resid_post\n", + "encoder.4.hook_resid_pre\n", + "encoder.4.ln1.hook_scale\n", + "encoder.4.ln1.hook_normalized\n", + "encoder.4.attn.hook_q\n", + "encoder.4.attn.hook_k\n", + "encoder.4.attn.hook_v\n", + "encoder.4.attn.hook_attn_scores\n", + "encoder.4.attn.hook_pattern\n", + "encoder.4.attn.hook_z\n", + "encoder.4.hook_attn_out\n", + "encoder.4.hook_resid_mid\n", + "encoder.4.ln2.hook_scale\n", + "encoder.4.ln2.hook_normalized\n", + "encoder.4.mlp.hook_pre\n", + "encoder.4.mlp.hook_post\n", + "encoder.4.hook_mlp_out\n", + "encoder.4.hook_resid_post\n", + "encoder.5.hook_resid_pre\n", + "encoder.5.ln1.hook_scale\n", + "encoder.5.ln1.hook_normalized\n", + "encoder.5.attn.hook_q\n", + "encoder.5.attn.hook_k\n", + "encoder.5.attn.hook_v\n", + "encoder.5.attn.hook_attn_scores\n", + "encoder.5.attn.hook_pattern\n", + "encoder.5.attn.hook_z\n", + "encoder.5.hook_attn_out\n", + "encoder.5.hook_resid_mid\n", + "encoder.5.ln2.hook_scale\n", + "encoder.5.ln2.hook_normalized\n", + "encoder.5.mlp.hook_pre\n", + "encoder.5.mlp.hook_post\n", + "encoder.5.hook_mlp_out\n", + "encoder.5.hook_resid_post\n", + "encoder_final_ln.hook_scale\n", + "encoder_final_ln.hook_normalized\n", + "decoder.0.hook_resid_pre\n", + "decoder.0.ln1.hook_scale\n", + "decoder.0.ln1.hook_normalized\n", + "decoder.0.attn.hook_q\n", + "decoder.0.attn.hook_k\n", + "decoder.0.attn.hook_v\n", + "decoder.0.attn.hook_attn_scores\n", + "decoder.0.attn.hook_pattern\n", + "decoder.0.attn.hook_z\n", + "decoder.0.hook_attn_out\n", + "decoder.0.hook_resid_mid\n", + "decoder.0.ln2.hook_scale\n", + "decoder.0.ln2.hook_normalized\n", + "decoder.0.cross_attn.hook_q\n", + "decoder.0.cross_attn.hook_k\n", + "decoder.0.cross_attn.hook_v\n", + "decoder.0.cross_attn.hook_attn_scores\n", + "decoder.0.cross_attn.hook_pattern\n", + "decoder.0.cross_attn.hook_z\n", + "decoder.0.hook_cross_attn_out\n", + "decoder.0.hook_resid_mid_cross\n", + "decoder.0.ln3.hook_scale\n", + "decoder.0.ln3.hook_normalized\n", + "decoder.0.mlp.hook_pre\n", + "decoder.0.mlp.hook_post\n", + "decoder.0.hook_mlp_out\n", + "decoder.0.hook_resid_post\n", + "decoder.1.hook_resid_pre\n", + "decoder.1.ln1.hook_scale\n", + "decoder.1.ln1.hook_normalized\n", + "decoder.1.attn.hook_q\n", + "decoder.1.attn.hook_k\n", + "decoder.1.attn.hook_v\n", + "decoder.1.attn.hook_attn_scores\n", + "decoder.1.attn.hook_pattern\n", + "decoder.1.attn.hook_z\n", + "decoder.1.hook_attn_out\n", + "decoder.1.hook_resid_mid\n", + "decoder.1.ln2.hook_scale\n", + "decoder.1.ln2.hook_normalized\n", + "decoder.1.cross_attn.hook_q\n", + "decoder.1.cross_attn.hook_k\n", + "decoder.1.cross_attn.hook_v\n", + "decoder.1.cross_attn.hook_attn_scores\n", + "decoder.1.cross_attn.hook_pattern\n", + "decoder.1.cross_attn.hook_z\n", + "decoder.1.hook_cross_attn_out\n", + "decoder.1.hook_resid_mid_cross\n", + "decoder.1.ln3.hook_scale\n", + "decoder.1.ln3.hook_normalized\n", + "decoder.1.mlp.hook_pre\n", + "decoder.1.mlp.hook_post\n", + "decoder.1.hook_mlp_out\n", + "decoder.1.hook_resid_post\n", + "decoder.2.hook_resid_pre\n", + "decoder.2.ln1.hook_scale\n", + "decoder.2.ln1.hook_normalized\n", + "decoder.2.attn.hook_q\n", + "decoder.2.attn.hook_k\n", + "decoder.2.attn.hook_v\n", + "decoder.2.attn.hook_attn_scores\n", + "decoder.2.attn.hook_pattern\n", + "decoder.2.attn.hook_z\n", + "decoder.2.hook_attn_out\n", + "decoder.2.hook_resid_mid\n", + "decoder.2.ln2.hook_scale\n", + "decoder.2.ln2.hook_normalized\n", + "decoder.2.cross_attn.hook_q\n", + "decoder.2.cross_attn.hook_k\n", + "decoder.2.cross_attn.hook_v\n", + "decoder.2.cross_attn.hook_attn_scores\n", + "decoder.2.cross_attn.hook_pattern\n", + "decoder.2.cross_attn.hook_z\n", + "decoder.2.hook_cross_attn_out\n", + "decoder.2.hook_resid_mid_cross\n", + "decoder.2.ln3.hook_scale\n", + "decoder.2.ln3.hook_normalized\n", + "decoder.2.mlp.hook_pre\n", + "decoder.2.mlp.hook_post\n", + "decoder.2.hook_mlp_out\n", + "decoder.2.hook_resid_post\n", + "decoder.3.hook_resid_pre\n", + "decoder.3.ln1.hook_scale\n", + "decoder.3.ln1.hook_normalized\n", + "decoder.3.attn.hook_q\n", + "decoder.3.attn.hook_k\n", + "decoder.3.attn.hook_v\n", + "decoder.3.attn.hook_attn_scores\n", + "decoder.3.attn.hook_pattern\n", + "decoder.3.attn.hook_z\n", + "decoder.3.hook_attn_out\n", + "decoder.3.hook_resid_mid\n", + "decoder.3.ln2.hook_scale\n", + "decoder.3.ln2.hook_normalized\n", + "decoder.3.cross_attn.hook_q\n", + "decoder.3.cross_attn.hook_k\n", + "decoder.3.cross_attn.hook_v\n", + "decoder.3.cross_attn.hook_attn_scores\n", + "decoder.3.cross_attn.hook_pattern\n", + "decoder.3.cross_attn.hook_z\n", + "decoder.3.hook_cross_attn_out\n", + "decoder.3.hook_resid_mid_cross\n", + "decoder.3.ln3.hook_scale\n", + "decoder.3.ln3.hook_normalized\n", + "decoder.3.mlp.hook_pre\n", + "decoder.3.mlp.hook_post\n", + "decoder.3.hook_mlp_out\n", + "decoder.3.hook_resid_post\n", + "decoder.4.hook_resid_pre\n", + "decoder.4.ln1.hook_scale\n", + "decoder.4.ln1.hook_normalized\n", + "decoder.4.attn.hook_q\n", + "decoder.4.attn.hook_k\n", + "decoder.4.attn.hook_v\n", + "decoder.4.attn.hook_attn_scores\n", + "decoder.4.attn.hook_pattern\n", + "decoder.4.attn.hook_z\n", + "decoder.4.hook_attn_out\n", + "decoder.4.hook_resid_mid\n", + "decoder.4.ln2.hook_scale\n", + "decoder.4.ln2.hook_normalized\n", + "decoder.4.cross_attn.hook_q\n", + "decoder.4.cross_attn.hook_k\n", + "decoder.4.cross_attn.hook_v\n", + "decoder.4.cross_attn.hook_attn_scores\n", + "decoder.4.cross_attn.hook_pattern\n", + "decoder.4.cross_attn.hook_z\n", + "decoder.4.hook_cross_attn_out\n", + "decoder.4.hook_resid_mid_cross\n", + "decoder.4.ln3.hook_scale\n", + "decoder.4.ln3.hook_normalized\n", + "decoder.4.mlp.hook_pre\n", + "decoder.4.mlp.hook_post\n", + "decoder.4.hook_mlp_out\n", + "decoder.4.hook_resid_post\n", + "decoder.5.hook_resid_pre\n", + "decoder.5.ln1.hook_scale\n", + "decoder.5.ln1.hook_normalized\n", + "decoder.5.attn.hook_q\n", + "decoder.5.attn.hook_k\n", + "decoder.5.attn.hook_v\n", + "decoder.5.attn.hook_attn_scores\n", + "decoder.5.attn.hook_pattern\n", + "decoder.5.attn.hook_z\n", + "decoder.5.hook_attn_out\n", + "decoder.5.hook_resid_mid\n", + "decoder.5.ln2.hook_scale\n", + "decoder.5.ln2.hook_normalized\n", + "decoder.5.cross_attn.hook_q\n", + "decoder.5.cross_attn.hook_k\n", + "decoder.5.cross_attn.hook_v\n", + "decoder.5.cross_attn.hook_attn_scores\n", + "decoder.5.cross_attn.hook_pattern\n", + "decoder.5.cross_attn.hook_z\n", + "decoder.5.hook_cross_attn_out\n", + "decoder.5.hook_resid_mid_cross\n", + "decoder.5.ln3.hook_scale\n", + "decoder.5.ln3.hook_normalized\n", + "decoder.5.mlp.hook_pre\n", + "decoder.5.mlp.hook_post\n", + "decoder.5.hook_mlp_out\n", + "decoder.5.hook_resid_post\n", + "decoder_final_ln.hook_scale\n", + "decoder_final_ln.hook_normalized\n" + ] + } + ], + "source": [ + "# the usual way of indexing cache via cache[\"pattetn\",0,\"attn\"] does not work\n", + "# besause it uses cache[\"block.0....] indexing\n", + "# t5 is implementes as separate stack of blocks for encoder and decoder\n", + "# so indexing is cache[\"encoder.0..\"], cache[\"decoder.0..\"] \n", + "# lets see what is in cache and choose the right key for encoder attention pattern on layer 0\n", + "print(\"\\n\".join(cache.keys()))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "encoder_attn_pattern = cache[\"encoder.0.attn.hook_pattern\"]\n", + "input_str_tokens = [w.lstrip(\"▁\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "cv.attention.attention_patterns(tokens=input_str_tokens, attention=encoder_attn_pattern)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### visualise decoder pattern" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['', '▁Bonjour', ',', '▁comment', '▁', 'êtes', '-', 'vous', '?', '']" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decoder_str_tokens = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])\n", + "decoder_str_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decoder_attn_pattern = cache[\"decoder.0.attn.hook_pattern\"]\n", + "cv.attention.attention_patterns(tokens=decoder_str_tokens, attention=decoder_attn_pattern)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## topk tokens visualisation" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# list of samples of shape (n_layers, n_tokens, n_neurons) for each sample\n", + "# i take the activations after the mlp layer\n", + "# you can also pass the activations after the attention layer (hook_attn_out),\n", + "# after the cross attention layer (hook_cross_attn_out) or after the mlp layer (hook_mlp_out)\n", + "activations = [\n", + " torch.stack([cache[f\"decoder.{layer}.hook_mlp_out\"] for layer in range(model.cfg.n_layers)]).cpu().numpy()\n", + " ]\n", + "\n", + "# list of samples of shape (n_tokens)\n", + "tokens = [decoder_str_tokens]\n", + "\n", + "# if we have an arbitrary selection of layers, when change the layer labels, now just pass the layer index\n", + "layer_labels = [i for i in range(model.cfg.n_layers)]\n", + "\n", + "\n", + "cv.topk_tokens.topk_tokens(\n", + " tokens=tokens,\n", + " activations=activations, \n", + " max_k=10, \n", + " first_dimension_name=\"Layer\", \n", + " first_dimension_labels=layer_labels,\n", + " third_dimension_name=\"Neuron\",\n", + ")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/test.ipynb b/demos/test.ipynb deleted file mode 100644 index a3340b402..000000000 --- a/demos/test.ipynb +++ /dev/null @@ -1,78 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "OSError", - "evalue": "You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B.\n401 Client Error. (Request ID: Root=1-662aa17c-23adf8063ca56f12201bef1d;fee43197-2550-4397-a2e4-114eedae301c)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json.\nAccess to model meta-llama/Meta-Llama-3-8B is restricted. You must be authenticated to access it.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_errors.py:304\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 304\u001b[0m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/models.py:1021\u001b[0m, in \u001b[0;36mResponse.raise_for_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m http_error_msg:\n\u001b[0;32m-> 1021\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(http_error_msg, response\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m)\n", - "\u001b[0;31mHTTPError\u001b[0m: 401 Client Error: Unauthorized for url: https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mGatedRepoError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/utils/hub.py:398\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 397\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 398\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 399\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath_or_repo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 401\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 411\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 412\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GatedRepoError \u001b[38;5;28;01mas\u001b[39;00m e:\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:119\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1403\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, headers, legacy_cache_layout, endpoint)\u001b[0m\n\u001b[1;32m 1401\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(head_call_error, RepositoryNotFoundError) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(head_call_error, GatedRepoError):\n\u001b[1;32m 1402\u001b[0m \u001b[38;5;66;03m# Repo not found or gated => let's raise the actual error\u001b[39;00m\n\u001b[0;32m-> 1403\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m head_call_error\n\u001b[1;32m 1404\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1405\u001b[0m \u001b[38;5;66;03m# Otherwise: most likely a connection issue or Hub downtime => let's warn the user\u001b[39;00m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1261\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, headers, legacy_cache_layout, endpoint)\u001b[0m\n\u001b[1;32m 1260\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1261\u001b[0m metadata \u001b[38;5;241m=\u001b[39m \u001b[43mget_hf_file_metadata\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1262\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1263\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1264\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1265\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43metag_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1266\u001b[0m \u001b[43m \u001b[49m\u001b[43mlibrary_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlibrary_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1267\u001b[0m \u001b[43m \u001b[49m\u001b[43mlibrary_version\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlibrary_version\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1268\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m EntryNotFoundError \u001b[38;5;28;01mas\u001b[39;00m http_error:\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# Cache the non-existence of the file and raise\u001b[39;00m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:119\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1674\u001b[0m, in \u001b[0;36mget_hf_file_metadata\u001b[0;34m(url, token, proxies, timeout, library_name, library_version, user_agent, headers)\u001b[0m\n\u001b[1;32m 1673\u001b[0m \u001b[38;5;66;03m# Retrieve metadata\u001b[39;00m\n\u001b[0;32m-> 1674\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43m_request_wrapper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1675\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mHEAD\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1676\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1677\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1678\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_redirects\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1679\u001b[0m \u001b[43m \u001b[49m\u001b[43mfollow_relative_redirects\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1680\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1681\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1682\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1683\u001b[0m hf_raise_for_status(r)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:369\u001b[0m, in \u001b[0;36m_request_wrapper\u001b[0;34m(method, url, follow_relative_redirects, **params)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m follow_relative_redirects:\n\u001b[0;32m--> 369\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43m_request_wrapper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 370\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 371\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 372\u001b[0m \u001b[43m \u001b[49m\u001b[43mfollow_relative_redirects\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 373\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 374\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;66;03m# If redirection, we redirect only relative paths.\u001b[39;00m\n\u001b[1;32m 377\u001b[0m \u001b[38;5;66;03m# This is useful in case of a renamed repository.\u001b[39;00m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:393\u001b[0m, in \u001b[0;36m_request_wrapper\u001b[0;34m(method, url, follow_relative_redirects, **params)\u001b[0m\n\u001b[1;32m 392\u001b[0m response \u001b[38;5;241m=\u001b[39m get_session()\u001b[38;5;241m.\u001b[39mrequest(method\u001b[38;5;241m=\u001b[39mmethod, url\u001b[38;5;241m=\u001b[39murl, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams)\n\u001b[0;32m--> 393\u001b[0m \u001b[43mhf_raise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_errors.py:321\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m 318\u001b[0m message \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;241m.\u001b[39mstatus_code\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Client Error.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot access gated repo for url \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;241m.\u001b[39murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m )\n\u001b[0;32m--> 321\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m GatedRepoError(message, response) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m error_message \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAccess to this resource is disabled.\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "\u001b[0;31mGatedRepoError\u001b[0m: 401 Client Error. (Request ID: Root=1-662aa17c-23adf8063ca56f12201bef1d;fee43197-2550-4397-a2e4-114eedae301c)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json.\nAccess to model meta-llama/Meta-Llama-3-8B is restricted. You must be authenticated to access it.", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformer_lens\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m HookedTransformer\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m hf_model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModelForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmeta-llama/Meta-Llama-3-8B\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmps\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat16\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m HookedTransformer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmeta-llama/Meta-Llama-3-8B\u001b[39m\u001b[38;5;124m\"\u001b[39m, hf_model\u001b[38;5;241m=\u001b[39mhf_model)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:523\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantization_config\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 521\u001b[0m _ \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantization_config\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 523\u001b[0m config, kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mAutoConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 524\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 525\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[43m \u001b[49m\u001b[43mcode_revision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 528\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 529\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 530\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 531\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;66;03m# if torch_dtype=auto was passed here, ensure to pass it on\u001b[39;00m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs_orig\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch_dtype\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py:1138\u001b[0m, in \u001b[0;36mAutoConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 1135\u001b[0m trust_remote_code \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrust_remote_code\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 1136\u001b[0m code_revision \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcode_revision\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m-> 1138\u001b[0m config_dict, unused_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mPretrainedConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_config_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1139\u001b[0m has_remote_code \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto_map\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAutoConfig\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto_map\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 1140\u001b[0m has_local_code \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict \u001b[38;5;129;01mand\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01min\u001b[39;00m CONFIG_MAPPING\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/configuration_utils.py:631\u001b[0m, in \u001b[0;36mPretrainedConfig.get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 629\u001b[0m original_kwargs \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(kwargs)\n\u001b[1;32m 630\u001b[0m \u001b[38;5;66;03m# Get config dict associated with the base config file\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m config_dict, kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_config_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict:\n\u001b[1;32m 633\u001b[0m original_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/configuration_utils.py:686\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 682\u001b[0m configuration_file \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_configuration_file\u001b[39m\u001b[38;5;124m\"\u001b[39m, CONFIG_NAME)\n\u001b[1;32m 684\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 685\u001b[0m \u001b[38;5;66;03m# Load from local folder or from cache or download from model Hub and cache\u001b[39;00m\n\u001b[0;32m--> 686\u001b[0m resolved_config_file \u001b[38;5;241m=\u001b[39m \u001b[43mcached_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 687\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 688\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfiguration_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 689\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 690\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 691\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 692\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 693\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 694\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 695\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 700\u001b[0m commit_hash \u001b[38;5;241m=\u001b[39m extract_commit_hash(resolved_config_file, commit_hash)\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to\u001b[39;00m\n\u001b[1;32m 703\u001b[0m \u001b[38;5;66;03m# the original exception.\u001b[39;00m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/utils/hub.py:416\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m resolved_file \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _raise_exceptions_for_gated_repo:\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resolved_file\n\u001b[0;32m--> 416\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 417\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are trying to access a gated repo.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mMake sure to have access to it at \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://huggingface.co/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mstr\u001b[39m(e)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 419\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 420\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m RepositoryNotFoundError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 421\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 422\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not a local folder and is not a valid model identifier \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 423\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlisted on \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://huggingface.co/models\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mIf this is a private repository, make sure to pass a token \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 424\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhaving permission to this repo either by logging in with `huggingface-cli login` or by passing \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 425\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`token=`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 426\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n", - "\u001b[0;31mOSError\u001b[0m: You are trying to access a gated repo.\nMake sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B.\n401 Client Error. (Request ID: Root=1-662aa17c-23adf8063ca56f12201bef1d;fee43197-2550-4397-a2e4-114eedae301c)\n\nCannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json.\nAccess to model meta-llama/Meta-Llama-3-8B is restricted. You must be authenticated to access it." - ] - } - ], - "source": [ - "from transformers import AutoModelForCausalLM\n", - "import os\n", - "from transformer_lens import HookedTransformer\n", - "import torch\n", - "\n", - "os.environ[\"HF_TOKEN\"] = \"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\"\n", - "\n", - "hf_model = AutoModelForCausalLM.from_pretrained(\n", - " \"meta-llama/Meta-Llama-3-8B\", device_map=\"mps\", torch_dtype=torch.float16\n", - ")\n", - "HookedTransformer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\", hf_model=hf_model)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/acceptance/test_hooked_encoder_decoder.py b/tests/acceptance/test_hooked_encoder_decoder.py new file mode 100644 index 000000000..ed1dcfe16 --- /dev/null +++ b/tests/acceptance/test_hooked_encoder_decoder.py @@ -0,0 +1,337 @@ +import pytest +import torch +from jaxtyping import Float +from torch.testing import assert_close +from transformers import AutoTokenizer, T5ForConditionalGeneration + +from transformer_lens import HookedEncoderDecoder + +MODEL_NAME = "t5-small" + + +@pytest.fixture(scope="module") +def our_model(): + return HookedEncoderDecoder.from_pretrained(MODEL_NAME, device="cpu") + + +@pytest.fixture(scope="module") +def huggingface_model(): + return T5ForConditionalGeneration.from_pretrained(MODEL_NAME).eval() + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture +def hello_world_tokens(tokenizer): + return tokenizer("Hello, world!", return_tensors="pt")["input_ids"] + + +@pytest.fixture +def decoder_input_ids(tokenizer): + return torch.LongTensor([[tokenizer.pad_token_id]]) + + +def test_full_model(our_model, huggingface_model, tokenizer, decoder_input_ids): + sequences = ["Hello, world!", "this is another sequence of tokens"] + + tokenized = tokenizer(sequences, return_tensors="pt", padding=True) + decoder_ids = torch.stack([decoder_input_ids[0]] * len(sequences), dim=0) + input_ids = tokenized["input_ids"] + + attention_mask = tokenized["attention_mask"] + + huggingface_model_out = huggingface_model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_ids, + ).logits + our_model_out = our_model( + input_ids, + decoder_input=decoder_ids, + one_zero_attention_mask=attention_mask, + ) + assert_close(huggingface_model_out, our_model_out, rtol=1.3e-6, atol=4e-5) + + +def test_encoder(our_model, huggingface_model, hello_world_tokens): + our_embeds = our_model.embed(hello_world_tokens) + pos_bias = our_model.encoder[0].attn.compute_relative_attention_bias( + hello_world_tokens.shape[1], hello_world_tokens.shape[1] + ) + + for our_layer in our_model.encoder: + our_embeds = our_layer(resid_pre=our_embeds, position_bias=pos_bias) + + our_encoder_out = our_model.encoder_final_ln(our_embeds) + + huggingface_encoder_out = huggingface_model.encoder(hello_world_tokens).last_hidden_state + + assert_close(our_encoder_out, huggingface_encoder_out, rtol=1.3e-6, atol=4e-5) + + +def test_decoder(our_model, huggingface_model, hello_world_tokens, decoder_input_ids): + encoder_hidden = huggingface_model.encoder(hello_world_tokens)[0] + + embeds = our_model.embed(decoder_input_ids) + pos_bias = our_model.decoder[0].attn.compute_relative_attention_bias( + decoder_input_ids.shape[1], decoder_input_ids.shape[1] + ) + for layer in our_model.decoder: + embeds = layer(embeds, encoder_hidden_states=encoder_hidden, position_bias=pos_bias) + + our_decoder_out = our_model.decoder_final_ln(embeds) + hf_decoder_out = huggingface_model.decoder( + decoder_input_ids, encoder_hidden_states=encoder_hidden + )[0] + + assert_close(our_decoder_out, hf_decoder_out, rtol=1.3e-6, atol=4e-5) + + +def test_embed_one_sentence(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.encoder.embed_tokens + our_embed = our_model.embed + + huggingface_embed_out = huggingface_embed(hello_world_tokens)[0] + our_embed_out = our_embed(hello_world_tokens).squeeze(0) + assert_close(huggingface_embed_out, our_embed_out) + + +def test_relative_attention_bias(our_model, huggingface_model, hello_world_tokens): + # it is used only in self attention of first layer of encoder + huggingface_embed = huggingface_model.encoder.embed_tokens + huggingface_attn = huggingface_model.encoder.block[0].layer[0].SelfAttention + our_attn = our_model.encoder[0].attn + + assert huggingface_attn.has_relative_attention_bias + assert our_attn.has_relative_attention_bias + assert ( + our_attn.relative_attention_num_buckets == huggingface_attn.relative_attention_num_buckets + ) + assert ( + our_attn.relative_attention_max_distance == huggingface_attn.relative_attention_max_distance + ) + assert_close(our_attn.rel_pos_bias.weight, huggingface_attn.relative_attention_bias.weight) + + input_len = hello_world_tokens.shape[1] + our_bias = our_attn.compute_relative_attention_bias(input_len, input_len) + hf_bias = huggingface_attn.compute_bias(input_len, input_len) + assert_close(our_bias, hf_bias, rtol=1e-5, atol=1e-5) + + embed_out = huggingface_embed(hello_world_tokens) + + huggingface_attn_out = huggingface_attn(embed_out)[0] + our_attn_out = our_attn(embed_out, embed_out, embed_out, position_bias=our_bias) + + assert_close(our_attn_out, huggingface_attn_out, rtol=7e-4, atol=1e-5) + + +def test_relative_attention_layer(our_model, huggingface_model, hello_world_tokens): + # it is used only in self attention of first layer of encoder + hf_block = huggingface_model.encoder.block[0].layer[0] + our_block = our_model.encoder[0] + resid = huggingface_model.encoder.embed_tokens(hello_world_tokens) + + input_len = hello_world_tokens.shape[1] + our_bias = our_block.attn.compute_relative_attention_bias(input_len, input_len) + resid_norm = our_block.ln1(resid) + our_out = resid + our_block.attn(resid_norm, resid_norm, resid_norm, position_bias=our_bias) + + hf_out = hf_block(resid)[0] + assert_close(our_out, hf_out, rtol=1.3e-6, atol=4e-5) + + +def test_attention(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.encoder.embed_tokens + huggingface_attn = huggingface_model.encoder.block[1].layer[0].SelfAttention + + embed_out = huggingface_embed(hello_world_tokens) + our_attn = our_model.encoder[1].attn + + our_attn_out = our_attn(embed_out, embed_out, embed_out) + huggingface_attn_out = huggingface_attn(embed_out)[0] + + assert_close(our_attn_out, huggingface_attn_out, rtol=5e-4, atol=1e-5) + + +def test_decoder_attention(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.decoder.embed_tokens + huggingface_attn = huggingface_model.decoder.block[1].layer[0].SelfAttention + + embed_out = huggingface_embed(hello_world_tokens) + our_attn = our_model.decoder[1].attn + + our_attn_out = our_attn(embed_out, embed_out, embed_out) + huggingface_attn_out = huggingface_attn(embed_out)[0] + assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=1e-5) + + +def test_attention_layer(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.encoder.embed_tokens + huggingface_attn = huggingface_model.encoder.block[1].layer[0] + + embed_out = huggingface_embed(hello_world_tokens) + our_attn = our_model.encoder[1].attn + norm_embed = our_model.encoder[1].ln1(embed_out) + our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out + + huggingface_attn_out = huggingface_attn(embed_out)[0] + assert_close(our_attn_out, huggingface_attn_out, rtol=2e-4, atol=1e-5) + + +def test_decoder_attention_layer(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.decoder.embed_tokens + huggingface_attn = huggingface_model.decoder.block[1].layer[0] + + embed_out = huggingface_embed(hello_world_tokens) + our_attn = our_model.decoder[1].attn + norm_embed = our_model.decoder[1].ln1(embed_out) + our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out + + huggingface_attn_out = huggingface_attn(embed_out)[0] + assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=4e-5) + + +def test_cross_attention(our_model, huggingface_model, hello_world_tokens, decoder_input_ids): + encoder_hidden = huggingface_model.encoder(hello_world_tokens).last_hidden_state + decoder_hidden = huggingface_model.decoder.embed_tokens(decoder_input_ids) + + huggingface_cross_attn = huggingface_model.decoder.block[0].layer[1].EncDecAttention + our_cross_attn = our_model.decoder[0].cross_attn + + our_cross_attn_out = our_cross_attn(decoder_hidden, encoder_hidden, encoder_hidden) + huggingface_cross_attn_out = huggingface_cross_attn( + decoder_hidden, key_value_states=encoder_hidden + )[0] + assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5) + + +def test_cross_attention_layer(our_model, huggingface_model, hello_world_tokens, decoder_input_ids): + encoder_hidden = huggingface_model.encoder(hello_world_tokens).last_hidden_state + decoder_hidden = huggingface_model.decoder.embed_tokens(decoder_input_ids) + + hf_layer = huggingface_model.decoder.block[0].layer[1] + our_layer = our_model.decoder[0] + # assert ln weights are the same + assert_close(hf_layer.layer_norm.weight, our_layer.ln2.w) + + our_cross_attn_out = ( + our_layer.cross_attn(our_layer.ln2(decoder_hidden), encoder_hidden, encoder_hidden) + + decoder_hidden + ) + huggingface_cross_attn_out = hf_layer(decoder_hidden, key_value_states=encoder_hidden)[0] + assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5) + + +def test_encoder_block(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.encoder.embed_tokens + huggingface_block = huggingface_model.encoder.block[1] + our_block = our_model.encoder[1] + + embed_out = huggingface_embed(hello_world_tokens) + + hf_out = huggingface_block(embed_out)[0] + our_out = our_block(embed_out) + + assert_close(our_out, hf_out, rtol=2e-4, atol=2e-5) + + +def test_decoder_block(our_model, huggingface_model, hello_world_tokens, decoder_input_ids): + huggingface_embed = huggingface_model.decoder.embed_tokens + huggingface_block = huggingface_model.decoder.block[1] + our_block = our_model.decoder[1] + + encoder_hidden = huggingface_model.encoder(hello_world_tokens)[0] + decoder_hidden = huggingface_model.decoder.block[0](huggingface_embed(decoder_input_ids))[0] + + our_out = our_block(decoder_hidden, encoder_hidden_states=encoder_hidden) + hf_out = huggingface_block(decoder_hidden, encoder_hidden_states=encoder_hidden)[0] + + assert_close(hf_out, our_out, rtol=2e-4, atol=2e-5) + + +def test_layernorm(our_model, huggingface_model, hello_world_tokens): + huggingface_embed = huggingface_model.encoder.embed_tokens + huggingface_layernorm = huggingface_model.encoder.block[0].layer[0].layer_norm + our_layernorm = our_model.encoder[0].ln1 + + embed_out = huggingface_embed(hello_world_tokens) + + our_layernorm_out = our_layernorm(embed_out) + huggingface_layernorm_out = huggingface_layernorm(embed_out) + assert_close(our_layernorm_out, huggingface_layernorm_out) + + +def test_unembed(our_model, huggingface_model, hello_world_tokens): + huggingface_model_hidden = huggingface_model.decoder(hello_world_tokens).last_hidden_state + + our_model_logits = our_model.unembed(huggingface_model_hidden) + huggingface_model_logits = huggingface_model.lm_head(huggingface_model_hidden) + + assert_close(our_model_logits, huggingface_model_logits, rtol=1.3e-3, atol=1e-5) + + +def test_run_with_cache(our_model, hello_world_tokens, decoder_input_ids): + logits, cache = our_model.run_with_cache(hello_world_tokens, decoder_input=decoder_input_ids) + + # check that an arbitrary subset of the keys exist and have the right shape + seq_len = 5 + generated_len = 1 + assert "hook_embed" in cache + assert cache["hook_embed"].shape == (1, seq_len, 512) + assert "encoder.1.attn.hook_v" in cache + assert cache["encoder.1.attn.hook_v"].shape == (1, seq_len, 8, 64) + assert "encoder.3.attn.hook_attn_scores" in cache + assert cache["encoder.3.attn.hook_attn_scores"].shape == (1, 8, seq_len, seq_len) + assert "decoder.0.cross_attn.hook_k" in cache + assert cache["decoder.0.cross_attn.hook_attn_scores"].shape == ( + 1, + 8, + generated_len, + seq_len, + ) + assert "decoder.3.hook_resid_post" in cache + assert cache["decoder.3.hook_resid_post"].shape == (1, generated_len, 512) + + +def test_from_pretrained_revision(): + """ + Check that the from_pretrained parameter `revision` (= git version) works + """ + + _ = HookedEncoderDecoder.from_pretrained(MODEL_NAME, revision="main") + + try: + _ = HookedEncoderDecoder.from_pretrained(MODEL_NAME, revision="inexistent_branch_name") + except: + pass + else: + raise AssertionError("Should have raised an error") + + +def test_predictions(our_model, huggingface_model, tokenizer, decoder_input_ids): + input_ids = tokenizer("My name is Wolfgang and I live in Berlin", return_tensors="pt")[ + "input_ids" + ] + + def get_predictions(logits: Float[torch.Tensor, "batch pos d_vocab"]): + predicted_tokens = logits[0].argmax(dim=-1) + return tokenizer.batch_decode(predicted_tokens) + + our_model_logits = our_model(input_ids, decoder_input=decoder_input_ids) + our_prediction = get_predictions(our_model_logits) + + huggingface_model_logits = huggingface_model( + input_ids, decoder_input_ids=decoder_input_ids + ).logits + huggingface_prediction = get_predictions(huggingface_model_logits) + + assert our_prediction == huggingface_prediction + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device") +def test_cuda(hello_world_tokens, decoder_input_ids): + model = HookedEncoderDecoder.from_pretrained(MODEL_NAME) + model(hello_world_tokens, decoder_input=decoder_input_ids.cuda()) diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py new file mode 100644 index 000000000..82e958aea --- /dev/null +++ b/transformer_lens/HookedEncoderDecoder.py @@ -0,0 +1,416 @@ +"""Hooked EncoderDecoder + +Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer` +because it has a significantly different architecture to e.g. GPT style transformers. +""" + +from __future__ import annotations + +import logging +import os +from itertools import chain +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union, cast, overload + +import torch +from einops import repeat +from jaxtyping import Float, Int +from torch import nn +from transformers import AutoTokenizer +from typing_extensions import Literal + +import transformer_lens.loading_from_pretrained as loading +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.components import Embed, RMSNorm, T5Block, Unembed +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utilities import devices + + +class HookedEncoderDecoder(HookedRootModule): + """ + This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + - The model only accepts tokens as inputs, and not strings, or lists of strings + """ + + def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead." + ) + self.cfg = cfg + + if self.cfg.n_devices != 1: + raise ValueError("Multiple devices not supported for HookedEncoderDecoder") + if tokenizer is not None: + self.tokenizer = tokenizer + elif self.cfg.tokenizer_name is not None: + huggingface_token = os.environ.get("HF_TOKEN", None) + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer_name, + token=huggingface_token, + ) + else: + self.tokenizer = None + + if self.cfg.d_vocab == -1: + # If we have a tokenizer, vocab size can be inferred from it. + if self.tokenizer is None: + raise ValueError("Must provide a tokenizer if d_vocab is not provided") + + self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 + if self.cfg.d_vocab_out == -1: + self.cfg.d_vocab_out = self.cfg.d_vocab + + self.embed = Embed(self.cfg) + self.encoder = nn.ModuleList( + [ + T5Block(self.cfg, num_layer, is_decoder=False) + for num_layer in range(self.cfg.n_layers) + ] + ) + self.encoder_final_ln = RMSNorm(self.cfg) + self.decoder = nn.ModuleList( + [ + T5Block(self.cfg, num_layer, is_decoder=True) + for num_layer in range(self.cfg.n_layers) + ] + ) + self.decoder_final_ln = RMSNorm(self.cfg) + # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out) + self.unembed = Unembed(self.cfg) + + self.hook_embed = HookPoint() + + if move_to_device: + self.to(self.cfg.device) + + self.setup() + + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + decoder_input: Int[torch.Tensor, "batch decoder_pos"], + return_type: Optional[str] = "logits", + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: + """Input must be a batch of tokens. Strings and lists of strings are not yet supported. + decoder_input: Int[torch.Tensor, "batch decoder_pos"]: The input to the decoder. This is the sequence of tokens that the model will generate, usually with a start token at the beginning + return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). + one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. + """ + + tokens = input + + if tokens.device.type != self.cfg.device: + tokens = tokens.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + resid = self.hook_embed(self.embed(tokens)) + + if one_zero_attention_mask is not None: + additive_attention_mask = ( + repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + ) * torch.finfo(self.cfg.dtype).min + else: + additive_attention_mask = None + + query_len = key_len = input.shape[1] + + encoder_positional_bias = self.encoder[0].attn.compute_relative_attention_bias( + query_len, key_len, device=self.cfg.device + ) + + for encoder_block in self.encoder: + resid = encoder_block( + resid_pre=resid, + additive_attention_mask=additive_attention_mask, + position_bias=encoder_positional_bias, + ) + + encoder_resid = self.encoder_final_ln(resid) + + decoder_resid = self.embed(decoder_input) + decoder_query_len = decoder_key_len = decoder_input.shape[1] + decoder_positional_bias = self.decoder[0].attn.compute_relative_attention_bias( + decoder_query_len, decoder_key_len, device=self.cfg.device + ) + + for decoder_block in self.decoder: + decoder_resid = decoder_block( + resid_pre=decoder_resid, + position_bias=decoder_positional_bias, + encoder_hidden_states=encoder_resid, + encoder_additive_attention_mask=additive_attention_mask, + ) + + decoder_resid = self.decoder_final_ln(decoder_resid) + + if self.cfg.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + decoder_resid *= self.cfg.d_model**-0.5 + + logits = self.unembed(decoder_resid) + if return_type is None: + return None + return logits + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[True] = True, **kwargs + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: + ... + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[False], **kwargs + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: + ... + + def run_with_cache( + self, + *model_args, + return_cache_object: bool = True, + remove_batch_dim: bool = False, + **kwargs, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) + return out, cache + else: + return out, cache_dict + + def to( # type: ignore + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cuda") + + def cpu(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cpu") + + def mps(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("mps") + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model=None, + device: Optional[str] = None, + tokenizer=None, + move_to_device=True, + dtype=torch.float32, + **from_pretrained_kwargs, + ) -> HookedEncoderDecoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using T5 for interpretability research, keep in mind that T5 has some significant architectural " + "differences to GPT. The major one is that T5 is an Encoder-Decoder model" + "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm" + ) + + if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( + "load_in_4bit", False + ): + raise ValueError("Quantization not supported") + + if "torch_dtype" in from_pretrained_kwargs: + dtype = from_pretrained_kwargs["torch_dtype"] + + name_or_path = ( + model_name if Path(model_name).exists() else loading.get_official_model_name(model_name) + ) + + cfg = loading.get_pretrained_model_config( + name_or_path, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + dtype=dtype, + **from_pretrained_kwargs, + ) + + state_dict = loading.get_pretrained_state_dict( + name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + ) + + model = cls(cfg, tokenizer, move_to_device=False) + + model.load_state_dict(state_dict, strict=False) + + if move_to_device: + model.to(cfg.device) + + print(f"Loaded pretrained model {model_name} into HookedTransformer") + + return model + + @property + def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: + """ + Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) + """ + return self.unembed.W_U + + @property + def b_U(self) -> Float[torch.Tensor, "d_vocab"]: + """ + Convenience to get the unembedding bias + """ + return self.unembed.b_U + + @property + def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: + """ + Convenience to get the embedding matrix + """ + return self.embed.W_E + + @property + def W_pos(self) -> None: + """ + Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! + """ + raise NotImplementedError( + "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead." + ) + + @property + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + return torch.stack( + [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + return torch.stack( + [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + return torch.stack( + [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + return torch.stack( + [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + """Stacks the MLP input weights across all layers""" + return torch.stack( + [cast(T5Block, block).mlp.W_in for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + return torch.stack( + [cast(T5Block, block).mlp.W_out for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + return torch.stack( + [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + return torch.stack( + [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + return torch.stack( + [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)], + dim=0, + ) + + @property + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + return torch.stack( + [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + return torch.stack( + [cast(T5Block, block).mlp.b_in for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + return torch.stack( + [cast(T5Block, block).mlp.b_out for block in chain(self.encoder, self.decoder)], dim=0 + ) + + @property + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. + Useful for visualizing attention patterns.""" + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> List[str]: + """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" + return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ + f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads) + ] diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 1e1e595ed..fa743016b 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -159,6 +159,12 @@ class HookedTransformerConfig: must also be set. Set to None if not using MoE. experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set, num_experts must also be set. Set to None if not using MoE. + relative_attention_max_distance (int, *optional*): The maximum distance between tokens for relative + attention. If set, relative_attention_num_buckets must also be set.Only used in EncoderDecoder models, like T5. + relative_attention_num_buckets (int, *optional*): The number of buckets to use for relative attention. + If set, relative_attention_max_distance must also be set.Only used in EncoderDecoder models, like T5. + decoder_start_token_id (int, *optional*): The start token id for the decoder. Only used in EncoderDecoder models, like T5. + tie_word_embeddings (bool): Whether to tie the word embeddings and the output layer weights. Defaults to False. Only used in EncoderDecoder (T5) by now. """ n_layers: int @@ -214,6 +220,10 @@ class HookedTransformerConfig: load_in_4bit: bool = False num_experts: Optional[int] = None experts_per_token: Optional[int] = None + relative_attention_max_distance: Optional[int] = None + relative_attention_num_buckets: Optional[int] = None + decoder_start_token_id: Optional[int] = None + tie_word_embeddings: bool = False def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index e2fb1484b..3710b434d 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -12,6 +12,7 @@ from .HookedTransformer import HookedTransformer from .SVDInterpreter import SVDInterpreter from .HookedEncoder import HookedEncoder +from .HookedEncoderDecoder import HookedEncoderDecoder from . import head_detector from . import loading_from_pretrained as loading from . import patching diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py index 47677426a..f077fadc6 100644 --- a/transformer_lens/components/__init__.py +++ b/transformer_lens/components/__init__.py @@ -4,6 +4,7 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ + # Independent classes from .abstract_attention import AbstractAttention from .layer_norm import LayerNorm @@ -27,3 +28,5 @@ from .bert_embed import BertEmbed from .moe import MoE from .transformer_block import TransformerBlock +from .t5_attention import T5Attention +from .t5_block import T5Block diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index cc22519c7..4494a70bc 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -79,7 +79,8 @@ def __init__( self.register_buffer("mask", causal_mask) elif self.attn_type == "local": # For local, this is banded, query - window_size < key <= query - assert isinstance(self.cfg.window_size, int) + if not isinstance(self.cfg.window_size, int): + raise ValueError("Window size must be an integer for local attention") self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) else: raise ValueError(f"Invalid attention type: {self.attn_type}") @@ -94,7 +95,8 @@ def __init__( else: self.attn_scale = 1.0 if self.cfg.scale_attn_by_inverse_layer_idx: - assert self.layer_id is not None # keep mypy happy + if self.layer_id is None: # keep mypy happy + raise ValueError("Layer ID must be provided to scale attention scores") self.attn_scale *= self.layer_id + 1 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] @@ -113,7 +115,8 @@ def __init__( # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details self.hook_rot_k = HookPoint() self.hook_rot_q = HookPoint() - assert self.cfg.rotary_dim is not None # keep mypy happy + if self.cfg.rotary_dim is None: # keep mypy happy + raise ValueError("Rotary dim must be provided for rotary positional embeddings") sin, cos = self.calculate_sin_cos_rotary( self.cfg.rotary_dim, self.cfg.n_ctx, @@ -127,6 +130,10 @@ def __init__( # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. self.alibi = None + elif self.cfg.positional_embedding_type == "relative_positional_bias": + # will be overwritten by the child T5Attention class + self.has_relative_attention_bias = False + @property def OV(self) -> FactoredMatrix: """ @@ -159,18 +166,19 @@ def forward( Float[torch.Tensor, "batch pos head_index d_model"], ], key_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - Float[torch.Tensor, "batch pos kv_head_index d_model"], + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], + Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], ], value_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - Float[torch.Tensor, "batch pos kv_head_index d_model"], + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], + Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], ], past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, - additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, ) -> Float[torch.Tensor, "batch pos d_model"]: """ shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details @@ -218,7 +226,20 @@ def forward( attn_scores += self.alibi[ :, :query_ctx, :key_ctx ] # [batch, head_index, query_pos, key_pos] + elif self.cfg.positional_embedding_type == "relative_positional_bias": + if position_bias is None: + if self.has_relative_attention_bias: + raise ValueError("Positional bias is required for relative_positional_bias") + else: + position_bias = torch.zeros( + 1, + self.cfg.n_heads, + attn_scores.shape[2], + attn_scores.shape[3], + device=attn_scores.device, + ) + attn_scores += position_bias if self.cfg.attention_dir == "causal": # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. attn_scores = self.apply_causal_mask( @@ -293,17 +314,17 @@ def calculate_qkv_matrices( Float[torch.Tensor, "batch pos head_index d_model"], ], key_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], ], value_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], ], ) -> Tuple[ Float[torch.Tensor, "batch pos head_index d_head"], - Float[torch.Tensor, "batch pos head_index d_head"], - Float[torch.Tensor, "batch pos head_index d_head"], + Float[torch.Tensor, "batch kv_pos head_index d_head"], + Float[torch.Tensor, "batch kv_pos head_index d_head"], ]: if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: qkv_einops_string = "batch pos head_index d_model" @@ -337,7 +358,8 @@ def calculate_qkv_matrices( + self.b_Q ) # [batch, pos, head_index, d_head] if self.cfg.load_in_4bit: - assert isinstance(self.W_K, Params4bit) + if not isinstance(self.W_K, Params4bit): + raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") k = self.hook_k( # call bitsandbytes method to dequantize and multiply bnb.matmul_4bit( @@ -362,7 +384,8 @@ def calculate_qkv_matrices( ) # [batch, pos, head_index, d_head] if self.cfg.load_in_4bit: - assert isinstance(self.W_V, Params4bit) + if not isinstance(self.W_V, Params4bit): + raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") v = self.hook_v( # call bitsandbytes method to dequantize and multiply bnb.matmul_4bit( @@ -435,9 +458,10 @@ def apply_causal_mask( # If not caching, query_ctx_length == key_ctx_length key_ctx_length = attn_scores.size(-1) - assert ( - query_ctx_length + past_kv_pos_offset == key_ctx_length - ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." + if query_ctx_length + past_kv_pos_offset != key_ctx_length: + raise ValueError( + f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." + ) # Index back to front to ensure local attention works final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] diff --git a/transformer_lens/components/t5_attention.py b/transformer_lens/components/t5_attention.py new file mode 100644 index 000000000..ef74b0919 --- /dev/null +++ b/transformer_lens/components/t5_attention.py @@ -0,0 +1,140 @@ +import math +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from jaxtyping import Float, Int + +from transformer_lens.components.abstract_attention import AbstractAttention +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +class T5Attention(AbstractAttention): + r""" + T5 attention - with relative attention bias and cross-attention support + This realisation expects you to precompute relative positional bias, and then feed it to forward + like + ```python + attn = T5Attention(cfg, has_relative_attention_bias=True) + positional_bias = attn.compute_relative_attention_bias(query_len, key_len, device=device) + result = attn(query, key, value, position_bias=positional_bias) + ``` + """ + + def __init__( + self, + cfg: Union[Dict, HookedTransformerConfig], + has_relative_attention_bias: bool = False, + attn_type: str = "global", + layer_id: Optional[int] = None, + ): + super().__init__(cfg, attn_type, layer_id) + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.has_relative_attention_bias: bool = has_relative_attention_bias + + if self.has_relative_attention_bias: + if ( + cfg.relative_attention_num_buckets is None + or cfg.relative_attention_max_distance is None + ): + raise ValueError( + "You need to specify relative_attention_num_buckets and relative_attention_max_distance in config to use relative attention bias" + ) + + self.relative_attention_num_buckets = cfg.relative_attention_num_buckets + self.relative_attention_max_distance = cfg.relative_attention_max_distance + self.rel_pos_bias = nn.Embedding(self.relative_attention_num_buckets, self.cfg.n_heads) + self.rel_pos_hook = HookPoint() + + self.W_K = nn.Parameter( + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) + ) + self.W_V = nn.Parameter( + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) + ) + self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) + self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) + + @staticmethod + def _relative_position_bucket( + relative_position: Int[torch.Tensor, "query_pos kv_pos"], + bidirectional=True, + num_buckets=32, + max_distance=128, + ) -> Int[torch.Tensor, "query_pos kv_pos"]: + """ + added from + https://github.com/huggingface/transformers/blob/e0c3cee17085914bbe505c159beeb8ae39bc37dd/src/transformers/models/t5/modeling_t5.py#L382 + which is adapted from + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = torch.zeros_like(relative_position) + + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_relative_attention_bias( + self, query_length: int, key_length: int, device=None + ) -> Float[torch.Tensor, "1 head_index pos kv_pos"]: + """Compute binned relative position bias""" + if device is None: + device = self.rel_pos_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.rel_pos_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values diff --git a/transformer_lens/components/t5_block.py b/transformer_lens/components/t5_block.py new file mode 100644 index 000000000..1e5c8281b --- /dev/null +++ b/transformer_lens/components/t5_block.py @@ -0,0 +1,155 @@ +from typing import Optional + +import torch +import torch.nn as nn +from jaxtyping import Float + +from transformer_lens.components import MLP, RMSNorm, T5Attention +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry +from transformer_lens.utils import repeat_along_head_dimension + + +class T5Block(nn.Module): + """ + T5 decoder Block. Uses T5Layernorm, and T5attention insted of usual ones. + Also uses cross attention if is_decoder is True. + """ + + def __init__(self, cfg: HookedTransformerConfig, block_index: int, is_decoder: bool): + super().__init__() + self.cfg = cfg + self.is_decoder = is_decoder + + self.ln1 = RMSNorm(cfg) + self.attn = T5Attention(cfg, has_relative_attention_bias=block_index == 0) + self.ln2 = RMSNorm(cfg) + if self.is_decoder: + self.cross_attn = T5Attention(cfg) + self.ln3 = RMSNorm(cfg) + self.mlp = MLP(cfg) # [batch, pos, n_heads] + + self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] + + self.hook_attn_in = HookPoint() # [batch, pos, d_model] + self.hook_attn_out = HookPoint() # [batch, pos, d_model] + if self.is_decoder: + self.hook_cross_attn_in = HookPoint() # [batch, pos, d_model] + self.hook_cross_attn_out = HookPoint() # [batch, pos, d_model] + self.hook_resid_mid_cross = HookPoint() # [batch, pos, d_model] + + self.hook_mlp_in = HookPoint() # [batch, pos, d_model] + self.hook_mlp_out = HookPoint() # [batch, pos, d_model] + self.hook_resid_pre = HookPoint() # [batch, pos, d_model] + self.hook_resid_mid = HookPoint() # [batch, pos, d_model] + self.hook_resid_post = HookPoint() # [batch, pos, d_model] + + def forward( + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + encoder_additive_attention_mask: Optional[ + Float[torch.Tensor, "batch 1 1 encoder_pos"] + ] = None, + position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, + encoder_hidden_states: Optional[Float[torch.Tensor, "batch encoder_pos d_model"]] = None, + past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + """A single Transformer block. + + Args: + resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] + encoder_hidden_states (torch.Tensor): The hidden states of the encoder for cross attention - shape [batch, encoder_pos, d_model] + cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. + attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. + + Returns: + _type_: _description_ + """ + resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] + + attn_in = resid_pre + + if self.cfg.use_attn_in: + attn_in = self.hook_attn_in( + repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) + ) + + if self.cfg.use_split_qkv_input: + n_kv_heads = ( + self.cfg.n_key_value_heads + if self.cfg.n_key_value_heads is not None + else self.cfg.n_heads + ) + query_input = self.hook_q_input( + repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) + ) + key_input = self.hook_k_input( + repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) + ) + value_input = self.hook_v_input( + repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) + ) + else: + query_input = attn_in + key_input = attn_in + value_input = attn_in + + attn_out = self.hook_attn_out( + # hook the residual stream states that are used to calculate the + # queries, keys and values, independently. + # Then take the layer norm of these inputs, and pass these to the attention module. + self.attn( + query_input=self.ln1(query_input), + key_input=self.ln1(key_input), + value_input=self.ln1(value_input), + past_kv_cache_entry=past_kv_cache_entry, + additive_attention_mask=additive_attention_mask, + position_bias=position_bias, + ) + ) + + # [batch, pos, d_model] + + resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] + + if self.is_decoder: + cross_attn_in = ( + resid_mid + if not self.cfg.use_attn_in + else self.hook_cross_attn_in(resid_mid.clone()) + ) + + if encoder_hidden_states is None: + raise ValueError("Encoder hidden states must be provided for cross attention!") + + cross_attn_out = self.hook_cross_attn_out( + self.cross_attn( + query_input=self.ln2(cross_attn_in), + key_input=encoder_hidden_states, + value_input=encoder_hidden_states, + additive_attention_mask=encoder_additive_attention_mask, + ) + ) + resid_mid_cross = self.hook_resid_mid_cross(resid_mid + cross_attn_out) + + mlp_in = ( + resid_mid_cross + if not self.cfg.use_hook_mlp_in + else self.hook_mlp_in(resid_mid_cross.clone()) + ) + + normalized_resid_mid = self.ln3(mlp_in) + else: + mlp_in = ( + resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) + ) + normalized_resid_mid = self.ln2(mlp_in) + + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model] + resid_post = self.hook_resid_post(mlp_in + mlp_out) # [batch, pos, d_model] + + return resid_post diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 67ade6490..1190337fb 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -63,7 +63,8 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): if not self.cfg.use_local_attn: self.attn = attention(cfg, "global", block_index) else: - assert self.cfg.attn_types is not None + if self.cfg.attn_types is None: + raise ValueError("attn_types must be set when using local attention") attn_type = self.cfg.attn_types[block_index] self.attn = attention(cfg, attn_type, block_index) if not self.cfg.attn_only: diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 10458d6b3..15fc5e7b2 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -7,12 +7,18 @@ import logging import os import re +from pathlib import Path from typing import Dict, Optional, Union, cast import einops import torch from huggingface_hub import HfApi -from transformers import AutoConfig, AutoModelForCausalLM, BertForPreTraining +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + BertForPreTraining, + T5ForConditionalGeneration, +) import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -183,6 +189,9 @@ "01-ai/Yi-34B", "01-ai/Yi-6B-Chat", "01-ai/Yi-34B-Chat", + "google-t5/t5-small", + "google-t5/t5-base", + "google-t5/t5-large", "ai-forever/mGPT", ] """Official model names for models on HuggingFace.""" @@ -598,6 +607,9 @@ "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], + "google-t5/t5-small": ["t5-small"], + "google-t5/t5-base": ["t5-base"], + "google-t5/t5-large": ["t5-large"], "ai-forever/mGPT": ["mGPT"], } """Model aliases for models on HuggingFace.""" @@ -659,7 +671,12 @@ def convert_hf_model_config(model_name: str, **kwargs): Takes the official_model_name as an input. """ # In case the user passed in an alias - official_model_name = get_official_model_name(model_name) + if (Path(model_name) / "config.json").exists(): + logging.info("Loading model config from local directory") + official_model_name = model_name + else: + official_model_name = get_official_model_name(model_name) + # Load HuggingFace model config if "llama" in official_model_name.lower(): architecture = "LlamaForCausalLM" @@ -1168,6 +1185,25 @@ def convert_hf_model_config(model_name: str, **kwargs): "gated_mlp": True, "final_rms": True, } + elif architecture == "T5ForConditionalGeneration": + cfg_dict = { + "d_model": hf_config.d_model, + "d_head": hf_config.d_kv, + "n_heads": hf_config.num_heads, + "d_mlp": hf_config.d_ff, + "d_vocab": hf_config.vocab_size, + "n_layers": hf_config.num_layers, + "n_ctx": hf_config.max_length, + "eps": hf_config.layer_norm_epsilon, + "act_fn": hf_config.feed_forward_proj, + "positional_embedding_type": "relative_positional_bias", + "relative_attention_max_distance": hf_config.relative_attention_max_distance, + "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, + "decoder_start_token_id": hf_config.decoder_start_token_id, + "attention_dir": "bidirectional", + "use_attn_scale": False, + "tie_word_embeddings": hf_config.tie_word_embeddings, + } else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -1266,7 +1302,12 @@ def get_pretrained_model_config( Also given to other HuggingFace functions when compatible. """ - official_model_name = get_official_model_name(model_name) + if Path(model_name).exists(): + # If the model_name is a path, it's a local model + cfg_dict = convert_hf_model_config(model_name, **kwargs) + official_model_name = model_name + else: + official_model_name = get_official_model_name(model_name) if ( official_model_name.startswith("NeelNanda") or official_model_name.startswith("ArthurConmy") @@ -1422,7 +1463,11 @@ def get_pretrained_state_dict( if "torch_dtype" in kwargs: dtype = kwargs["torch_dtype"] del kwargs["torch_dtype"] - official_model_name = get_official_model_name(official_model_name) + if Path(official_model_name).exists(): + official_model_name = str(Path(official_model_name).resolve()) + logging.info(f"Loading model from local path {official_model_name}") + else: + official_model_name = get_official_model_name(official_model_name) if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( "trust_remote_code", False ): @@ -1488,6 +1533,13 @@ def get_pretrained_state_dict( token=huggingface_token, **kwargs, ) + elif "t5" in official_model_name: + hf_model = T5ForConditionalGeneration.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token, + **kwargs, + ) else: hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, @@ -1515,6 +1567,8 @@ def get_pretrained_state_dict( state_dict = convert_llama_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) + elif cfg.original_architecture == "T5ForConditionalGeneration": + state_dict = convert_t5_weights(hf_model, cfg) elif cfg.original_architecture == "MistralForCausalLM": state_dict = convert_mistral_weights(hf_model, cfg) elif cfg.original_architecture == "MixtralForCausalLM": @@ -2446,6 +2500,104 @@ def convert_bert_weights(bert, cfg: HookedTransformerConfig): return state_dict +def convert_t5_weights(t5, cfg: HookedTransformerConfig): + state_dict = { + "embed.W_E": t5.encoder.embed_tokens.weight, + "unembed.W_U": t5.encoder.embed_tokens.weight.T, + "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0] + .layer[0] + .SelfAttention.relative_attention_bias.weight, + } + + for l in range(cfg.n_layers): + block = t5.encoder.block[l] + state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange( + block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange( + block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + + state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange( + block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + + state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange( + block.layer[0].SelfAttention.o.weight, + "m (i h) -> i h m", + i=cfg.n_heads, + ) + state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight + + # fixme DenseReluDense may be T5DenseGatedActDense instead + state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange( + block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp" + ) + + state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange( + block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model" + ) + state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight + + state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight + + state_dict["decoder.0.attn.rel_pos_bias.weight"] = ( + t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight + ) + + for l in range(cfg.n_layers): + block = t5.decoder.block[l] + state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange( + block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + + state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange( + block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange( + block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + + state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange( + block.layer[0].SelfAttention.o.weight, + "m (i h) -> i h m", + i=cfg.n_heads, + ) + + state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight + + state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange( + block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + + state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange( + block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + + state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange( + block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange( + block.layer[1].EncDecAttention.o.weight, + "m (i h) -> i h m", + i=cfg.n_heads, + ) + state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight + + # fixme DenseReluDense may be T5DenseGatedActDense instead + state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange( + block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp" + ) + state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange( + block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model" + ) + state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight + + state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight + + return state_dict + + def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): state_dict = {} diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index c8e5b78b7..f7de5d3c7 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -3,6 +3,7 @@ Utilities to get the correct device, and assist in distributing model layers across multiple devices. """ + from __future__ import annotations from typing import Optional, Union @@ -45,7 +46,11 @@ def get_device_for_block_index( def move_to_and_update_config( - model: Union["transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder"], + model: Union[ + "transformer_lens.HookedTransformer", + "transformer_lens.HookedEncoder", + "transformer_lens.HookedEncoderDecoder", + ], device_or_dtype: Union[torch.device, str, torch.dtype], print_details=True, ):