Skip to content

Commit

Permalink
[KYUUBI #5877] Support Python magic syntax for notebook usage
Browse files Browse the repository at this point in the history
# 🔍 Description
## Issue References 🔗

Support python magic syntax, for example:
```
%table
%json
%matplot
```

Refer:
https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-magics.html
https://github.com/jupyter-incubator/sparkmagic
https://github.com/apache/incubator-livy/blob/master/repl/src/main/resources/fake_shell.py
This pull request fixes #5877

## Describe Your Solution 🔧

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

## Types of changes 🔖

- [ ] Bugfix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

## Test Plan 🧪

Testing with python code:
```
import matplotlib.pyplot as plt
plt.plot([3,4,5],[6,7,8])
%matplot plt;
```
<img width="1723" alt="image" src="https://github.com/apache/kyuubi/assets/6757692/9a1176c0-8eb0-4a64-83e4-35e74e33d2f0">

Decode the "image/png" and save to png.

![matplot](https://github.com/apache/kyuubi/assets/6757692/9139f9d3-7822-43b0-8959-261ed8e79d22)

#### Behavior Without This Pull Request ⚰️

#### Behavior With This Pull Request 🎉

#### Related Unit Tests

---

# Checklists
## 📝 Author Self Checklist

- [ ] My code follows the [style guidelines](https://kyuubi.readthedocs.io/en/master/contributing/code/style.html) of this project
- [ ] I have performed a self-review
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] This patch was not authored or co-authored using [Generative Tooling](https://www.apache.org/legal/generative-tooling.html)

## 📝 Committer Pre-Merge Checklist

- [ ] Pull request title is okay.
- [ ] No license issues.
- [ ] Milestone correctly set?
- [ ] Test coverage is ok
- [ ] Assignees are selected.
- [ ] Minimum number of approvals
- [ ] No changes are requested

**Be nice. Be informative.**

Closes #5881 from turboFei/magic_command.

Closes #5877

6f2b193 [Fei Wang] ut
877c7d1 [Fei Wang] internal config
012dfe4 [Fei Wang] nit
3e0f324 [Fei Wang] except other exceptions
24352d2 [Fei Wang] raise execution error
0853161 [Fei Wang] raise ExecutionError instead of execute_reply_error
c058def [Fei Wang] add more ut
4da5215 [Fei Wang] Dumps python object to json at last
3512753 [Fei Wang] add ut for json and table
48735eb [Fei Wang] the data should be Map[String, Object]
3a3ba0a [Fei Wang] return other data fields
54d6800 [Fei Wang] reformat
87ded6e [Fei Wang] add config to disable
44f88ef [Fei Wang] add magic node back

Authored-by: Fei Wang <[email protected]>
Signed-off-by: Fei Wang <[email protected]>
  • Loading branch information
turboFei committed Dec 21, 2023
1 parent 37620d4 commit 50910ae
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
#

import ast
import datetime
import decimal
import io
import json

import os
import re
import sys
import traceback
import base64
from glob import glob

if sys.version_info[0] < 3:
Expand Down Expand Up @@ -70,6 +73,8 @@

global_dict = {}

MAGIC_ENABLED = os.environ.get("MAGIC_ENABLED") == "true"


class NormalNode(object):
def __init__(self, code):
Expand All @@ -94,6 +99,36 @@ def execute(self):
raise ExecutionError(sys.exc_info())


class UnknownMagic(Exception):
pass


class MagicNode(object):
def __init__(self, line):
parts = line[1:].split(" ", 1)
if len(parts) == 1:
self.magic, self.rest = parts[0], ()
else:
self.magic, self.rest = parts[0], (parts[1],)

def execute(self):
if not self.magic:
raise UnknownMagic("magic command not specified")

try:
handler = magic_router[self.magic]
except KeyError:
raise UnknownMagic("unknown magic command '%s'" % self.magic)

try:
return handler(*self.rest)
except ExecutionError as e:
raise e
except Exception:
exc_type, exc_value, tb = sys.exc_info()
raise ExecutionError((exc_type, exc_value, None))


class ExecutionError(Exception):
def __init__(self, exc_info):
self.exc_info = exc_info
Expand All @@ -118,6 +153,14 @@ def parse_code_into_nodes(code):
try:
nodes.append(NormalNode(code))
except SyntaxError:
# It's possible we hit a syntax error because of a magic command. Split the code groups
# of 'normal code', and code that starts with a '%'. possibly magic code lines, and see
# if any of the lines. Remove lines until we find a node that parses, then check if the
# next line is a magic line.

# Split the code into chunks of normal code, and possibly magic code, which starts with
# a '%'.

normal = []
chunks = []
for i, line in enumerate(code.rstrip().split("\n")):
Expand All @@ -135,24 +178,22 @@ def parse_code_into_nodes(code):

# Convert the chunks into AST nodes. Let exceptions propagate.
for chunk in chunks:
# TODO: look back here when Jupyter and sparkmagic are supported
# if chunk.startswith('%'):
# nodes.append(MagicNode(chunk))

nodes.append(NormalNode(chunk))
if MAGIC_ENABLED and chunk.startswith("%"):
nodes.append(MagicNode(chunk))
else:
nodes.append(NormalNode(chunk))

return nodes


def execute_reply(status, content):
msg = {
return {
"msg_type": "execute_reply",
"content": dict(
content,
status=status,
),
}
return json.dumps(msg)


def execute_reply_ok(data):
Expand Down Expand Up @@ -211,6 +252,9 @@ def execute_request(content):
try:
for node in nodes:
result = node.execute()
except UnknownMagic:
exc_type, exc_value, tb = sys.exc_info()
return execute_reply_error(exc_type, exc_value, None)
except ExecutionError as e:
return execute_reply_error(*e.exc_info)

Expand Down Expand Up @@ -239,6 +283,171 @@ def execute_request(content):
return execute_reply_ok(result)


def magic_table_convert(value):
try:
converter = magic_table_types[type(value)]
except KeyError:
converter = magic_table_types[str]

return converter(value)


def magic_table_convert_seq(items):
last_item_type = None
converted_items = []

for item in items:
item_type, item = magic_table_convert(item)

if last_item_type is None:
last_item_type = item_type
elif last_item_type != item_type:
raise ValueError("value has inconsistent types")

converted_items.append(item)

return "ARRAY_TYPE", converted_items


def magic_table_convert_map(m):
last_key_type = None
last_value_type = None
converted_items = {}

for key, value in m:
key_type, key = magic_table_convert(key)
value_type, value = magic_table_convert(value)

if last_key_type is None:
last_key_type = key_type
elif last_value_type != value_type:
raise ValueError("value has inconsistent types")

if last_value_type is None:
last_value_type = value_type
elif last_value_type != value_type:
raise ValueError("value has inconsistent types")

converted_items[key] = value

return "MAP_TYPE", converted_items


magic_table_types = {
type(None): lambda x: ("NULL_TYPE", x),
bool: lambda x: ("BOOLEAN_TYPE", x),
int: lambda x: ("INT_TYPE", x),
float: lambda x: ("DOUBLE_TYPE", x),
str: lambda x: ("STRING_TYPE", str(x)),
datetime.date: lambda x: ("DATE_TYPE", str(x)),
datetime.datetime: lambda x: ("TIMESTAMP_TYPE", str(x)),
decimal.Decimal: lambda x: ("DECIMAL_TYPE", str(x)),
tuple: magic_table_convert_seq,
list: magic_table_convert_seq,
dict: magic_table_convert_map,
}


def magic_table(name):
try:
value = global_dict[name]
except KeyError:
exc_type, exc_value, tb = sys.exc_info()
raise ExecutionError((exc_type, exc_value, None))

if not isinstance(value, (list, tuple)):
value = [value]

headers = {}
data = []

for row in value:
cols = []
data.append(cols)

if "Row" == row.__class__.__name__:
row = row.asDict()

if not isinstance(row, (list, tuple, dict)):
row = [row]

if isinstance(row, (list, tuple)):
iterator = enumerate(row)
else:
iterator = sorted(row.items())

for name, col in iterator:
col_type, col = magic_table_convert(col)

try:
header = headers[name]
except KeyError:
header = {
"name": str(name),
"type": col_type,
}
headers[name] = header
else:
# Reject columns that have a different type. (allow none value)
if col_type != "NULL_TYPE" and header["type"] != col_type:
if header["type"] == "NULL_TYPE":
header["type"] = col_type
else:
exc_type = Exception
exc_value = Exception("table rows have different types")
raise ExecutionError((exc_type, exc_value, None))

cols.append(col)

headers = [v for k, v in sorted(headers.items())]

return {
"application/vnd.livy.table.v1+json": {
"headers": headers,
"data": data,
}
}


def magic_json(name):
try:
value = global_dict[name]
except KeyError:
exc_type, exc_value, tb = sys.exc_info()
raise ExecutionError((exc_type, exc_value, None))

return {
"application/json": value,
}


def magic_matplot(name):
try:
value = global_dict[name]
fig = value.gcf()
imgdata = io.BytesIO()
fig.savefig(imgdata, format="png")
imgdata.seek(0)
encode = base64.b64encode(imgdata.getvalue())
if sys.version >= "3":
encode = encode.decode()

except:
exc_type, exc_value, tb = sys.exc_info()
raise ExecutionError((exc_type, exc_value, None))

return {
"image/png": encode,
}


magic_router = {
"table": magic_table,
"json": magic_json,
"matplot": magic_matplot,
}


# get or create spark session
spark_session = kyuubi_util.get_spark_session(
os.environ.get("KYUUBI_SPARK_SESSION_UUID")
Expand Down Expand Up @@ -278,6 +487,22 @@ def main():
break

result = execute_request(content)

try:
result = json.dumps(result)
except ValueError:
result = json.dumps(
{
"msg_type": "inspect_reply",
"content": {
"status": "error",
"ename": "ValueError",
"evalue": "cannot json-ify %s" % response,
"traceback": [],
},
}
)

print(result, file=sys_stdout)
sys_stdout.flush()
clearOutputs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{KyuubiSQLException, Logging, Utils}
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE}
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE, ENGINE_SPARK_PYTHON_MAGIC_ENABLED}
import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY, KYUUBI_STATEMENT_ID_KEY}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.operation.{ArrayFetchIterator, OperationHandle, OperationState}
Expand Down Expand Up @@ -233,6 +233,7 @@ object ExecutePython extends Logging {
final val PY4J_REGEX = "py4j-[\\S]*.zip$".r
final val PY4J_PATH = "PY4J_PATH"
final val IS_PYTHON_APP_KEY = "spark.yarn.isPython"
final val MAGIC_ENABLED = "MAGIC_ENABLED"

private val isPythonGatewayStart = new AtomicBoolean(false)
private val kyuubiPythonPath = Utils.createTempDir()
Expand Down Expand Up @@ -280,6 +281,7 @@ object ExecutePython extends Logging {
}
env.put("KYUUBI_SPARK_SESSION_UUID", sessionId)
env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH)
env.put(MAGIC_ENABLED, getSessionConf(ENGINE_SPARK_PYTHON_MAGIC_ENABLED, spark).toString)
logger.info(
s"""
|launch python worker command: ${builder.command().asScala.mkString(" ")}
Expand Down Expand Up @@ -409,15 +411,24 @@ object PythonResponse {
}

case class PythonResponseContent(
data: Map[String, String],
data: Map[String, Object],
ename: String,
evalue: String,
traceback: Seq[String],
status: String) {
def getOutput(): String = {
Option(data)
.map(_.getOrElse("text/plain", ""))
.getOrElse("")
if (data == null) return ""

// If data does not contains field other than `test/plain`, keep backward compatibility,
// otherwise, return all the data.
if (data.filterNot(_._1 == "text/plain").isEmpty) {
data.get("text/plain").map {
case str: String => str
case obj => ExecutePython.toJson(obj)
}.getOrElse("")
} else {
ExecutePython.toJson(data)
}
}
def getEname(): String = {
Option(ename).getOrElse("")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3246,6 +3246,15 @@ object KyuubiConf {
.stringConf
.createWithDefault("bin/python")

val ENGINE_SPARK_PYTHON_MAGIC_ENABLED: ConfigEntry[Boolean] =
buildConf("kyuubi.engine.spark.python.magic.enabled")
.internal
.doc("Whether to enable pyspark magic node, which is helpful for notebook." +
" See details in KYUUBI #5877")
.version("1.9.0")
.booleanConf
.createWithDefault(true)

val ENGINE_SPARK_REGISTER_ATTRIBUTES: ConfigEntry[Seq[String]] =
buildConf("kyuubi.engine.spark.register.attributes")
.internal
Expand Down
Loading

0 comments on commit 50910ae

Please sign in to comment.