diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb
index 52ac7c1..5bf4cc3 100644
--- a/python/Demonstration.ipynb
+++ b/python/Demonstration.ipynb
@@ -20,31 +20,24 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 19,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "# Enable python import reloading\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
"# Imports\n",
"import numpy as np\n",
"from circuitsvis.attention import attention_patterns, attention_pattern\n",
"from circuitsvis.activations import text_neuron_activations\n",
- "from circuitsvis.examples import hello\n",
"from circuitsvis.tokens import colored_tokens\n",
"from circuitsvis.topk_tokens import topk_tokens\n",
- "from circuitsvis.topk_samples import topk_samples"
+ "from circuitsvis.topk_samples import topk_samples\n",
+ "\n",
+ "# Hot reload (for development only)\n",
+ "from IPython import get_ipython\n",
+ "ip = get_ipython()\n",
+ "if not ip.extension_manager.loaded:\n",
+ " ip.extension_manager.load('autoreload')\n",
+ " %autoreload 2"
]
},
{
@@ -73,62 +66,62 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 2,
- "metadata": {},
+ "execution_count": 12,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -4981,62 +4978,62 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 3,
- "metadata": {},
+ "execution_count": 13,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -9898,62 +9899,62 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 4,
- "metadata": {},
+ "execution_count": 14,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -14802,62 +14807,62 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 5,
- "metadata": {},
+ "execution_count": 15,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -19713,62 +19722,62 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 6,
- "metadata": {},
+ "execution_count": 16,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -24618,62 +24631,62 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 7,
- "metadata": {},
+ "execution_count": 17,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -29532,62 +29549,62 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "\n",
+ "\n",
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 8,
- "metadata": {},
+ "execution_count": 18,
+ "metadata": {
+ "text/html": {
+ "Content-Type": "text/html"
+ }
+ },
"output_type": "execute_result"
}
],
@@ -34453,9 +34474,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "circuitsvis-env",
+ "display_name": ".venv",
"language": "python",
- "name": "circuitsvis-env"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -34467,12 +34488,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.8"
- },
- "vscode": {
- "interpreter": {
- "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee"
- }
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/python/circuitsvis/utils/render.py b/python/circuitsvis/utils/render.py
index 547a9e6..ced00b6 100644
--- a/python/circuitsvis/utils/render.py
+++ b/python/circuitsvis/utils/render.py
@@ -3,6 +3,7 @@
import subprocess
import os
from pathlib import Path
+from typing import Dict, Tuple, Union
from urllib import request
from uuid import uuid4
@@ -42,28 +43,29 @@ def __init__(self, local_src: str, cdn_src: str):
self.local_src = local_src
self.cdn_src = cdn_src
- def _repr_html_(self) -> str:
+ def _repr_html_(self) -> Tuple[str, Dict]:
"""Jupyter/Colab HTML Representation
When Jupyter sees this method, it renders the HTML.
Returns:
- str: HTML for Jupyter/Colab
+ HTML for Jupyter/Colab
"""
- # Use local source if we're in dev mode
- if is_in_dev_mode():
- return self.local_src
+ # Use local source if we're in dev mode, or offline. Otherwise use the CDN.
+ src: str
+ if is_in_dev_mode() or not internet_on():
+ src = self.local_src
+ else:
+ src = self.cdn_src
- # Use local source if we're offline
- if not internet_on():
- return self.local_src
-
- # Otherwise use the CDN
- return self.cdn_src
+ # Return html
+ mime = {"Content-Type": "text/html"}
+ return src, mime
def __html__(self) -> str:
"""Used by some tooling as an alternative to _repr_html_"""
- return self._repr_html_()
+ # Just return the source code (not the MIME data), as some tools may not support this.
+ return self._repr_html_()[0]
def show_code(self) -> str:
"""Show the code as HTML source code
diff --git a/python/circuitsvis/utils/tests/test_render.py b/python/circuitsvis/utils/tests/test_render.py
index f199292..afe2971 100644
--- a/python/circuitsvis/utils/tests/test_render.py
+++ b/python/circuitsvis/utils/tests/test_render.py
@@ -37,7 +37,7 @@ def test_jupyter_renders(self):
html = RenderedHTML(src, src)
# Check the _repr_html_ method is defined (as Jupyter Lab displays this)
- assert html._repr_html_() == src
+ assert html._repr_html_()[0] == src
def test_show_code(self):
src = "Hi
"
@@ -85,10 +85,10 @@ def test_stringified_render_is_from_cdn(self, monkeypatch):
res = render("Hello", name="Bob")
assert str(res) == str(prod)
- def test_jupyter_verson_is_from_local(self, monkeypatch):
+ def test_jupyter_version_is_from_local(self, monkeypatch):
monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock")
monkeypatch.setattr(circuitsvis, "__version__", "1.0.0")
dev = render_local("Hello", name="Bob")
res = render("Hello", name="Bob")
- assert res._repr_html_() == dev
+ assert res._repr_html_()[0] == dev