diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py index e6fe7f92bcf..6729092f75d 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py @@ -16,6 +16,8 @@ # import ast +import datetime +import decimal import io import json @@ -23,6 +25,7 @@ import re import sys import traceback +import base64 from glob import glob if sys.version_info[0] < 3: @@ -70,6 +73,8 @@ global_dict = {} +MAGIC_ENABLED = os.environ.get("MAGIC_ENABLED") == "true" + class NormalNode(object): def __init__(self, code): @@ -94,6 +99,36 @@ def execute(self): raise ExecutionError(sys.exc_info()) +class UnknownMagic(Exception): + pass + + +class MagicNode(object): + def __init__(self, line): + parts = line[1:].split(" ", 1) + if len(parts) == 1: + self.magic, self.rest = parts[0], () + else: + self.magic, self.rest = parts[0], (parts[1],) + + def execute(self): + if not self.magic: + raise UnknownMagic("magic command not specified") + + try: + handler = magic_router[self.magic] + except KeyError: + raise UnknownMagic("unknown magic command '%s'" % self.magic) + + try: + return handler(*self.rest) + except ExecutionError as e: + raise e + except Exception: + exc_type, exc_value, tb = sys.exc_info() + raise ExecutionError((exc_type, exc_value, None)) + + class ExecutionError(Exception): def __init__(self, exc_info): self.exc_info = exc_info @@ -118,6 +153,14 @@ def parse_code_into_nodes(code): try: nodes.append(NormalNode(code)) except SyntaxError: + # It's possible we hit a syntax error because of a magic command. Split the code groups + # of 'normal code', and code that starts with a '%'. possibly magic code lines, and see + # if any of the lines. Remove lines until we find a node that parses, then check if the + # next line is a magic line. + + # Split the code into chunks of normal code, and possibly magic code, which starts with + # a '%'. + normal = [] chunks = [] for i, line in enumerate(code.rstrip().split("\n")): @@ -135,24 +178,22 @@ def parse_code_into_nodes(code): # Convert the chunks into AST nodes. Let exceptions propagate. for chunk in chunks: - # TODO: look back here when Jupyter and sparkmagic are supported - # if chunk.startswith('%'): - # nodes.append(MagicNode(chunk)) - - nodes.append(NormalNode(chunk)) + if MAGIC_ENABLED and chunk.startswith("%"): + nodes.append(MagicNode(chunk)) + else: + nodes.append(NormalNode(chunk)) return nodes def execute_reply(status, content): - msg = { + return { "msg_type": "execute_reply", "content": dict( content, status=status, ), } - return json.dumps(msg) def execute_reply_ok(data): @@ -211,6 +252,9 @@ def execute_request(content): try: for node in nodes: result = node.execute() + except UnknownMagic: + exc_type, exc_value, tb = sys.exc_info() + return execute_reply_error(exc_type, exc_value, None) except ExecutionError as e: return execute_reply_error(*e.exc_info) @@ -239,6 +283,171 @@ def execute_request(content): return execute_reply_ok(result) +def magic_table_convert(value): + try: + converter = magic_table_types[type(value)] + except KeyError: + converter = magic_table_types[str] + + return converter(value) + + +def magic_table_convert_seq(items): + last_item_type = None + converted_items = [] + + for item in items: + item_type, item = magic_table_convert(item) + + if last_item_type is None: + last_item_type = item_type + elif last_item_type != item_type: + raise ValueError("value has inconsistent types") + + converted_items.append(item) + + return "ARRAY_TYPE", converted_items + + +def magic_table_convert_map(m): + last_key_type = None + last_value_type = None + converted_items = {} + + for key, value in m: + key_type, key = magic_table_convert(key) + value_type, value = magic_table_convert(value) + + if last_key_type is None: + last_key_type = key_type + elif last_value_type != value_type: + raise ValueError("value has inconsistent types") + + if last_value_type is None: + last_value_type = value_type + elif last_value_type != value_type: + raise ValueError("value has inconsistent types") + + converted_items[key] = value + + return "MAP_TYPE", converted_items + + +magic_table_types = { + type(None): lambda x: ("NULL_TYPE", x), + bool: lambda x: ("BOOLEAN_TYPE", x), + int: lambda x: ("INT_TYPE", x), + float: lambda x: ("DOUBLE_TYPE", x), + str: lambda x: ("STRING_TYPE", str(x)), + datetime.date: lambda x: ("DATE_TYPE", str(x)), + datetime.datetime: lambda x: ("TIMESTAMP_TYPE", str(x)), + decimal.Decimal: lambda x: ("DECIMAL_TYPE", str(x)), + tuple: magic_table_convert_seq, + list: magic_table_convert_seq, + dict: magic_table_convert_map, +} + + +def magic_table(name): + try: + value = global_dict[name] + except KeyError: + exc_type, exc_value, tb = sys.exc_info() + raise ExecutionError((exc_type, exc_value, None)) + + if not isinstance(value, (list, tuple)): + value = [value] + + headers = {} + data = [] + + for row in value: + cols = [] + data.append(cols) + + if "Row" == row.__class__.__name__: + row = row.asDict() + + if not isinstance(row, (list, tuple, dict)): + row = [row] + + if isinstance(row, (list, tuple)): + iterator = enumerate(row) + else: + iterator = sorted(row.items()) + + for name, col in iterator: + col_type, col = magic_table_convert(col) + + try: + header = headers[name] + except KeyError: + header = { + "name": str(name), + "type": col_type, + } + headers[name] = header + else: + # Reject columns that have a different type. (allow none value) + if col_type != "NULL_TYPE" and header["type"] != col_type: + if header["type"] == "NULL_TYPE": + header["type"] = col_type + else: + exc_type = Exception + exc_value = Exception("table rows have different types") + raise ExecutionError((exc_type, exc_value, None)) + + cols.append(col) + + headers = [v for k, v in sorted(headers.items())] + + return { + "application/vnd.livy.table.v1+json": { + "headers": headers, + "data": data, + } + } + + +def magic_json(name): + try: + value = global_dict[name] + except KeyError: + exc_type, exc_value, tb = sys.exc_info() + raise ExecutionError((exc_type, exc_value, None)) + + return { + "application/json": value, + } + + +def magic_matplot(name): + try: + value = global_dict[name] + fig = value.gcf() + imgdata = io.BytesIO() + fig.savefig(imgdata, format="png") + imgdata.seek(0) + encode = base64.b64encode(imgdata.getvalue()) + if sys.version >= "3": + encode = encode.decode() + + except: + exc_type, exc_value, tb = sys.exc_info() + raise ExecutionError((exc_type, exc_value, None)) + + return { + "image/png": encode, + } + + +magic_router = { + "table": magic_table, + "json": magic_json, + "matplot": magic_matplot, +} + + # get or create spark session spark_session = kyuubi_util.get_spark_session( os.environ.get("KYUUBI_SPARK_SESSION_UUID") @@ -278,6 +487,22 @@ def main(): break result = execute_request(content) + + try: + result = json.dumps(result) + except ValueError: + result = json.dumps( + { + "msg_type": "inspect_reply", + "content": { + "status": "error", + "ename": "ValueError", + "evalue": "cannot json-ify %s" % response, + "traceback": [], + }, + } + ) + print(result, file=sys_stdout) sys_stdout.flush() clearOutputs() diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala index b3643a7ae43..d35c3fbd475 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types.StructType import org.apache.kyuubi.{KyuubiSQLException, Logging, Utils} -import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE} +import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE, ENGINE_SPARK_PYTHON_MAGIC_ENABLED} import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY, KYUUBI_STATEMENT_ID_KEY} import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._ import org.apache.kyuubi.operation.{ArrayFetchIterator, OperationHandle, OperationState} @@ -233,6 +233,7 @@ object ExecutePython extends Logging { final val PY4J_REGEX = "py4j-[\\S]*.zip$".r final val PY4J_PATH = "PY4J_PATH" final val IS_PYTHON_APP_KEY = "spark.yarn.isPython" + final val MAGIC_ENABLED = "MAGIC_ENABLED" private val isPythonGatewayStart = new AtomicBoolean(false) private val kyuubiPythonPath = Utils.createTempDir() @@ -280,6 +281,7 @@ object ExecutePython extends Logging { } env.put("KYUUBI_SPARK_SESSION_UUID", sessionId) env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH) + env.put(MAGIC_ENABLED, getSessionConf(ENGINE_SPARK_PYTHON_MAGIC_ENABLED, spark).toString) logger.info( s""" |launch python worker command: ${builder.command().asScala.mkString(" ")} @@ -409,15 +411,24 @@ object PythonResponse { } case class PythonResponseContent( - data: Map[String, String], + data: Map[String, Object], ename: String, evalue: String, traceback: Seq[String], status: String) { def getOutput(): String = { - Option(data) - .map(_.getOrElse("text/plain", "")) - .getOrElse("") + if (data == null) return "" + + // If data does not contains field other than `test/plain`, keep backward compatibility, + // otherwise, return all the data. + if (data.filterNot(_._1 == "text/plain").isEmpty) { + data.get("text/plain").map { + case str: String => str + case obj => ExecutePython.toJson(obj) + }.getOrElse("") + } else { + ExecutePython.toJson(data) + } } def getEname(): String = { Option(ename).getOrElse("") diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 00c1b89956e..9dbc483b07c 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -3246,6 +3246,15 @@ object KyuubiConf { .stringConf .createWithDefault("bin/python") + val ENGINE_SPARK_PYTHON_MAGIC_ENABLED: ConfigEntry[Boolean] = + buildConf("kyuubi.engine.spark.python.magic.enabled") + .internal + .doc("Whether to enable pyspark magic node, which is helpful for notebook." + + " See details in KYUUBI #5877") + .version("1.9.0") + .booleanConf + .createWithDefault(true) + val ENGINE_SPARK_REGISTER_ATTRIBUTES: ConfigEntry[Seq[String]] = buildConf("kyuubi.engine.spark.register.attributes") .internal diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala index 16a7f728ea6..c723dcf4aa8 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala @@ -132,6 +132,61 @@ class PySparkTests extends WithKyuubiServer with HiveJDBCTestHelper { }) } + test("Support python magic syntax for python notebook") { + checkPythonRuntimeAndVersion() + withSessionConf()(Map(KyuubiConf.ENGINE_SPARK_PYTHON_MAGIC_ENABLED.key -> "true"))() { + withMultipleConnectionJdbcStatement()({ stmt => + val statement = stmt.asInstanceOf[KyuubiStatement] + statement.executePython("x = [[1, 'a'], [3, 'b']]") + + val resultSet1 = statement.executePython("%json x") + assert(resultSet1.next()) + val output1 = resultSet1.getString("output") + assert(output1 == "{\"application/json\":[[1,\"a\"],[3,\"b\"]]}") + + val resultSet2 = statement.executePython("%table x") + assert(resultSet2.next()) + val output2 = resultSet2.getString("output") + assert(output2 == "{\"application/vnd.livy.table.v1+json\":{" + + "\"headers\":[" + + "{\"name\":\"0\",\"type\":\"INT_TYPE\"},{\"name\":\"1\",\"type\":\"STRING_TYPE\"}" + + "]," + + "\"data\":[" + + "[1,\"a\"],[3,\"b\"]" + + "]}}") + + Seq("table", "json", "matplot").foreach { magic => + val e = intercept[KyuubiSQLException] { + statement.executePython(s"%$magic invalid_value") + }.getMessage + assert(e.contains("KeyError: 'invalid_value'")) + } + + statement.executePython("y = [[1, 2], [3, 'b']]") + var e = intercept[KyuubiSQLException] { + statement.executePython("%table y") + }.getMessage + assert(e.contains("table rows have different types")) + + e = intercept[KyuubiSQLException] { + statement.executePython("%magic_unknown") + }.getMessage + assert(e.contains("unknown magic command 'magic_unknown'")) + }) + } + + withSessionConf()(Map(KyuubiConf.ENGINE_SPARK_PYTHON_MAGIC_ENABLED.key -> "false"))() { + withMultipleConnectionJdbcStatement()({ stmt => + val statement = stmt.asInstanceOf[KyuubiStatement] + statement.executePython("x = [[1, 'a'], [3, 'b']]") + val e = intercept[KyuubiSQLException] { + statement.executePython("%json x") + }.getMessage + assert(e.contains("SyntaxError: invalid syntax")) + }) + } + } + private def runPySparkTest( pyCode: String, output: String): Unit = {