Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ibis): Add Oracle connector #1067

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class QueryMySqlDTO(QueryDTO):
connection_info: ConnectionUrl | MySqlConnectionInfo = connection_info_field


class QueryOracleDTO(QueryDTO):
connection_info: ConnectionUrl | OracleConnectionInfo = connection_info_field


class QueryPostgresDTO(QueryDTO):
connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field

Expand Down Expand Up @@ -131,6 +135,14 @@ class PostgresConnectionInfo(BaseModel):
password: SecretStr | None = None


class OracleConnectionInfo(BaseModel):
host: SecretStr = Field(examples=["localhost"])
port: SecretStr = Field(examples=[1521])
database: SecretStr
user: SecretStr
password: SecretStr | None = None


class SnowflakeConnectionInfo(BaseModel):
user: SecretStr
password: SecretStr
Expand Down
14 changes: 14 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ConnectionInfo,
MSSqlConnectionInfo,
MySqlConnectionInfo,
OracleConnectionInfo,
PostgresConnectionInfo,
QueryBigQueryDTO,
QueryCannerDTO,
Expand All @@ -27,6 +28,7 @@
QueryMinioFileDTO,
QueryMSSqlDTO,
QueryMySqlDTO,
QueryOracleDTO,
QueryPostgresDTO,
QueryS3FileDTO,
QuerySnowflakeDTO,
Expand All @@ -43,6 +45,7 @@ class DataSource(StrEnum):
clickhouse = auto()
mssql = auto()
mysql = auto()
oracle = auto()
postgres = auto()
snowflake = auto()
trino = auto()
Expand Down Expand Up @@ -70,6 +73,7 @@ class DataSourceExtension(Enum):
clickhouse = QueryClickHouseDTO
mssql = QueryMSSqlDTO
mysql = QueryMySqlDTO
oracle = QueryOracleDTO
postgres = QueryPostgresDTO
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO
Expand Down Expand Up @@ -176,6 +180,16 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend:
password=(info.password and info.password.get_secret_value()),
)

@staticmethod
def get_oracle_connection(info: OracleConnectionInfo) -> BaseBackend:
return ibis.oracle.connect(
host=info.host.get_secret_value(),
port=int(info.port.get_secret_value()),
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=(info.password and info.password.get_secret_value()),
)

@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
return ibis.snowflake.connect(
Expand Down
6 changes: 6 additions & 0 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def base64_to_dict(base64_str: str) -> dict:

def to_json(df: pd.DataFrame) -> dict:
for column in df.columns:
if df[column].dtype == object:
# Convert Oracle LOB objects to string
df[column] = df[column].apply(lambda x: str(x) if hasattr(x, "read") else x)
if is_datetime64_any_dtype(df[column].dtype):
df[column] = _to_datetime_and_format(df[column])
return _to_json_obj(df)
Expand Down Expand Up @@ -44,6 +47,9 @@ def default(obj):
return _date_offset_to_str(obj)
if isinstance(obj, datetime.timedelta):
return str(obj)
# Add handling for any remaining LOB objects
if hasattr(obj, "read"): # Check if object is LOB-like
return str(obj)
raise TypeError

json_obj = orjson.loads(
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ibis-framework = { version = "9.5.0", extras = [
"clickhouse",
"mssql",
"mysql",
"oracle",
"postgres",
"snowflake",
"trino",
Expand All @@ -33,6 +34,7 @@ gql = { extras = ["aiohttp"], version = "3.5.0" }
anyio = "4.8.0"
duckdb = "1.1.3"
opendal = ">=0.45"
oracledb = "2.5.1"

[tool.poetry.group.dev.dependencies]
pytest = "8.3.4"
Expand Down Expand Up @@ -61,6 +63,7 @@ markers = [
"functions: mark a test as a functions test",
"mssql: mark a test as a mssql test",
"mysql: mark a test as a mysql test",
"oracle: mark a test as a oracle test",
"postgres: mark a test as a postgres test",
"snowflake: mark a test as a snowflake test",
"trino: mark a test as a trino test",
Expand Down
134 changes: 134 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import base64

import orjson
import pandas as pd
import pytest
import sqlalchemy
from sqlalchemy import text
from testcontainers.oracle import OracleDbContainer

from tests.conftest import file_path

pytestmark = pytest.mark.oracle

base_url = "/v2/connector/oracle"

manifest = {
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "Orders",
"tableReference": {
"schema": "SYSTEM",
"table": "ORDERS",
},
"columns": [
{"name": "orderkey", "expression": "O_ORDERKEY", "type": "number"},
{"name": "custkey", "expression": "O_CUSTKEY", "type": "number"},
{
"name": "orderstatus",
"expression": "O_ORDERSTATUS",
"type": "varchar2",
},
{
"name": "totalprice",
"expression": "O_TOTALPRICE",
"type": "number",
},
{"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"},
{
"name": "order_cust_key",
"expression": "O_ORDERKEY || '_' || O_CUSTKEY",
"type": "varchar2",
},
{
"name": "timestamp",
"expression": "CAST('2024-01-01 23:59:59' AS TIMESTAMP)",
"type": "timestamp",
},
{
"name": "test_null_time",
"expression": "CAST(NULL AS TIMESTAMP)",
"type": "timestamp",
},
{
"name": "blob_column",
"expression": "UTL_RAW.CAST_TO_RAW('abc')",
"type": "blob",
},
],
"primaryKey": "orderkey",
}
],
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


@pytest.fixture(scope="module")
def oracle(request) -> OracleDbContainer:
oracle = OracleDbContainer(
"gvenzl/oracle-free:23.6-slim-faststart", oracle_password="Oracle123"
)

oracle.start()

host = oracle.get_container_host_ip()
port = oracle.get_exposed_port(1521)
connection_url = (
f"oracle+oracledb://SYSTEM:Oracle123@{host}:{port}/?service_name=FREEPDB1"
)
engine = sqlalchemy.create_engine(connection_url, echo=True)

with engine.begin() as conn:
pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql(
"orders", engine, index=False
)
pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql(
"customer", engine, index=False
)
Comment on lines +92 to +97
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for data loading.

The data loading from parquet files lacks error handling. Consider:

  1. Validating file existence
  2. Adding error handling for file read/write operations
  3. Verifying data integrity after loading

# Add table and column comments
conn.execute(text("COMMENT ON TABLE orders IS 'This is a table comment'"))
conn.execute(text("COMMENT ON COLUMN orders.o_comment IS 'This is a comment'"))

request.addfinalizer(oracle.stop)
return oracle


async def test_query_with_connection_url(
client, manifest_str, oracle: OracleDbContainer
):
connection_url = _to_connection_url(oracle)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": {"connectionUrl": connection_url},
"manifestStr": manifest_str,
"sql": "SELECT * FROM SYSTEM.ORDERS LIMIT 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["data"]) == 1
assert result["data"][0][0] == 1
assert result["dtypes"] is not None


def _to_connection_info(oracle: OracleDbContainer):
return {
"host": oracle.get_container_host_ip(),
"port": oracle.get_exposed_port(oracle.port),
"user": "SYSTEM",
"password": "Oracle123",
"service": "FREEPDB1",
}


def _to_connection_url(oracle: OracleDbContainer):
info = _to_connection_info(oracle)
return f"oracle://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['service']}"