diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index 5bf4cc3..52ac7c1 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -20,24 +20,31 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ + "# Enable python import reloading\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", "# Imports\n", "import numpy as np\n", "from circuitsvis.attention import attention_patterns, attention_pattern\n", "from circuitsvis.activations import text_neuron_activations\n", + "from circuitsvis.examples import hello\n", "from circuitsvis.tokens import colored_tokens\n", "from circuitsvis.topk_tokens import topk_tokens\n", - "from circuitsvis.topk_samples import topk_samples\n", - "\n", - "# Hot reload (for development only)\n", - "from IPython import get_ipython\n", - "ip = get_ipython()\n", - "if not ip.extension_manager.loaded:\n", - " ip.extension_manager.load('autoreload')\n", - " %autoreload 2" + "from circuitsvis.topk_samples import topk_samples" ] }, { @@ -66,62 +73,62 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 12, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 2, + "metadata": {}, "output_type": "execute_result" } ], @@ -4978,62 +4981,62 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 13, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 3, + "metadata": {}, "output_type": "execute_result" } ], @@ -9899,62 +9898,62 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 14, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 4, + "metadata": {}, "output_type": "execute_result" } ], @@ -14807,62 +14802,62 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 15, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 5, + "metadata": {}, "output_type": "execute_result" } ], @@ -19722,62 +19713,62 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 16, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 6, + "metadata": {}, "output_type": "execute_result" } ], @@ -24631,62 +24618,62 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 17, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 7, + "metadata": {}, "output_type": "execute_result" } ], @@ -29549,62 +29532,62 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 18, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 8, + "metadata": {}, "output_type": "execute_result" } ], @@ -34474,9 +34453,9 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "circuitsvis-env", "language": "python", - "name": "python3" + "name": "circuitsvis-env" }, "language_info": { "codemirror_mode": { @@ -34488,7 +34467,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.8" + }, + "vscode": { + "interpreter": { + "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee" + } } }, "nbformat": 4, diff --git a/python/circuitsvis/utils/render.py b/python/circuitsvis/utils/render.py index ced00b6..547a9e6 100644 --- a/python/circuitsvis/utils/render.py +++ b/python/circuitsvis/utils/render.py @@ -3,7 +3,6 @@ import subprocess import os from pathlib import Path -from typing import Dict, Tuple, Union from urllib import request from uuid import uuid4 @@ -43,29 +42,28 @@ def __init__(self, local_src: str, cdn_src: str): self.local_src = local_src self.cdn_src = cdn_src - def _repr_html_(self) -> Tuple[str, Dict]: + def _repr_html_(self) -> str: """Jupyter/Colab HTML Representation When Jupyter sees this method, it renders the HTML. Returns: - HTML for Jupyter/Colab + str: HTML for Jupyter/Colab """ - # Use local source if we're in dev mode, or offline. Otherwise use the CDN. - src: str - if is_in_dev_mode() or not internet_on(): - src = self.local_src - else: - src = self.cdn_src + # Use local source if we're in dev mode + if is_in_dev_mode(): + return self.local_src - # Return html - mime = {"Content-Type": "text/html"} - return src, mime + # Use local source if we're offline + if not internet_on(): + return self.local_src + + # Otherwise use the CDN + return self.cdn_src def __html__(self) -> str: """Used by some tooling as an alternative to _repr_html_""" - # Just return the source code (not the MIME data), as some tools may not support this. - return self._repr_html_()[0] + return self._repr_html_() def show_code(self) -> str: """Show the code as HTML source code diff --git a/python/circuitsvis/utils/tests/test_render.py b/python/circuitsvis/utils/tests/test_render.py index afe2971..f199292 100644 --- a/python/circuitsvis/utils/tests/test_render.py +++ b/python/circuitsvis/utils/tests/test_render.py @@ -37,7 +37,7 @@ def test_jupyter_renders(self): html = RenderedHTML(src, src) # Check the _repr_html_ method is defined (as Jupyter Lab displays this) - assert html._repr_html_()[0] == src + assert html._repr_html_() == src def test_show_code(self): src = "

Hi

" @@ -85,10 +85,10 @@ def test_stringified_render_is_from_cdn(self, monkeypatch): res = render("Hello", name="Bob") assert str(res) == str(prod) - def test_jupyter_version_is_from_local(self, monkeypatch): + def test_jupyter_verson_is_from_local(self, monkeypatch): monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") monkeypatch.setattr(circuitsvis, "__version__", "1.0.0") dev = render_local("Hello", name="Bob") res = render("Hello", name="Bob") - assert res._repr_html_()[0] == dev + assert res._repr_html_() == dev