Skip to content

Commit

Permalink
feat(pyspark): add official support and ci testing with spark connect (
Browse files Browse the repository at this point in the history
…ibis-project#10187)

## Description of changes

This PR adds testing for using the pyspark Ibis backend with
spark-connect.

The way this is done is running a Spark connect instance as a docker
compose
service, similar to our other client-server model backends.

The primary bit of functionality that isn't tested is UDFs (which means
JSON unwrapping is also not tested, because that's implemented as a
UDF).

These effectively require a clone of the Python environment on the
server, and that seems out of scope for initial support of spark
connect.
  • Loading branch information
cpcloud authored and ncclementi committed Sep 24, 2024
1 parent cd9ee1b commit b7135e7
Show file tree
Hide file tree
Showing 25 changed files with 581 additions and 462 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ PGPASSWORD="postgres"
MYSQL_PWD="ibis"
MSSQL_SA_PASSWORD="1bis_Testing!"
DRUID_URL="druid://localhost:8082/druid/v2/sql"
SPARK_CONFIG=./docker/spark-connect/conf.properties
33 changes: 25 additions & 8 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,9 @@ jobs:
- name: download backend data
run: just download-data

- name: show docker compose version
if: matrix.backend.services != null
run: docker compose version

- name: start services
if: matrix.backend.services != null
run: docker compose up --wait ${{ join(matrix.backend.services, ' ') }}
run: just up ${{ join(matrix.backend.services, ' ') }}

- name: install python
uses: actions/setup-python@v5
Expand Down Expand Up @@ -600,7 +596,7 @@ jobs:

- name: start services
if: matrix.backend.services != null
run: docker compose up --wait ${{ join(matrix.backend.services, ' ') }}
run: just up ${{ join(matrix.backend.services, ' ') }}

- name: install python
uses: actions/setup-python@v5
Expand Down Expand Up @@ -653,7 +649,7 @@ jobs:
run: docker compose logs

test_pyspark:
name: PySpark ${{ matrix.pyspark-minor-version }} ubuntu-latest python-${{ matrix.python-version }}
name: PySpark ${{ matrix.tag }} ${{ matrix.pyspark-minor-version }} ubuntu-latest python-${{ matrix.python-version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
Expand All @@ -665,19 +661,29 @@ jobs:
deps:
- "'pandas@<2'"
- "'numpy@<1.24'"
tag: local
- python-version: "3.11"
pyspark-version: "3.5.2"
pyspark-minor-version: "3.5"
deps:
- "'pandas@>2'"
- "'numpy@>1.24'"
tag: local
- python-version: "3.12"
pyspark-version: "3.5.2"
pyspark-minor-version: "3.5"
deps:
- "'pandas@>2'"
- "'numpy@>1.24'"
- setuptools
tag: local
- python-version: "3.12"
pyspark-version: "3.5.2"
pyspark-minor-version: "3.5"
deps:
- setuptools
tag: remote
SPARK_REMOTE: "sc://localhost:15002"
steps:
- name: checkout
uses: actions/checkout@v4
Expand All @@ -691,6 +697,10 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

- name: start services
if: matrix.tag == 'remote'
run: just up spark-connect

- name: download backend data
run: just download-data

Expand Down Expand Up @@ -730,7 +740,14 @@ jobs:
shell: bash
run: just download-iceberg-jar ${{ matrix.pyspark-minor-version }}

- name: run tests
- name: run spark connect tests
if: matrix.tag == 'remote'
run: just ci-check -m pyspark
env:
SPARK_REMOTE: ${{ matrix.SPARK_REMOTE }}

- name: run spark tests
if: matrix.tag == 'local'
run: just ci-check -m pyspark

- name: check that no untracked files were produced
Expand Down
20 changes: 20 additions & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,24 @@ services:
networks:
- risingwave

spark-connect:
image: bitnami/spark:3.5.2
ports:
- 15002:15002
command: /opt/bitnami/spark/sbin/start-connect-server.sh --name ibis_testing --packages org.apache.spark:spark-connect_2.12:3.5.2,org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.5.2
healthcheck:
test:
- CMD-SHELL
- bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/15002; exit $$?;'
interval: 5s
retries: 6
volumes:
- spark-connect:/data
- $PWD/docker/spark-connect/conf.properties:/opt/bitnami/spark/conf/spark-defaults.conf:ro
# - $PWD/docker/spark-connect/log4j2.properties:/opt/bitnami/spark/conf/log4j2.properties:ro
networks:
- spark-connect

networks:
impala:
# docker defaults to naming networks "$PROJECT_$NETWORK" but the Java Hive
Expand All @@ -606,6 +624,7 @@ networks:
exasol:
flink:
risingwave:
spark-connect:

volumes:
clickhouse:
Expand All @@ -617,3 +636,4 @@ volumes:
exasol:
impala:
risingwave:
spark-connect:
12 changes: 12 additions & 0 deletions docker/spark-connect/conf.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
spark.driver.extraJavaOptions=-Duser.timezone=GMT
spark.executor.extraJavaOptions=-Duser.timezone=GMT
spark.jars.packages=org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.5.2
spark.sql.catalog.local.type=hadoop
spark.sql.catalog.local.warehouse=warehouse
spark.sql.catalog.local=org.apache.iceberg.spark.SparkCatalog
spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions
spark.sql.legacy.timeParserPolicy=LEGACY
spark.sql.session.timeZone=UTC
spark.sql.streaming.schemaInference=true
spark.ui.enabled=false
spark.ui.showConsoleProgress=false
68 changes: 68 additions & 0 deletions docker/spark-connect/log4j2.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Set everything to be logged to the console
rootLogger.level = error
rootLogger.appenderRef.stdout.ref = console

# In the pattern layout configuration below, we specify an explicit `%ex` conversion
# pattern for logging Throwables. If this was omitted, then (by default) Log4J would
# implicitly add an `%xEx` conversion pattern which logs stacktraces with additional
# class packaging information. That extra information can sometimes add a substantial
# performance overhead, so we disable it in our default logging config.
# For more information, see SPARK-39361.
appender.console.type = Console
appender.console.name = console
appender.console.target = SYSTEM_ERR
appender.console.layout.type = PatternLayout
appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n%ex

# Set the default spark-shell/spark-sql log level to WARN. When running the
# spark-shell/spark-sql, the log level for these classes is used to overwrite
# the root logger's log level, so that the user can have different defaults
# for the shell and regular Spark apps.
logger.repl.name = org.apache.spark.repl.Main
logger.repl.level = error

logger.thriftserver.name = org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver
logger.thriftserver.level = error

# Settings to quiet third party logs that are too verbose
logger.jetty1.name = org.sparkproject.jetty
logger.jetty1.level = error
logger.jetty2.name = org.sparkproject.jetty.util.component.AbstractLifeCycle
logger.jetty2.level = error
logger.replexprTyper.name = org.apache.spark.repl.SparkIMain$exprTyper
logger.replexprTyper.level = error
logger.replSparkILoopInterpreter.name = org.apache.spark.repl.SparkILoop$SparkILoopInterpreter
logger.replSparkILoopInterpreter.level = error
logger.parquet1.name = org.apache.parquet
logger.parquet1.level = error
logger.parquet2.name = parquet
logger.parquet2.level = error

# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
logger.RetryingHMSHandler.name = org.apache.hadoop.hive.metastore.RetryingHMSHandler
logger.RetryingHMSHandler.level = fatal
logger.FunctionRegistry.name = org.apache.hadoop.hive.ql.exec.FunctionRegistry
logger.FunctionRegistry.level = error

# For deploying Spark ThriftServer
# SPARK-34128: Suppress undesirable TTransportException warnings involved in THRIFT-4805
appender.console.filter.1.type = RegexFilter
appender.console.filter.1.regex = .*Thrift error occurred during processing of message.*
appender.console.filter.1.onMatch = deny
appender.console.filter.1.onMismatch = neutral
38 changes: 21 additions & 17 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@
from ibis.util import deprecated

try:
from pyspark.errors import AnalysisException, ParseException
from pyspark.errors import ParseException
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
except ImportError:
from pyspark.sql.utils import AnalysisException, ParseException
from pyspark.sql.utils import ParseException

# Use a dummy class for when spark connect is not available
class SparkConnectGrpcException(Exception):
pass


if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -186,13 +192,6 @@ def do_connect(
# Databricks Serverless compute only supports limited properties
# and any attempt to set unsupported properties will result in an error.
# https://docs.databricks.com/en/spark/conf.html
try:
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
except ImportError:
# Use a dummy class for when spark connect is not available
class SparkConnectGrpcException(Exception):
pass

with contextlib.suppress(SparkConnectGrpcException):
self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")

Expand Down Expand Up @@ -456,7 +455,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
df.createTempView(op.name)

def _finalize_memtable(self, name: str) -> None:
self._session.catalog.dropTempView(name)
"""No-op, otherwise a deadlock can occur when using Spark Connect."""
if isinstance(session := self._session, pyspark.sql.SparkSession):
session.catalog.dropTempView(name)

@contextlib.contextmanager
def _safe_raw_sql(self, query: str) -> Any:
Expand Down Expand Up @@ -579,16 +580,20 @@ def get_schema(

table_loc = self._to_sqlglot_table((catalog, database))
catalog, db = self._to_catalog_db_tuple(table_loc)
session = self._session
with self._active_catalog_database(catalog, db):
try:
df = self._session.table(table_name)
except AnalysisException as e:
if not self._session.catalog.tableExists(table_name):
df = session.table(table_name)
# this is intentionally included in the try block because when
# using spark connect, the table-not-found exception coming
# from the server will *NOT* be raised until the schema
# property is accessed
struct = PySparkType.to_ibis(df.schema)
except Exception as e:
if not session.catalog.tableExists(table_name):
raise com.TableNotFound(table_name) from e
raise

struct = PySparkType.to_ibis(df.schema)

return sch.Schema(struct)

def create_table(
Expand Down Expand Up @@ -752,7 +757,7 @@ def _create_cached_table(self, name, expr):
query = self.compile(expr)
t = self._session.sql(query).cache()
assert t.is_cached
t.createOrReplaceTempView(name)
t.createTempView(name)
# store the underlying spark dataframe so we can release memory when
# asked to, instead of when the session ends
self._cached_dataframes[name] = t
Expand All @@ -761,7 +766,6 @@ def _create_cached_table(self, name, expr):
def _drop_cached_table(self, name):
self._session.catalog.dropTempView(name)
t = self._cached_dataframes.pop(name)
assert t.is_cached
t.unpersist()
assert not t.is_cached

Expand Down
Loading

0 comments on commit b7135e7

Please sign in to comment.