diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index 52ac7c1..5bf4cc3 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -20,31 +20,24 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ - "# Enable python import reloading\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", "# Imports\n", "import numpy as np\n", "from circuitsvis.attention import attention_patterns, attention_pattern\n", "from circuitsvis.activations import text_neuron_activations\n", - "from circuitsvis.examples import hello\n", "from circuitsvis.tokens import colored_tokens\n", "from circuitsvis.topk_tokens import topk_tokens\n", - "from circuitsvis.topk_samples import topk_samples" + "from circuitsvis.topk_samples import topk_samples\n", + "\n", + "# Hot reload (for development only)\n", + "from IPython import get_ipython\n", + "ip = get_ipython()\n", + "if not ip.extension_manager.loaded:\n", + " ip.extension_manager.load('autoreload')\n", + " %autoreload 2" ] }, { @@ -73,62 +66,62 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, - "metadata": {}, + "execution_count": 12, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -4981,62 +4978,62 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, - "metadata": {}, + "execution_count": 13, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -9898,62 +9899,62 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 4, - "metadata": {}, + "execution_count": 14, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -14802,62 +14807,62 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 5, - "metadata": {}, + "execution_count": 15, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -19713,62 +19722,62 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 6, - "metadata": {}, + "execution_count": 16, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -24618,62 +24631,62 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 7, - "metadata": {}, + "execution_count": 17, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -29532,62 +29549,62 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 8, - "metadata": {}, + "execution_count": 18, + "metadata": { + "text/html": { + "Content-Type": "text/html" + } + }, "output_type": "execute_result" } ], @@ -34453,9 +34474,9 @@ ], "metadata": { "kernelspec": { - "display_name": "circuitsvis-env", + "display_name": ".venv", "language": "python", - "name": "circuitsvis-env" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -34467,12 +34488,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "vscode": { - "interpreter": { - "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee" - } + "version": "3.10.12" } }, "nbformat": 4, diff --git a/python/circuitsvis/utils/render.py b/python/circuitsvis/utils/render.py index 547a9e6..ced00b6 100644 --- a/python/circuitsvis/utils/render.py +++ b/python/circuitsvis/utils/render.py @@ -3,6 +3,7 @@ import subprocess import os from pathlib import Path +from typing import Dict, Tuple, Union from urllib import request from uuid import uuid4 @@ -42,28 +43,29 @@ def __init__(self, local_src: str, cdn_src: str): self.local_src = local_src self.cdn_src = cdn_src - def _repr_html_(self) -> str: + def _repr_html_(self) -> Tuple[str, Dict]: """Jupyter/Colab HTML Representation When Jupyter sees this method, it renders the HTML. Returns: - str: HTML for Jupyter/Colab + HTML for Jupyter/Colab """ - # Use local source if we're in dev mode - if is_in_dev_mode(): - return self.local_src + # Use local source if we're in dev mode, or offline. Otherwise use the CDN. + src: str + if is_in_dev_mode() or not internet_on(): + src = self.local_src + else: + src = self.cdn_src - # Use local source if we're offline - if not internet_on(): - return self.local_src - - # Otherwise use the CDN - return self.cdn_src + # Return html + mime = {"Content-Type": "text/html"} + return src, mime def __html__(self) -> str: """Used by some tooling as an alternative to _repr_html_""" - return self._repr_html_() + # Just return the source code (not the MIME data), as some tools may not support this. + return self._repr_html_()[0] def show_code(self) -> str: """Show the code as HTML source code diff --git a/python/circuitsvis/utils/tests/test_render.py b/python/circuitsvis/utils/tests/test_render.py index f199292..afe2971 100644 --- a/python/circuitsvis/utils/tests/test_render.py +++ b/python/circuitsvis/utils/tests/test_render.py @@ -37,7 +37,7 @@ def test_jupyter_renders(self): html = RenderedHTML(src, src) # Check the _repr_html_ method is defined (as Jupyter Lab displays this) - assert html._repr_html_() == src + assert html._repr_html_()[0] == src def test_show_code(self): src = "

Hi

" @@ -85,10 +85,10 @@ def test_stringified_render_is_from_cdn(self, monkeypatch): res = render("Hello", name="Bob") assert str(res) == str(prod) - def test_jupyter_verson_is_from_local(self, monkeypatch): + def test_jupyter_version_is_from_local(self, monkeypatch): monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") monkeypatch.setattr(circuitsvis, "__version__", "1.0.0") dev = render_local("Hello", name="Bob") res = render("Hello", name="Bob") - assert res._repr_html_() == dev + assert res._repr_html_()[0] == dev