Skip to content

Commit

Permalink
[SPARK-37516][PYTHON][SQL] Uses Python's standard string formatter fo…
Browse files Browse the repository at this point in the history
…r SQL API in PySpark

### What changes were proposed in this pull request?

This PR proposes to use [Python's standard string formatter](https://docs.python.org/3/library/string.html#custom-string-formatting) in `SparkSession.sql`, see also apache#34677.

### Why are the changes needed?

To improve usability in PySpark. It works together with Python standard string formatter.

### Does this PR introduce _any_ user-facing change?

By default, there is no user-facing change. If `kwargs` is specified, yes.

1. Attribute supports from frame (standard Python support):

    ```python
    mydf = spark.range(10)
    spark.sql("SELECT {tbl.id}, {tbl[id]} FROM {tbl}", tbl=mydf)
    ```

2. Understanding `DataFrame`:

    ```python
    mydf = spark.range(10)
    spark.sql("SELECT * FROM {tbl}", tbl=mydf)
    ```

3. Understanding `Column`. (explicit column reference only):

    ```python
    mydf = spark.range(10)
    spark.sql("SELECT {c} FROM {tbl}", c=col("id"), tbl=mydf)
    ```

4. Leveraging other Python string format:

    ```python
    mydf = spark.range(10)
    spark.sql(
        "SELECT {col} FROM {mydf} WHERE id IN {x}",
        col=mydf.id, mydf=mydf, x=tuple(range(4)))
    ```

### How was this patch tested?

Doctests were added.

Closes apache#34774 from HyukjinKwon/SPARK-37516.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Dec 8, 2021
1 parent fdc276b commit 26f4953
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 16 deletions.
10 changes: 5 additions & 5 deletions python/pyspark/pandas/sql_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def sql(
return sql_processor.sql(query, index_col=index_col, **kwargs)

session = default_session()
formatter = SQLStringFormatter(session)
formatter = PandasSQLStringFormatter(session)
try:
sdf = session.sql(formatter.format(query, **kwargs))
finally:
Expand All @@ -178,7 +178,7 @@ def sql(
)


class SQLStringFormatter(string.Formatter):
class PandasSQLStringFormatter(string.Formatter):
"""
A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances
with basic Python objects. This object has to be clear after the use for single SQL
Expand All @@ -191,7 +191,7 @@ def __init__(self, session: SparkSession) -> None:
self._ref_sers: List[Tuple[Series, str]] = []

def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
ret = super(SQLStringFormatter, self).vformat(format_string, args, kwargs)
ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs)

for ref, n in self._ref_sers:
if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views):
Expand All @@ -200,7 +200,7 @@ def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str,
return ret

def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs)
obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first

def _convert_value(self, val: Any, name: str) -> Optional[str]:
Expand Down Expand Up @@ -256,7 +256,7 @@ def _test() -> None:
globs["ps"] = pyspark.pandas
spark = (
SparkSession.builder.master("local[4]")
.appName("pyspark.pandas.sql_processor tests")
.appName("pyspark.pandas.sql_formatter tests")
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/pandas/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def test_error_variable_not_exist(self):
with self.assertRaisesRegex(KeyError, "variable_foo"):
ps.sql("select * from {variable_foo}")

def test_error_unsupported_type(self):
with self.assertRaisesRegex(KeyError, "some_dict"):
ps.sql("select * from {some_dict}")

def test_error_bad_sql(self):
with self.assertRaises(ParseException):
ps.sql("this is not valid sql")
Expand Down
90 changes: 84 additions & 6 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.sql_formatter import SQLStringFormatter
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import (
AtomicType,
Expand Down Expand Up @@ -924,23 +925,100 @@ def prepare(obj):
df._schema = struct
return df

def sql(self, sqlQuery: str) -> DataFrame:
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame:
"""Returns a :class:`DataFrame` representing the result of the given query.
When ``kwargs`` is specified, this method formats the given string by using the Python
standard formatter.
.. versionadded:: 2.0.0
Parameters
----------
sqlQuery : str
SQL query string.
kwargs : dict
Other variables that the user wants to set that can be referenced in the query
.. versionchanged:: 3.3.0
Added optional argument ``kwargs`` to specify the mapping of variables in the query.
This feature is experimental and unstable.
Returns
-------
:class:`DataFrame`
Examples
--------
>>> df.createOrReplaceTempView("table1")
>>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')]
Executing a SQL query.
>>> spark.sql("SELECT * FROM range(10) where id > 7").show()
+---+
| id|
+---+
| 8|
| 9|
+---+
Executing a SQL query with variables as Python formatter standard.
>>> spark.sql(
... "SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9
... ).show()
+---+
| id|
+---+
| 8|
+---+
>>> mydf = spark.range(10)
>>> spark.sql(
... "SELECT {col} FROM {mydf} WHERE id IN {x}",
... col=mydf.id, mydf=mydf, x=tuple(range(4))).show()
+---+
| id|
+---+
| 0|
| 1|
| 2|
| 3|
+---+
>>> spark.sql('''
... SELECT m1.a, m2.b
... FROM {table1} m1 INNER JOIN {table2} m2
... ON m1.key = m2.key
... ORDER BY m1.a, m2.b''',
... table1=spark.createDataFrame([(1, "a"), (2, "b")], ["a", "key"]),
... table2=spark.createDataFrame([(3, "a"), (4, "b"), (5, "b")], ["b", "key"])).show()
+---+---+
| a| b|
+---+---+
| 1| 3|
| 2| 4|
| 2| 5|
+---+---+
Also, it is possible to query using class:`Column` from :class:`DataFrame`.
>>> mydf = spark.createDataFrame([(1, 4), (2, 4), (3, 6)], ["A", "B"])
>>> spark.sql("SELECT {df.A}, {df[B]} FROM {df}", df=mydf).show()
+---+---+
| A| B|
+---+---+
| 1| 4|
| 2| 4|
| 3| 6|
+---+---+
"""
return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)

formatter = SQLStringFormatter(self)
if len(kwargs) > 0:
sqlQuery = formatter.format(sqlQuery, **kwargs)
try:
return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
finally:
if len(kwargs) > 0:
formatter.clear()

def table(self, tableName: str) -> DataFrame:
"""Returns the specified table as a :class:`DataFrame`.
Expand Down
84 changes: 84 additions & 0 deletions python/pyspark/sql/sql_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import string
import typing
from typing import Any, Optional, List, Tuple, Sequence, Mapping
import uuid

from py4j.java_gateway import is_instance_of

if typing.TYPE_CHECKING:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import lit


class SQLStringFormatter(string.Formatter):
"""
A standard ``string.Formatter`` in Python that can understand PySpark instances
with basic Python objects. This object has to be clear after the use for single SQL
query; cannot be reused across multiple SQL queries without cleaning.
"""

def __init__(self, session: "SparkSession") -> None:
self._session: "SparkSession" = session
self._temp_views: List[Tuple[DataFrame, str]] = []

def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first

def _convert_value(self, val: Any, field_name: str) -> Optional[str]:
"""
Converts the given value into a SQL string.
"""
from pyspark import SparkContext
from pyspark.sql import Column, DataFrame

if isinstance(val, Column):
assert SparkContext._gateway is not None # type: ignore[attr-defined]

gw = SparkContext._gateway # type: ignore[attr-defined]
jexpr = val._jc.expr()
if is_instance_of(
gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
) or is_instance_of(
gw, jexpr, "org.apache.spark.sql.catalyst.expressions.AttributeReference"
):
return jexpr.sql()
else:
raise ValueError(
"%s in %s should be a plain column reference such as `df.col` "
"or `col('column')`" % (val, field_name)
)
elif isinstance(val, DataFrame):
for df, n in self._temp_views:
if df is val:
return n
df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "")
self._temp_views.append((val, df_name))
val.createOrReplaceTempView(df_name)
return df_name
elif isinstance(val, str):
return lit(val)._jc.expr().sql() # for escaped characters.
else:
return val

def clear(self) -> None:
for _, n in self._temp_views:
self._session.catalog.dropTempView(n)
self._temp_views = []
10 changes: 9 additions & 1 deletion python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession, SQLContext, Row
from pyspark.sql.functions import col
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import PySparkTestCase

Expand Down Expand Up @@ -93,7 +94,7 @@ def test_get_active_session_when_no_active_session(self):
active = SparkSession.getActiveSession()
self.assertEqual(active, None)

def test_SparkSession(self):
def test_spark_session(self):
spark = SparkSession.builder.master("local").config("some-config", "v2").getOrCreate()
try:
self.assertEqual(spark.conf.get("some-config"), "v2")
Expand All @@ -105,6 +106,13 @@ def test_SparkSession(self):
spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet")
self.assertEqual(spark.table("table1").columns, ["name", "age"])
self.assertEqual(spark.range(3).count(), 3)

# SPARK-37516: Only plain column references work as variable in SQL.
self.assertEqual(
spark.sql("select {c} from range(1)", c=col("id")).first(), spark.range(1).first()
)
with self.assertRaisesRegex(ValueError, "Column"):
spark.sql("select {c} from range(10)", c=col("id") + 1)
finally:
spark.sql("DROP DATABASE test_db CASCADE")
spark.stop()
Expand Down

0 comments on commit 26f4953

Please sign in to comment.