-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add multi-sample text neuron activations support (#34)
- Loading branch information
1 parent
32979cb
commit 9fad349
Showing
12 changed files
with
7,542 additions
and
2,413 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,57 @@ | ||
"""Activations visualizations""" | ||
from typing import List, Union | ||
from typing import List, Union, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from circuitsvis.utils.render import RenderedHTML, render | ||
|
||
|
||
def text_neuron_activations( | ||
tokens: List[str], | ||
activations: Union[List[List[List[float]]], np.ndarray, torch.Tensor], | ||
tokens: Union[List[str], List[List[str]]], | ||
activations: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], | ||
first_dimension_name: Optional[str] = "Layer", | ||
second_dimension_name: Optional[str] = "Neuron", | ||
first_dimension_labels: Optional[List[str]] = None, | ||
second_dimension_labels: Optional[List[str]] = None, | ||
) -> RenderedHTML: | ||
"""Show activations (colored by intensity) for each token in some text | ||
"""Show activations (colored by intensity) for each token in a text or set | ||
of texts. | ||
Includes drop-downs for layer and neuron numbers. | ||
Args: | ||
tokens: List of tokens (e.g. `["A", "person"]`) | ||
activations: Activations of the shape [tokens x layers x neurons] | ||
tokens: List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`) | ||
activations: Activations of the shape [tokens x layers x neurons] if | ||
single sample or list of [tokens x layers x neurons] if multiple samples | ||
Returns: | ||
Html: Text neuron activations visualization | ||
""" | ||
# Verify that activations and tokens have the right shape and convert to | ||
# nested lists | ||
if isinstance(activations, (np.ndarray, torch.Tensor)): | ||
assert ( | ||
activations.ndim == 3 | ||
), "activations must be of shape [tokens x layers x neurons]" | ||
activations_list = activations.tolist() | ||
elif isinstance(activations, list): | ||
activations_list = [] | ||
for act in activations: | ||
assert ( | ||
act.ndim == 3 | ||
), "activations must be of shape [tokens x layers x neurons]" | ||
activations_list.append(act.tolist()) | ||
else: | ||
raise TypeError( | ||
f"activations must be of type np.ndarray, torch.Tensor, or list, not {type(activations)}" | ||
) | ||
|
||
return render( | ||
"TextNeuronActivations", | ||
tokens=tokens, | ||
activations=activations, | ||
activations=activations_list, | ||
firstDimensionName=first_dimension_name, | ||
secondDimensionName=second_dimension_name, | ||
firstDimensionLabels=first_dimension_labels, | ||
secondDimensionLabels=second_dimension_labels, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,12 +7,22 @@ | |
|
||
snapshots = Snapshot() | ||
|
||
snapshots['TestTextNeuronActivations.test_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/> | ||
snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/> | ||
<script crossorigin type="module"> | ||
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.0.0/dist/cdn/esm.js"; | ||
import { render, TextNeuronActivations } from "https://unpkg.com/circuitsvis@1.34.0/dist/cdn/esm.js"; | ||
render( | ||
"circuits-vis-mock", | ||
TextNeuronActivations, | ||
{"tokens": ["a", "b"], "activations": [[[0, 1], [0, 1]]]} | ||
{"tokens": [["a", "b"], ["c", "d", "e"]], "activations": [[[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]], [[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]], [[0, 1, 1], [1, 1, 1]]]], "firstDimensionName": "Layer", "secondDimensionName": "Neuron"} | ||
) | ||
</script>''' | ||
|
||
snapshots['TestTextNeuronActivations.test_single_matches_snapshot 1'] = '''<div id="circuits-vis-mock" style="margin: 15px 0;"/> | ||
<script crossorigin type="module"> | ||
import { render, TextNeuronActivations } from "https://unpkg.com/[email protected]/dist/cdn/esm.js"; | ||
render( | ||
"circuits-vis-mock", | ||
TextNeuronActivations, | ||
{"tokens": ["a", "b"], "activations": [[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]], "firstDimensionName": "Layer", "secondDimensionName": "Neuron"} | ||
) | ||
</script>''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,38 @@ | ||
import circuitsvis | ||
from circuitsvis.activations import text_neuron_activations | ||
import circuitsvis.utils.render | ||
import numpy as np | ||
from circuitsvis.activations import text_neuron_activations | ||
|
||
|
||
class TestTextNeuronActivations: | ||
def test_matches_snapshot(self, snapshot, monkeypatch): | ||
def test_single_matches_snapshot(self, snapshot, monkeypatch): | ||
# Monkeypatch uuid4 to always return the same uuid | ||
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") | ||
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0") | ||
|
||
res = text_neuron_activations( | ||
tokens=["a", "b"], | ||
activations=np.array([[[0, 1], [0, 1]]]) | ||
activations=np.array( | ||
[[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]] | ||
), # [tokens (2) x layers (2) x neurons(3)] | ||
) | ||
snapshot.assert_match(str(res)) | ||
|
||
def test_multi_matches_snapshot(self, snapshot, monkeypatch): | ||
# Monkeypatch uuid4 to always return the same uuid | ||
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") | ||
|
||
res = text_neuron_activations( | ||
tokens=[["a", "b"], ["c", "d", "e"]], | ||
activations=[ | ||
np.array( | ||
[[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]] | ||
), # [tokens (2) x layers (2) x neurons(3)] | ||
np.array( | ||
[ | ||
[[0, 1, 0], [0, 1, 1]], | ||
[[0, 1, 1], [1, 1, 1]], | ||
[[0, 1, 1], [1, 1, 1]], | ||
] | ||
), # [tokens (3) x layers (2) x neurons(3)] | ||
], | ||
) | ||
snapshot.assert_match(str(res)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.