diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb
index 5bf4cc3..52ac7c1 100644
--- a/python/Demonstration.ipynb
+++ b/python/Demonstration.ipynb
@@ -20,24 +20,31 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
+ "# Enable python import reloading\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
"# Imports\n",
"import numpy as np\n",
"from circuitsvis.attention import attention_patterns, attention_pattern\n",
"from circuitsvis.activations import text_neuron_activations\n",
+ "from circuitsvis.examples import hello\n",
"from circuitsvis.tokens import colored_tokens\n",
"from circuitsvis.topk_tokens import topk_tokens\n",
- "from circuitsvis.topk_samples import topk_samples\n",
- "\n",
- "# Hot reload (for development only)\n",
- "from IPython import get_ipython\n",
- "ip = get_ipython()\n",
- "if not ip.extension_manager.loaded:\n",
- " ip.extension_manager.load('autoreload')\n",
- " %autoreload 2"
+ "from circuitsvis.topk_samples import topk_samples"
]
},
{
@@ -66,62 +73,62 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 12,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 2,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -4978,62 +4981,62 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 13,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 3,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -9899,62 +9898,62 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 14,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 4,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -14807,62 +14802,62 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 15,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 5,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -19722,62 +19713,62 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 16,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 6,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -24631,62 +24618,62 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 17,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 7,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -29549,62 +29532,62 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 18,
- "metadata": {
- "text/html": {
- "Content-Type": "text/html"
- }
- },
+ "execution_count": 8,
+ "metadata": {},
"output_type": "execute_result"
}
],
@@ -34474,9 +34453,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv",
+ "display_name": "circuitsvis-env",
"language": "python",
- "name": "python3"
+ "name": "circuitsvis-env"
},
"language_info": {
"codemirror_mode": {
@@ -34488,7 +34467,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.12"
+ "version": "3.10.8"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee"
+ }
}
},
"nbformat": 4,
diff --git a/python/circuitsvis/utils/render.py b/python/circuitsvis/utils/render.py
index ced00b6..547a9e6 100644
--- a/python/circuitsvis/utils/render.py
+++ b/python/circuitsvis/utils/render.py
@@ -3,7 +3,6 @@
import subprocess
import os
from pathlib import Path
-from typing import Dict, Tuple, Union
from urllib import request
from uuid import uuid4
@@ -43,29 +42,28 @@ def __init__(self, local_src: str, cdn_src: str):
self.local_src = local_src
self.cdn_src = cdn_src
- def _repr_html_(self) -> Tuple[str, Dict]:
+ def _repr_html_(self) -> str:
"""Jupyter/Colab HTML Representation
When Jupyter sees this method, it renders the HTML.
Returns:
- HTML for Jupyter/Colab
+ str: HTML for Jupyter/Colab
"""
- # Use local source if we're in dev mode, or offline. Otherwise use the CDN.
- src: str
- if is_in_dev_mode() or not internet_on():
- src = self.local_src
- else:
- src = self.cdn_src
+ # Use local source if we're in dev mode
+ if is_in_dev_mode():
+ return self.local_src
- # Return html
- mime = {"Content-Type": "text/html"}
- return src, mime
+ # Use local source if we're offline
+ if not internet_on():
+ return self.local_src
+
+ # Otherwise use the CDN
+ return self.cdn_src
def __html__(self) -> str:
"""Used by some tooling as an alternative to _repr_html_"""
- # Just return the source code (not the MIME data), as some tools may not support this.
- return self._repr_html_()[0]
+ return self._repr_html_()
def show_code(self) -> str:
"""Show the code as HTML source code
diff --git a/python/circuitsvis/utils/tests/test_render.py b/python/circuitsvis/utils/tests/test_render.py
index afe2971..f199292 100644
--- a/python/circuitsvis/utils/tests/test_render.py
+++ b/python/circuitsvis/utils/tests/test_render.py
@@ -37,7 +37,7 @@ def test_jupyter_renders(self):
html = RenderedHTML(src, src)
# Check the _repr_html_ method is defined (as Jupyter Lab displays this)
- assert html._repr_html_()[0] == src
+ assert html._repr_html_() == src
def test_show_code(self):
src = "Hi
"
@@ -85,10 +85,10 @@ def test_stringified_render_is_from_cdn(self, monkeypatch):
res = render("Hello", name="Bob")
assert str(res) == str(prod)
- def test_jupyter_version_is_from_local(self, monkeypatch):
+ def test_jupyter_verson_is_from_local(self, monkeypatch):
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0")
dev = render_local("Hello", name="Bob")
res = render("Hello", name="Bob")
- assert res._repr_html_()[0] == dev
+ assert res._repr_html_() == dev