From a336e008e76c3f976ed58156eccb8b131a76b72d Mon Sep 17 00:00:00 2001 From: "ZhengYu, Xu" <zen-xu@outlook.com> Date: Fri, 19 Apr 2024 10:00:05 +0800 Subject: [PATCH 1/2] fix: correct read sql return type annotation --- connectorx-python/connectorx/__init__.py | 11 ++++----- connectorx-python/connectorx/connectorx.pyi | 27 ++++++++++++++------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index a6a5acc6d..456042077 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -1,16 +1,17 @@ from __future__ import annotations - import importlib from importlib.metadata import version -from typing import Any, Literal, TYPE_CHECKING, overload +from typing import Literal, TYPE_CHECKING, overload from .connectorx import ( read_sql as _read_sql, partition_sql as _partition_sql, read_sql2 as _read_sql2, get_meta as _get_meta, + _DataframeInfos, + _ArrowInfos, ) if TYPE_CHECKING: @@ -394,9 +395,7 @@ def read_sql( return df -def reconstruct_arrow( - result: tuple[list[str], list[list[tuple[int, int]]]], -) -> pa.Table: +def reconstruct_arrow(result: _ArrowInfos) -> pa.Table: import pyarrow as pa names, ptrs = result @@ -412,7 +411,7 @@ def reconstruct_arrow( return pa.Table.from_batches(rbs) -def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame: +def reconstruct_pandas(df_infos: _DataframeInfos) -> pd.DataFrame: import pandas as pd data = df_infos["data"] diff --git a/connectorx-python/connectorx/connectorx.pyi b/connectorx-python/connectorx/connectorx.pyi index b556d918b..c1ccbfa53 100644 --- a/connectorx-python/connectorx/connectorx.pyi +++ b/connectorx-python/connectorx/connectorx.pyi @@ -1,11 +1,22 @@ from __future__ import annotations -from typing import overload, Literal, Any, TypeAlias -import pandas as pd +from typing import overload, Literal, Any, TypeAlias, TypedDict +import numpy as np _ArrowArrayPtr: TypeAlias = int _ArrowSchemaPtr: TypeAlias = int -_Column: TypeAlias = str +_Header: TypeAlias = str + +class PandasBlockInfo: + cids: list[int] + dt: int + +class _DataframeInfos(TypedDict): + data: list[tuple[np.ndarray, ...] | np.ndarray] + headers: list[_Header] + block_infos: list[PandasBlockInfo] + +_ArrowInfos = tuple[list[_Header], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]] @overload def read_sql( @@ -14,7 +25,7 @@ def read_sql( protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, -) -> pd.DataFrame: ... +) -> _DataframeInfos: ... @overload def read_sql( conn: str, @@ -22,13 +33,11 @@ def read_sql( protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, -) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ... +) -> _ArrowInfos: ... def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ... -def read_sql2( - sql: str, db_map: dict[str, str] -) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ... +def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ... def get_meta( conn: str, protocol: Literal["csv", "binary", "cursor", "simple", "text"] | None, query: str, -) -> dict[str, Any]: ... +) -> _DataframeInfos: ... From 7158e805ab7a6c0e045103baf638b4749e8d0419 Mon Sep 17 00:00:00 2001 From: "ZhengYu, Xu" <zen-xu@outlook.com> Date: Fri, 19 Apr 2024 15:14:03 +0800 Subject: [PATCH 2/2] fix: put `_DataframeInfos` and `_ArrowInfos` under TYPE_CHECKING block --- connectorx-python/connectorx/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 456042077..441c67394 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -10,8 +10,6 @@ partition_sql as _partition_sql, read_sql2 as _read_sql2, get_meta as _get_meta, - _DataframeInfos, - _ArrowInfos, ) if TYPE_CHECKING: @@ -21,6 +19,10 @@ import dask.dataframe as dd import pyarrow as pa + # only for typing hints + from .connectorx import _DataframeInfos, _ArrowInfos + + __version__ = version(__name__) import os