Skip to content

Commit

Permalink
Add multi-sample text neuron activations support (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
danbraunai authored Jan 3, 2023
1 parent 32979cb commit 9fad349
Show file tree
Hide file tree
Showing 12 changed files with 7,542 additions and 2,413 deletions.
9,243 changes: 7,081 additions & 2,162 deletions python/Demonstration.ipynb

Large diffs are not rendered by default.

43 changes: 36 additions & 7 deletions python/circuitsvis/activations.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,57 @@
"""Activations visualizations"""
from typing import List, Union
from typing import List, Union, Optional

import numpy as np
import torch
from circuitsvis.utils.render import RenderedHTML, render


def text_neuron_activations(
tokens: List[str],
activations: Union[List[List[List[float]]], np.ndarray, torch.Tensor],
tokens: Union[List[str], List[List[str]]],
activations: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
first_dimension_name: Optional[str] = "Layer",
second_dimension_name: Optional[str] = "Neuron",
first_dimension_labels: Optional[List[str]] = None,
second_dimension_labels: Optional[List[str]] = None,
) -> RenderedHTML:
"""Show activations (colored by intensity) for each token in some text
"""Show activations (colored by intensity) for each token in a text or set
of texts.
Includes drop-downs for layer and neuron numbers.
Args:
tokens: List of tokens (e.g. `["A", "person"]`)
activations: Activations of the shape [tokens x layers x neurons]
tokens: List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`)
activations: Activations of the shape [tokens x layers x neurons] if
single sample or list of [tokens x layers x neurons] if multiple samples
Returns:
Html: Text neuron activations visualization
"""
# Verify that activations and tokens have the right shape and convert to
# nested lists
if isinstance(activations, (np.ndarray, torch.Tensor)):
assert (
activations.ndim == 3
), "activations must be of shape [tokens x layers x neurons]"
activations_list = activations.tolist()
elif isinstance(activations, list):
activations_list = []
for act in activations:
assert (
act.ndim == 3
), "activations must be of shape [tokens x layers x neurons]"
activations_list.append(act.tolist())
else:
raise TypeError(
f"activations must be of type np.ndarray, torch.Tensor, or list, not {type(activations)}"
)

return render(
"TextNeuronActivations",
tokens=tokens,
activations=activations,
activations=activations_list,
firstDimensionName=first_dimension_name,
secondDimensionName=second_dimension_name,
firstDimensionLabels=first_dimension_labels,
secondDimensionLabels=second_dimension_labels,
)
16 changes: 13 additions & 3 deletions python/circuitsvis/tests/snapshots/snap_test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@

snapshots = Snapshot()

snapshots['TestTextNeuronActivations.test_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/>
snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/>
<script crossorigin type="module">
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.0.0/dist/cdn/esm.js";
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.34.0/dist/cdn/esm.js";
render(
"circuits-vis-mock",
TextNeuronActivations,
{"tokens": ["a", "b"], "activations": [[[0, 1], [0, 1]]]}
{"tokens": [["a", "b"], ["c", "d", "e"]], "activations": [[[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]], [[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]], [[0, 1, 1], [1, 1, 1]]]], "firstDimensionName": "Layer", "secondDimensionName": "Neuron"}
)
</script>'''

snapshots['TestTextNeuronActivations.test_single_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/>
<script crossorigin type="module">
import { render, TextNeuronActivations } from "https://unpkg.com/[email protected]/dist/cdn/esm.js";
render(
"circuits-vis-mock",
TextNeuronActivations,
{"tokens": ["a", "b"], "activations": [[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]], "firstDimensionName": "Layer", "secondDimensionName": "Neuron"}
)
</script>'''
32 changes: 27 additions & 5 deletions python/circuitsvis/tests/test_activations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
import circuitsvis
from circuitsvis.activations import text_neuron_activations
import circuitsvis.utils.render
import numpy as np
from circuitsvis.activations import text_neuron_activations


class TestTextNeuronActivations:
def test_matches_snapshot(self, snapshot, monkeypatch):
def test_single_matches_snapshot(self, snapshot, monkeypatch):
# Monkeypatch uuid4 to always return the same uuid
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0")

res = text_neuron_activations(
tokens=["a", "b"],
activations=np.array([[[0, 1], [0, 1]]])
activations=np.array(
[[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]]
), # [tokens (2) x layers (2) x neurons(3)]
)
snapshot.assert_match(str(res))

def test_multi_matches_snapshot(self, snapshot, monkeypatch):
# Monkeypatch uuid4 to always return the same uuid
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")

res = text_neuron_activations(
tokens=[["a", "b"], ["c", "d", "e"]],
activations=[
np.array(
[[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]]
), # [tokens (2) x layers (2) x neurons(3)]
np.array(
[
[[0, 1, 0], [0, 1, 1]],
[[0, 1, 1], [1, 1, 1]],
[[0, 1, 1], [1, 1, 1]],
]
), # [tokens (3) x layers (2) x neurons(3)]
],
)
snapshot.assert_match(str(res))
20 changes: 16 additions & 4 deletions react/src/activations/TextNeuronActivations.stories.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { ComponentStory, ComponentMeta } from "@storybook/react";
import React from "react";
import { mockActivations, mockTokens } from "./mocks/textNeuronActivations";
import {
mockActivations,
mockTokens,
neuronLabels
} from "./mocks/textNeuronActivations";
import { TextNeuronActivations } from "./TextNeuronActivations";

export default {
Expand All @@ -11,8 +15,16 @@ const Template: ComponentStory<typeof TextNeuronActivations> = (args) => (
<TextNeuronActivations {...args} />
);

export const SmallModelExample = Template.bind({});
SmallModelExample.args = {
export const MultipleSamples = Template.bind({});
MultipleSamples.args = {
tokens: mockTokens,
activations: mockActivations
activations: mockActivations,
secondDimensionLabels: neuronLabels
};

export const SingleSample = Template.bind({});
SingleSample.args = {
tokens: mockTokens[0],
activations: mockActivations[0],
secondDimensionLabels: neuronLabels
};
Loading

0 comments on commit 9fad349

Please sign in to comment.