Skip to content

Commit

Permalink
Add column_dtypes to forcibly convert the type
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 committed Jun 5, 2024
1 parent 175942b commit 72f2a7d
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 239 deletions.
22 changes: 19 additions & 3 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from json import loads

import pandas as pd

from app.mdl.rewriter import rewrite
from app.model.data_source import DataSource, ConnectionInfo


class Connector:
def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, manifest_str: str):
def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, manifest_str: str, column_dtypes: dict[str, str]):
self.data_source = data_source
self.connection = self.data_source.get_connection(connection_info)
self.manifest_str = manifest_str
self.column_dtypes = column_dtypes

def query(self, sql) -> dict:
rewritten_sql = rewrite(self.manifest_str, sql)
Expand All @@ -21,13 +24,26 @@ def dry_run(self, sql) -> None:
except Exception as e:
raise QueryDryRunError(f'Exception: {type(e)}, message: {str(e)}')

@staticmethod
def _to_json(df):
def _to_json(self, df):
if self.column_dtypes:
self._to_specific_types(df, self.column_dtypes)
json_obj = loads(df.to_json(orient='split'))
del json_obj['index']
json_obj['dtypes'] = df.dtypes.apply(lambda x: x.name).to_dict()
return json_obj

def _to_specific_types(self, df: pd.DataFrame, column_dtypes: dict[str, str]):
for column, dtype in column_dtypes.items():
if dtype == 'datetime64':
df[column] = self._to_datetime_and_format(df[column])
else:
df[column] = df[column].astype(dtype)

@staticmethod
def _to_datetime_and_format(series: pd.Series) -> pd.Series:
series = pd.to_datetime(series, errors='coerce')
return series.apply(lambda d: d.strftime('%Y-%m-%d %H:%M:%S.%f' + (' %Z' if series.dt.tz is not None else '')) if not pd.isnull(d) else d)


class QueryDryRunError(Exception):
pass
1 change: 1 addition & 0 deletions ibis-server/app/model/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class IbisDTO(BaseModel):
sql: str
manifest_str: str = Field(alias="manifestStr", description="Base64 manifest")
column_dtypes: dict[str, str] | None = Field(alias="columnDtypes", description="If this field is set, it will forcibly convert the type.", default=None)


class PostgresDTO(IbisDTO):
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/routers/ibis/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@router.post("/query")
@log_dto
def query(dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str)
connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/routers/ibis/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@router.post("/query")
@log_dto
def query(dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str)
connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/routers/ibis/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@router.post("/query")
@log_dto
def query(dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str)
connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
Expand Down
190 changes: 95 additions & 95 deletions ibis-server/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ google-auth = "2.29.0"
httpx = "0.27.0"
python-dotenv = "1.0.1"
orjson = "3.10.3"
pandas = "2.2.2"

[tool.poetry.group.dev.dependencies]
pytest = "8.2.0"
testcontainers = {extras = ["postgres"], version = "4.5.0"}
sqlalchemy = "2.0.30"
pandas = "2.2.2"

[tool.pytest.ini_options]
addopts = "--strict-markers"
Expand Down
134 changes: 89 additions & 45 deletions ibis-server/tests/routers/ibis/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import base64

import orjson
import pytest
from fastapi.testclient import TestClient

Expand All @@ -8,38 +11,38 @@

@pytest.mark.bigquery
class TestBigquery:
manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"refSql": "select * from tpch_tiny.orders",
"columns": [
{"name": "orderkey", "expression": "o_orderkey", "type": "integer"},
{"name": "custkey", "expression": "o_custkey", "type": "integer"},
{"name": "orderstatus", "expression": "o_orderstatus", "type": "varchar"},
{"name": "totalprice", "expression": "o_totalprice", "type": "float"},
{"name": "orderdate", "expression": "o_orderdate", "type": "date"},
{"name": "order_cust_key", "expression": "concat(o_orderkey, '_', o_custkey)", "type": "varchar"},
{"name": "timestamp", "expression": "cast('2024-01-01T23:59:59' as timestamp)", "type": "timestamp"},
{"name": "timestamptz", "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", "type": "timestamp"}
],
"primaryKey": "orderkey"
},
{
"name": "Customer",
"refSql": "select * from tpch_tiny.customer",
"columns": [
{"name": "custkey", "expression": "c_custkey", "type": "integer"},
{"name": "name", "expression": "c_name", "type": "varchar"}
],
"primaryKey": "custkey"
}
]
}

@pytest.fixture()
def manifest_str(self) -> str:
import base64
import orjson

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"properties": {},
"refSql": "select * from tpch_tiny.orders",
"columns": [
{"name": "orderkey", "expression": "o_orderkey", "type": "integer"},
{"name": "custkey", "type": "integer", "expression": "o_custkey"}
],
"primaryKey": "orderkey"
},
{
"name": "Customer",
"refSql": "select * from tpch_tiny.customer",
"columns": [
{"name": "custkey", "expression": "c_custkey", "type": "integer"},
{"name": "name", "expression": "c_name", "type": "varchar"}
],
"primaryKey": "custkey"
}
]
}
return base64.b64encode(orjson.dumps(manifest)).decode('utf-8')
manifest_str = base64.b64encode(orjson.dumps(manifest)).decode('utf-8')

@staticmethod
def get_connection_info():
Expand All @@ -50,22 +53,63 @@ def get_connection_info():
"credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON")
}

def test_query(self, manifest_str: str):
def test_query(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1'
"manifestStr": self.manifest_str,
"sql": 'SELECT * FROM "Orders" ORDER BY orderkey LIMIT 1'
}
)
assert response.status_code == 200
result = response.json()
assert len(result['columns']) == len(self.manifest['models'][0]['columns'])
assert len(result['data']) == 1
assert result['data'][0] == [1, 370, 'O', 172799.49, 820540800000, '1_370', 1704153599000, 1704153599000]
assert result['dtypes'] == {
'orderkey': 'int64',
'custkey': 'int64',
'orderstatus': 'object',
'totalprice': 'float64',
'orderdate': 'object',
'order_cust_key': 'object',
'timestamp': 'datetime64[ns]',
'timestamptz': 'datetime64[ns, UTC]'
}

def test_query_with_column_dtypes(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"connectionInfo": connection_info,
"manifestStr": self.manifest_str,
"sql": 'SELECT * FROM "Orders" ORDER BY orderkey LIMIT 1',
"columnDtypes": {
"totalprice": "float",
"orderdate": "datetime64",
"timestamp": "datetime64",
"timestamptz": "datetime64"
}
}
)
assert response.status_code == 200
result = response.json()
assert len(result['columns']) == 2
assert len(result['columns']) == len(self.manifest['models'][0]['columns'])
assert len(result['data']) == 1
assert result['data'][0][0] is not None
assert result['dtypes'] is not None
assert result['data'][0] == [1, 370, 'O', 172799.49, '1996-01-02 00:00:00.000000', '1_370', '2024-01-01 23:59:59.000000', '2024-01-01 23:59:59.000000 UTC']
assert result['dtypes'] == {
'orderkey': 'int64',
'custkey': 'int64',
'orderstatus': 'object',
'totalprice': 'float64',
'orderdate': 'object',
'order_cust_key': 'object',
'timestamp': 'object',
'timestamptz': 'object'
}

def test_query_without_manifest(self):
connection_info = self.get_connection_info()
Expand All @@ -83,13 +127,13 @@ def test_query_without_manifest(self):
assert result['detail'][0]['loc'] == ['body', 'manifestStr']
assert result['detail'][0]['msg'] == 'Field required'

def test_query_without_sql(self, manifest_str: str):
def test_query_without_sql(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str
"manifestStr": self.manifest_str
}
)
assert response.status_code == 422
Expand All @@ -99,11 +143,11 @@ def test_query_without_sql(self, manifest_str: str):
assert result['detail'][0]['loc'] == ['body', 'sql']
assert result['detail'][0]['msg'] == 'Field required'

def test_query_without_connection_info(self, manifest_str: str):
def test_query_without_connection_info(self):
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"manifestStr": manifest_str,
"manifestStr": self.manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
Expand All @@ -114,27 +158,27 @@ def test_query_without_connection_info(self, manifest_str: str):
assert result['detail'][0]['loc'] == ['body', 'connectionInfo']
assert result['detail'][0]['msg'] == 'Field required'

def test_query_with_dry_run(self, manifest_str: str):
def test_query_with_dry_run(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"manifestStr": self.manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 204

def test_query_with_dry_run_and_invalid_sql(self, manifest_str: str):
def test_query_with_dry_run_and_invalid_sql(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"manifestStr": self.manifest_str,
"sql": 'SELECT * FROM X'
}
)
Expand Down
Loading

0 comments on commit 72f2a7d

Please sign in to comment.