Skip to content

Commit

Permalink
Add data source pyspark
Browse files Browse the repository at this point in the history
  • Loading branch information
ichuniq committed Nov 29, 2024
1 parent 689caf8 commit bd79542
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 0 deletions.
16 changes: 16 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class QueryPostgresDTO(QueryDTO):
connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field


class QueryPySparkDTO(QueryDTO):
connection_info: ConnectionUrl | PySparkConnectionInfo = connection_info_field


class QuerySnowflakeDTO(QueryDTO):
connection_info: SnowflakeConnectionInfo = connection_info_field

Expand Down Expand Up @@ -109,6 +113,17 @@ class PostgresConnectionInfo(BaseModel):
password: SecretStr


class PySparkConnectionInfo(BaseModel):
app_name: SecretStr = Field(examples=["wrenai"])
master: SecretStr = Field(
default="local[*]",
description="Spark master URL (e.g., 'local[*]', 'spark://master:7077')",
)
configs: dict[str, str] | None = Field(
default=None, description="Additional Spark configurations"
)


class SnowflakeConnectionInfo(BaseModel):
user: SecretStr
password: SecretStr
Expand Down Expand Up @@ -137,6 +152,7 @@ class TrinoConnectionInfo(BaseModel):
| MSSqlConnectionInfo
| MySqlConnectionInfo
| PostgresConnectionInfo
| PySparkConnectionInfo
| SnowflakeConnectionInfo
| TrinoConnectionInfo
)
Expand Down
19 changes: 19 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis
from google.oauth2 import service_account
from ibis import BaseBackend
from pyspark.sql import SparkSession

from app.model import (
BigQueryConnectionInfo,
Expand All @@ -16,13 +17,15 @@
MSSqlConnectionInfo,
MySqlConnectionInfo,
PostgresConnectionInfo,
PySparkConnectionInfo,
QueryBigQueryDTO,
QueryCannerDTO,
QueryClickHouseDTO,
QueryDTO,
QueryMSSqlDTO,
QueryMySqlDTO,
QueryPostgresDTO,
QueryPySparkDTO,
QuerySnowflakeDTO,
QueryTrinoDTO,
SnowflakeConnectionInfo,
Expand All @@ -37,6 +40,7 @@ class DataSource(StrEnum):
mssql = auto()
mysql = auto()
postgres = auto()
pyspark = auto()
snowflake = auto()
trino = auto()

Expand All @@ -60,6 +64,7 @@ class DataSourceExtension(Enum):
mssql = QueryMSSqlDTO
mysql = QueryMySqlDTO
postgres = QueryPostgresDTO
pyspark = QueryPySparkDTO
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO

Expand Down Expand Up @@ -143,6 +148,20 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend:
password=info.password.get_secret_value(),
)

@staticmethod
def get_pyspark_connection(info: PySparkConnectionInfo) -> BaseBackend:
builder = SparkSession.builder.appName(info.app_name.get_secret_value()).master(
info.master.get_secret_value()
)

if info.configs:
for key, value in info.configs.items():
builder = builder.config(key, value)

# Create or get existing Spark session
spark_session = builder.getOrCreate()
return ibis.pyspark.connect(session=spark_session)

@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
return ibis.snowflake.connect(
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ibis-framework = { version = "9.5.0", extras = [
"mssql",
"mysql",
"postgres",
"pyspark",
"snowflake",
"trino",
] }
Expand All @@ -42,6 +43,7 @@ sqlalchemy = "2.0.36"
pre-commit = "4.0.1"
ruff = "0.8.0"
trino = ">=0.321,<1"
pyspark = "3.5.1"
psycopg2 = ">=2.8.4,<3"
clickhouse-connect = "0.8.7"

Expand All @@ -54,6 +56,7 @@ markers = [
"mssql: mark a test as a mssql test",
"mysql: mark a test as a mysql test",
"postgres: mark a test as a postgres test",
"pyspark: mark a test as a pyspark test",
"snowflake: mark a test as a snowflake test",
"trino: mark a test as a trino test",
"beta: mark a test as a test for beta versions of the engine",
Expand Down
191 changes: 191 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_pyspark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import base64

# import os
import orjson
import pytest
from fastapi.testclient import TestClient

from app.main import app
from app.model.validator import rules

pytestmark = pytest.mark.pyspark

base_url = "/v2/connector/pyspark"

connection_info = {
"app_name": "MyApp",
"master": "local",
}

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"properties": {},
"refSql": "select * from tpch.orders",
"columns": [
{"name": "orderkey", "expression": "O_ORDERKEY", "type": "integer"},
{"name": "custkey", "expression": "O_CUSTKEY", "type": "integer"},
{
"name": "orderstatus",
"expression": "O_ORDERSTATUS",
"type": "varchar",
},
{
"name": "totalprice",
"expression": "O_TOTALPRICE",
"type": "float",
},
{"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"},
{
"name": "order_cust_key",
"expression": "concat(O_ORDERKEY, '_', O_CUSTKEY)",
"type": "varchar",
},
{
"name": "timestamp",
"expression": "cast('2024-01-01T23:59:59' as timestamp)",
"type": "timestamp",
},
{
"name": "timestamptz",
"expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)",
"type": "timestamp",
},
{
"name": "test_null_time",
"expression": "cast(NULL as timestamp)",
"type": "timestamp",
},
],
"primaryKey": "orderkey",
},
],
}


@pytest.fixture
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


with TestClient(app) as client:
# def test_query(manifest_str):
# response = client.post(
# url=f"{base_url}/query",
# json={
# "connectionInfo": connection_info,
# "manifestStr": manifest_str,
# "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1',
# },
# )
# assert response.status_code == 200
# result = response.json()
# assert len(result["columns"]) == len(manifest["models"][0]["columns"])
# assert len(result["data"]) == 1
# assert result["data"][0] == [
# 1,
# 36901,
# "O",
# "173665.47",
# "1996-01-02",
# "1_36901",
# "2024-01-01 23:59:59.000000",
# "2024-01-01 23:59:59.000000 UTC",
# None,
# ]
# assert result["dtypes"] == {
# "orderkey": "int64",
# "custkey": "int64",
# "orderstatus": "object",
# "totalprice": "object",
# "orderdate": "object",
# "order_cust_key": "object",
# "timestamp": "object",
# "timestamptz": "object",
# "test_null_time": "datetime64[ns]",
# }

def test_query_without_manifest():
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "manifestStr"]
assert result["detail"][0]["msg"] == "Field required"

def test_query_without_sql(manifest_str):
response = client.post(
url=f"{base_url}/query",
json={"connectionInfo": connection_info, "manifestStr": manifest_str},
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "sql"]
assert result["detail"][0]["msg"] == "Field required"

def test_query_without_connection_info(manifest_str):
response = client.post(
url=f"{base_url}/query",
json={
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "connectionInfo"]
assert result["detail"][0]["msg"] == "Field required"

# def test_query_with_dry_run(manifest_str):
# response = client.post(
# url=f"{base_url}/query",
# params={"dryRun": True},
# json={
# "connectionInfo": connection_info,
# "manifestStr": manifest_str,
# "sql": 'SELECT * FROM "Orders" LIMIT 1',
# },
# )
# assert response.status_code == 204

def test_query_with_dry_run_and_invalid_sql(manifest_str):
response = client.post(
url=f"{base_url}/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM X",
},
)
assert response.status_code == 422
assert response.text is not None

def test_validate_with_unknown_rule(manifest_str):
response = client.post(
url=f"{base_url}/validate/unknown_rule",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"parameters": {"modelName": "Orders", "columnName": "orderkey"},
},
)
assert response.status_code == 404
assert (
response.text
== f"The rule `unknown_rule` is not in the rules, rules: {rules}"
)

0 comments on commit bd79542

Please sign in to comment.