Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ibis): introduce S3 File connector #1038

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ibis-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ jobs:
- name: Run tests
env:
WREN_ENGINE_ENDPOINT: http://localhost:8080
run: poetry run pytest -m "not bigquery and not snowflake and not canner"
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION: ${{ secrets.AWS_REGION }}
AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }}
run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file"
- name: Test bigquery if need
if: contains(github.event.pull_request.labels.*.name, 'bigquery')
env:
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_read_dialect(cls, experiment) -> str | None:
def _get_write_dialect(cls, data_source: DataSource) -> str:
if data_source == DataSource.canner:
return "trino"
elif data_source == DataSource.local_file:
elif data_source in {DataSource.local_file, DataSource.s3_file}:
return "duckdb"
return data_source.name

Expand Down
16 changes: 16 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class QueryLocalFileDTO(QueryDTO):
connection_info: LocalFileConnectionInfo = connection_info_field


class QueryS3FileDTO(QueryDTO):
connection_info: S3FileConnectionInfo = connection_info_field


class BigQueryConnectionInfo(BaseModel):
project_id: SecretStr
dataset_id: SecretStr
Expand Down Expand Up @@ -147,6 +151,17 @@ class LocalFileConnectionInfo(BaseModel):
)


class S3FileConnectionInfo(BaseModel):
url: SecretStr = Field(description="the root path of the s3 bucket", default="/")
format: str = Field(
description="File format", default="csv", examples=["csv", "parquet", "json"]
)
bucket: SecretStr
region: SecretStr
access_key: SecretStr
secret_key: SecretStr


ConnectionInfo = (
BigQueryConnectionInfo
| CannerConnectionInfo
Expand All @@ -157,6 +172,7 @@ class LocalFileConnectionInfo(BaseModel):
| SnowflakeConnectionInfo
| TrinoConnectionInfo
| LocalFileConnectionInfo
| S3FileConnectionInfo
)


Expand Down
25 changes: 21 additions & 4 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@
import ibis.formats
import pandas as pd
import sqlglot.expressions as sge
from duckdb import HTTPException
from google.cloud import bigquery
from google.oauth2 import service_account
from ibis import BaseBackend
from ibis.backends.sql.compilers.postgres import compiler as postgres_compiler

from app.model import ConnectionInfo, UnknownIbisError, UnprocessableEntityError
from app.model import (
ConnectionInfo,
S3FileConnectionInfo,
UnknownIbisError,
UnprocessableEntityError,
)
from app.model.data_source import DataSource
from app.model.utils import init_duckdb_s3

# Override datatypes of ibis
importlib.import_module("app.custom_ibis.backends.sql.datatypes")
Expand All @@ -32,6 +39,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
self._connector = BigQueryConnector(connection_info)
elif data_source == DataSource.local_file:
self._connector = DuckDBConnector(connection_info)
elif data_source == DataSource.s3_file:
self._connector = DuckDBConnector(connection_info)
else:
self._connector = SimpleConnector(data_source, connection_info)

Expand Down Expand Up @@ -147,16 +156,24 @@ def query(self, sql: str, limit: int) -> pd.DataFrame:


class DuckDBConnector:
def __init__(self, _connection_info: ConnectionInfo):
def __init__(self, connection_info: ConnectionInfo):
import duckdb

self.connection = duckdb.connect()
if isinstance(connection_info, S3FileConnectionInfo):
init_duckdb_s3(self.connection, connection_info)

def query(self, sql: str, limit: int) -> pd.DataFrame:
return self.connection.execute(sql).fetch_df().head(limit)
try:
return self.connection.execute(sql).fetch_df().head(limit)
except HTTPException as e:
raise UnprocessableEntityError(f"Failed to execute query: {e!s}")

def dry_run(self, sql: str) -> None:
self.connection.execute(sql)
try:
self.connection.execute(sql)
except HTTPException as e:
raise QueryDryRunError(f"Failed to execute query: {e!s}")


@cache
Expand Down
3 changes: 3 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QueryMSSqlDTO,
QueryMySqlDTO,
QueryPostgresDTO,
QueryS3FileDTO,
QuerySnowflakeDTO,
QueryTrinoDTO,
SnowflakeConnectionInfo,
Expand All @@ -44,6 +45,7 @@ class DataSource(StrEnum):
snowflake = auto()
trino = auto()
local_file = auto()
s3_file = auto()

def get_connection(self, info: ConnectionInfo) -> BaseBackend:
try:
Expand All @@ -68,6 +70,7 @@ class DataSourceExtension(Enum):
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO
local_file = QueryLocalFileDTO
s3_file = QueryS3FileDTO

def __init__(self, dto: QueryDTO):
self.dto = dto
Expand Down
3 changes: 2 additions & 1 deletion ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from app.model.metadata.metadata import Metadata
from app.model.metadata.mssql import MSSQLMetadata
from app.model.metadata.mysql import MySQLMetadata
from app.model.metadata.object_storage import LocalFileMetadata
from app.model.metadata.object_storage import LocalFileMetadata, S3FileMetadata
from app.model.metadata.postgres import PostgresMetadata
from app.model.metadata.snowflake import SnowflakeMetadata
from app.model.metadata.trino import TrinoMetadata
Expand All @@ -20,6 +20,7 @@
DataSource.trino: TrinoMetadata,
DataSource.snowflake: SnowflakeMetadata,
DataSource.local_file: LocalFileMetadata,
DataSource.s3_file: S3FileMetadata,
}


Expand Down
139 changes: 93 additions & 46 deletions ibis-server/app/model/metadata/object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,70 +4,80 @@
import opendal
from loguru import logger

from app.model import LocalFileConnectionInfo
from app.model import (
LocalFileConnectionInfo,
S3FileConnectionInfo,
UnprocessableEntityError,
)
from app.model.metadata.dto import (
Column,
RustWrenEngineColumnType,
Table,
TableProperties,
)
from app.model.metadata.metadata import Metadata
from app.model.utils import init_duckdb_s3


class ObjectStorageMetadata(Metadata):
def __init__(self, connection_info):
super().__init__(connection_info)

def get_table_list(self) -> list[Table]:
op = opendal.Operator("fs", root=self.connection_info.url.get_secret_value())
op = self._get_dal_operator()
conn = self._get_connection()
unique_tables = {}
for file in op.list("/"):
if file.path != "/":
stat = op.stat(file.path)
if stat.mode.is_dir():
# if the file is a directory, use the directory name as the table name
table_name = os.path.basename(os.path.normpath(file.path))
full_path = f"{self.connection_info.url.get_secret_value()}/{table_name}/*.{self.connection_info.format}"
else:
# if the file is a file, use the file name as the table name
table_name = os.path.splitext(os.path.basename(file.path))[0]
full_path = (
f"{self.connection_info.url.get_secret_value()}/{file.path}"
)
try:
for file in op.list("/"):
if file.path != "/":
stat = op.stat(file.path)
if stat.mode.is_dir():
# if the file is a directory, use the directory name as the table name
table_name = os.path.basename(os.path.normpath(file.path))
full_path = f"{self.connection_info.url.get_secret_value()}/{table_name}/*.{self.connection_info.format}"
else:
# if the file is a file, use the file name as the table name
table_name = os.path.splitext(os.path.basename(file.path))[0]
full_path = (
f"{self.connection_info.url.get_secret_value()}/{file.path}"
)

# read the file with the target format if unreadable, skip the file
df = self._read_df(conn, full_path)
if df is None:
continue
columns = []
try:
for col in df.columns:
duckdb_type = df[col].dtypes[0]
columns.append(
Column(
name=col,
type=self._to_column_type(duckdb_type.__str__()),
notNull=False,
# add required prefix for object storage
full_path = self._get_full_path(full_path)
# read the file with the target format if unreadable, skip the file
df = self._read_df(conn, full_path)
if df is None:
continue
columns = []
try:
for col in df.columns:
duckdb_type = df[col].dtypes[0]
columns.append(
Column(
name=col,
type=self._to_column_type(duckdb_type.__str__()),
notNull=False,
)
)
)
except Exception as e:
logger.debug(f"Failed to read column types: {e}")
continue

unique_tables[table_name] = Table(
name=table_name,
description=None,
columns=[],
properties=TableProperties(
table=table_name,
schema=None,
catalog=None,
path=full_path,
),
primaryKey=None,
)
unique_tables[table_name].columns = columns
except Exception as e:
logger.debug(f"Failed to read column types: {e}")
continue

unique_tables[table_name] = Table(
name=table_name,
description=None,
columns=[],
properties=TableProperties(
table=table_name,
schema=None,
catalog=None,
path=full_path,
),
primaryKey=None,
)
unique_tables[table_name].columns = columns
except Exception as e:
raise UnprocessableEntityError(f"Failed to list files: {e!s}")

return list(unique_tables.values())

Expand Down Expand Up @@ -147,10 +157,47 @@ def _to_column_type(self, col_type: str) -> RustWrenEngineColumnType:
def _get_connection(self):
return duckdb.connect()

def _get_dal_operator(self):
return opendal.Operator("fs", root=self.connection_info.url.get_secret_value())

def _get_full_path(self, path):
return path


class LocalFileMetadata(ObjectStorageMetadata):
def __init__(self, connection_info: LocalFileConnectionInfo):
super().__init__(connection_info)

def get_version(self):
return "Local File System"


class S3FileMetadata(ObjectStorageMetadata):
def __init__(self, connection_info: S3FileConnectionInfo):
super().__init__(connection_info)

def get_version(self):
return "S3"

def _get_connection(self):
conn = duckdb.connect()
init_duckdb_s3(conn, self.connection_info)
logger.debug("Initialized duckdb s3")
return conn

def _get_dal_operator(self):
info: S3FileConnectionInfo = self.connection_info
return opendal.Operator(
"s3",
root=info.url.get_secret_value(),
bucket=info.bucket.get_secret_value(),
region=info.region.get_secret_value(),
secret_access_key=info.secret_key.get_secret_value(),
access_key_id=info.access_key.get_secret_value(),
)

def _get_full_path(self, path):
if path.startswith("/"):
path = path[1:]

return f"s3://{self.connection_info.bucket.get_secret_value()}/{path}"
22 changes: 22 additions & 0 deletions ibis-server/app/model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from duckdb import DuckDBPyConnection, HTTPException

from app.model import S3FileConnectionInfo


def init_duckdb_s3(
connection: DuckDBPyConnection, connection_info: S3FileConnectionInfo
):
create_secret = f"""
CREATE SECRET wren_s3 (
TYPE S3,
KEY_ID '{connection_info.access_key.get_secret_value()}',
SECRET '{connection_info.secret_key.get_secret_value()}',
REGION '{connection_info.region.get_secret_value()}'
)
"""
try:
result = connection.execute(create_secret).fetchone()
if result is None or not result[0]:
raise Exception("Failed to create secret")
except HTTPException as e:
raise Exception("Failed to create secret", e)
1 change: 1 addition & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ markers = [
"snowflake: mark a test as a snowflake test",
"trino: mark a test as a trino test",
"local_file: mark a test as a local file test",
"s3_file: mark a test as a s3 file test",
"beta: mark a test as a test for beta versions of the engine",
]

Expand Down
4 changes: 2 additions & 2 deletions ibis-server/tests/routers/v2/connector/test_local_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ async def test_unsupported_format(client):
},
},
)
assert response.status_code == 501
assert response.text == "Unsupported format: unsupported"
assert response.status_code == 422
assert response.text == "Failed to list files: Unsupported format: unsupported"


async def test_list_parquet_files(client):
Expand Down
Loading
Loading