From a336e008e76c3f976ed58156eccb8b131a76b72d Mon Sep 17 00:00:00 2001
From: "ZhengYu, Xu" <zen-xu@outlook.com>
Date: Fri, 19 Apr 2024 10:00:05 +0800
Subject: [PATCH 1/2] fix: correct read sql return type annotation

---
 connectorx-python/connectorx/__init__.py    | 11 ++++-----
 connectorx-python/connectorx/connectorx.pyi | 27 ++++++++++++++-------
 2 files changed, 23 insertions(+), 15 deletions(-)

diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py
index a6a5acc6d..456042077 100644
--- a/connectorx-python/connectorx/__init__.py
+++ b/connectorx-python/connectorx/__init__.py
@@ -1,16 +1,17 @@
 from __future__ import annotations
 
-
 import importlib
 from importlib.metadata import version
 
-from typing import Any, Literal, TYPE_CHECKING, overload
+from typing import Literal, TYPE_CHECKING, overload
 
 from .connectorx import (
     read_sql as _read_sql,
     partition_sql as _partition_sql,
     read_sql2 as _read_sql2,
     get_meta as _get_meta,
+    _DataframeInfos,
+    _ArrowInfos,
 )
 
 if TYPE_CHECKING:
@@ -394,9 +395,7 @@ def read_sql(
     return df
 
 
-def reconstruct_arrow(
-    result: tuple[list[str], list[list[tuple[int, int]]]],
-) -> pa.Table:
+def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
     import pyarrow as pa
 
     names, ptrs = result
@@ -412,7 +411,7 @@ def reconstruct_arrow(
     return pa.Table.from_batches(rbs)
 
 
-def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame:
+def reconstruct_pandas(df_infos: _DataframeInfos) -> pd.DataFrame:
     import pandas as pd
 
     data = df_infos["data"]
diff --git a/connectorx-python/connectorx/connectorx.pyi b/connectorx-python/connectorx/connectorx.pyi
index b556d918b..c1ccbfa53 100644
--- a/connectorx-python/connectorx/connectorx.pyi
+++ b/connectorx-python/connectorx/connectorx.pyi
@@ -1,11 +1,22 @@
 from __future__ import annotations
 
-from typing import overload, Literal, Any, TypeAlias
-import pandas as pd
+from typing import overload, Literal, Any, TypeAlias, TypedDict
+import numpy as np
 
 _ArrowArrayPtr: TypeAlias = int
 _ArrowSchemaPtr: TypeAlias = int
-_Column: TypeAlias = str
+_Header: TypeAlias = str
+
+class PandasBlockInfo:
+    cids: list[int]
+    dt: int
+
+class _DataframeInfos(TypedDict):
+    data: list[tuple[np.ndarray, ...] | np.ndarray]
+    headers: list[_Header]
+    block_infos: list[PandasBlockInfo]
+
+_ArrowInfos = tuple[list[_Header], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]
 
 @overload
 def read_sql(
@@ -14,7 +25,7 @@ def read_sql(
     protocol: str | None,
     queries: list[str] | None,
     partition_query: dict[str, Any] | None,
-) -> pd.DataFrame: ...
+) -> _DataframeInfos: ...
 @overload
 def read_sql(
     conn: str,
@@ -22,13 +33,11 @@ def read_sql(
     protocol: str | None,
     queries: list[str] | None,
     partition_query: dict[str, Any] | None,
-) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ...
+) -> _ArrowInfos: ...
 def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
-def read_sql2(
-    sql: str, db_map: dict[str, str]
-) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ...
+def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...
 def get_meta(
     conn: str,
     protocol: Literal["csv", "binary", "cursor", "simple", "text"] | None,
     query: str,
-) -> dict[str, Any]: ...
+) -> _DataframeInfos: ...

From 7158e805ab7a6c0e045103baf638b4749e8d0419 Mon Sep 17 00:00:00 2001
From: "ZhengYu, Xu" <zen-xu@outlook.com>
Date: Fri, 19 Apr 2024 15:14:03 +0800
Subject: [PATCH 2/2] fix: put `_DataframeInfos` and `_ArrowInfos` under
 TYPE_CHECKING block

---
 connectorx-python/connectorx/__init__.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py
index 456042077..441c67394 100644
--- a/connectorx-python/connectorx/__init__.py
+++ b/connectorx-python/connectorx/__init__.py
@@ -10,8 +10,6 @@
     partition_sql as _partition_sql,
     read_sql2 as _read_sql2,
     get_meta as _get_meta,
-    _DataframeInfos,
-    _ArrowInfos,
 )
 
 if TYPE_CHECKING:
@@ -21,6 +19,10 @@
     import dask.dataframe as dd
     import pyarrow as pa
 
+    # only for typing hints
+    from .connectorx import  _DataframeInfos, _ArrowInfos
+
+
 __version__ = version(__name__)
 
 import os