Skip to content

Commit

Permalink
Add rich display for python objects in ScalaPy
Browse files Browse the repository at this point in the history
  • Loading branch information
kiendang authored and alexarchambault committed Aug 18, 2022
1 parent e97b363 commit 161cf78
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 5 deletions.
19 changes: 19 additions & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ class AlmondSpark(val crossScalaVersion: String) extends AlmondModule with Mima
// sources.in(Compile, doc) := Nil
}

class AlmondScalaPy(val crossScalaVersion: String) extends AlmondModule with Mima {
def ivyDeps = Agg(
Deps.jvmRepr
)
def compileIvyDeps = Agg(
Deps.scalapy
)
}

class AlmondRx(val crossScalaVersion: String) extends AlmondModule with Mima {
def compileModuleDeps = Seq(
scala.`scala-kernel-api`()
Expand Down Expand Up @@ -290,6 +299,7 @@ object scala extends Module {
object `scala-interpreter` extends Cross[ScalaInterpreter](ScalaVersions.all: _*)
object `scala-kernel` extends Cross[ScalaKernel] (ScalaVersions.all: _*)
object `scala-kernel-helper` extends Cross[ScalaKernelHelper](ScalaVersions.all.filter(_.startsWith("3.")): _*)
object `almond-scalapy` extends Cross[AlmondScalaPy] (ScalaVersions.binaries: _*)
object `almond-spark` extends Cross[AlmondSpark] (ScalaVersions.scala212)
object `almond-rx` extends Cross[AlmondRx] (ScalaVersions.scala212)
}
Expand Down Expand Up @@ -480,12 +490,21 @@ def validateExamples(matcher: String = "") = {
Some(m)
}

val sv0 = {
val prefix = sv.split('.').take(2).map(_ + ".").mkString
ScalaVersions.binaries.find(_.startsWith(prefix)).getOrElse {
sys.error(s"Can't find a Scala version in ${ScalaVersions.binaries} with the same binary version as $sv (prefix: $prefix)")
}
}

T.command {
val launcher = scala.`scala-kernel`(sv).launcher().path
val jupyterPath = T.dest / "jupyter"
val outputDir = T.dest / "output"
os.makeDir.all(outputDir)

scala.`almond-scalapy`(sv0).publishLocalNoFluff((baseRepoRoot / "{VERSION}").toString)()

val version = scala.`scala-kernel`(sv).publishVersion()
val repoRoot = baseRepoRoot / version

Expand Down
248 changes: 248 additions & 0 deletions examples/scalapy-displays.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {

},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mai.kien.python.Python\n",
"\n",
"\u001b[39m"
]
},
"execution_count": 1,
"metadata": {

},
"output_type": "execute_result"
}
],
"source": [
"import $ivy.`ai.kien::python-native-libs:0.2.1`\n",
"import ai.kien.python.Python\n",
"\n",
"Python().scalapyProperties.fold(\n",
" ex => throw new Exception(ex),\n",
" props => props.map { kv => println(kv); kv }.foreach(Function.tupled(System.setProperty _))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {

},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mme.shadaj.scalapy.py\n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mme.shadaj.scalapy.py.PyQuote\n",
"\u001b[39m\n",
"\u001b[32mimport \u001b[39m\u001b[36mme.shadaj.scalapy.py.SeqConverters\u001b[39m"
]
},
"execution_count": 2,
"metadata": {

},
"output_type": "execute_result"
}
],
"source": [
"import $ivy.`me.shadaj::scalapy-core:0.5.2`\n",
"import me.shadaj.scalapy.py\n",
"import me.shadaj.scalapy.py.PyQuote\n",
"import me.shadaj.scalapy.py.SeqConverters"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {

},
"outputs": [

],
"source": [
"almond.scalapy.initDisplay"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {

},
"outputs": [

],
"source": [
"// disable pprint so that the next line won't show any output\n",
"repl.pprinter() = repl.pprinter().copy(defaultHeight = 0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {

},
"outputs": [
{
"data": {
"text/plain": [
"......"
]
},
"execution_count": 5,
"metadata": {

},
"output_type": "execute_result"
}
],
"source": [
"val display = py.module(\"IPython.display\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {

},
"outputs": [
{
"data": {
"text/html": [
"<b>hello</b>"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.HTML(\"<b>hello</b>\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {

},
"outputs": [
{
"data": {
"text/latex": [
"\\begin{eqnarray}\n",
"\\nabla \\times \\vec{\\mathbf{B}} -\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{E}}}{\\partial t} & = \\frac{4\\pi}{c}\\vec{\\mathbf{j}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{E}} & = 4 \\pi \\rho \\\\\n",
"\\nabla \\times \\vec{\\mathbf{E}}\\, +\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{B}}}{\\partial t} & = \\vec{\\mathbf{0}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{B}} & = 0 \n",
"\\end{eqnarray}"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.Latex(\"\"\"\\begin{eqnarray}\n",
"\\nabla \\times \\vec{\\mathbf{B}} -\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{E}}}{\\partial t} & = \\frac{4\\pi}{c}\\vec{\\mathbf{j}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{E}} & = 4 \\pi \\rho \\\\\n",
"\\nabla \\times \\vec{\\mathbf{E}}\\, +\\, \\frac1c\\, \\frac{\\partial\\vec{\\mathbf{B}}}{\\partial t} & = \\vec{\\mathbf{0}} \\\\\n",
"\\nabla \\cdot \\vec{\\mathbf{B}} & = 0 \n",
"\\end{eqnarray}\"\"\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {

},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle a = b + c$"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.Math(\"a = b + c\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {

},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"# title\n",
"## subsec\n",
"foo\n"
]
},
"metadata": {

},
"output_type": "display_data"
}
],
"source": [
"display.Markdown(\"\"\"\n",
"# title\n",
"## subsec\n",
"foo\n",
"\"\"\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Scala (sources)",
"language": "scala",
"name": "scala-debug"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".sc",
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json as __almond_scalapy_json


def __almond_scalapy_format_display_data(obj, include):
repr_methods = ((t, m) for m, t in include if m in set(dir(obj)))
representations = ((t, getattr(obj, m)()) for t, m in repr_methods)

display_data = (
(t, (r[0], r[1]) if isinstance(r, tuple) and len(r) == 2 else (r, None))
for t, r in representations if r is not None
)
display_data = [(t, m, md) for t, (m, md) in display_data if m is not None]

data = [
(t, d if isinstance(d, str) else __almond_scalapy_json.dumps(d))
for t, d, _ in display_data
]
metadata = [
(t, md if isinstance(md, str) else __almond_scalapy_json.dumps(md))
for t, _, md in display_data if md is not None
]

return data, metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package almond

import java.{util => ju}
import jupyter.{Displayer, Displayers}
import me.shadaj.scalapy.interpreter.CPythonInterpreter
import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.{PyQuote, SeqConverters}
import scala.io.Source
import scala.jdk.CollectionConverters._

package object scalapy {
CPythonInterpreter.execManyLines(Source.fromResource("format_display_data.py").mkString)

def initDisplay: Unit = {
Displayers.register(
classOf[py.Any],
new Displayer[py.Any] {
def display(obj: py.Any): ju.Map[String, String] = {
val (data, _) = formatDisplayData(obj)
if (data.isEmpty) null else data.asJava
}
}
)
}

private val pyFormatDisplayData = py.Dynamic.global.__almond_scalapy_format_display_data

private def formatDisplayData(obj: py.Any): (Map[String, String], Map[String, String]) = {
val displayData = pyFormatDisplayData(obj, allReprMethods.toPythonCopy)
val data = displayData.bracketAccess(0).as[List[(String, String)]].toMap
val metadata = displayData.bracketAccess(1).as[List[(String, String)]].toMap

(data, metadata)
}

private val mimetypes = Map(
"svg" -> "image/svg+xml",
"png" -> "image/png",
"jpeg" -> "image/jpeg",
"html" -> "text/html",
"javascript" -> "application/javascript",
"markdown" -> "text/markdown",
"latex" -> "text/latex"
)

private lazy val allReprMethods: Seq[(String, String)] =
mimetypes.map { case (k, v) => s"_repr_${k}_" -> v }.toSeq
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ final class ReplApiImpl(
.asInstanceOf[Displayer[T]]
.display(value)
.asScala
.toMap
p.display(DisplayData(m))
Some(Iterator())
if (m == null) None
else {
p.display(DisplayData(m.toMap))
Some(Iterator())
}
} else
for (updatableResults <- updatableResultsOpt if (onChange.nonEmpty && custom.isEmpty) || (onChangeOrError.nonEmpty && custom.nonEmpty)) yield {

Expand Down Expand Up @@ -204,4 +206,3 @@ final class ReplApiImpl(
object ReplApiImpl {
private class Foo
}

Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ final case class Options(
val default =
if (defaultAutoDependencies)
Map(
Module.of("org.apache.spark", "*") -> Seq(Dependency.of(Module.of("sh.almond", s"almond-spark_$sbv"), Properties.version))
Module.of("org.apache.spark", "*") -> Seq(Dependency.of(Module.of("sh.almond", s"almond-spark_$sbv"), Properties.version)),
Module.of("me.shadaj", "scalapy*") -> Seq(Dependency.of(Module.of("sh.almond", s"almond-scalapy_$sbv"), Properties.version))
)
else
Map.empty[Module, Seq[Dependency]]
Expand Down
Loading

0 comments on commit 161cf78

Please sign in to comment.