1- import io
21from typing import Optional , Any
32from chdb import _chdb
43
1110 raise ImportError ("Failed to import pyarrow" ) from None
1211
1312
13+ _arrow_format = set ({"dataframe" , "arrowtable" })
14+ _process_result_format_funs = {
15+ "dataframe" : lambda x : to_df (x ),
16+ "arrowtable" : lambda x : to_arrowTable (x ),
17+ }
18+
19+
20+ # return pyarrow table
21+ def to_arrowTable (res ):
22+ """convert res to arrow table"""
23+ # try import pyarrow and pandas, if failed, raise ImportError with suggestion
24+ try :
25+ import pyarrow as pa # noqa
26+ import pandas as pd # noqa
27+ except ImportError as e :
28+ print (f"ImportError: { e } " )
29+ print ('Please install pyarrow and pandas via "pip install pyarrow pandas"' )
30+ raise ImportError ("Failed to import pyarrow or pandas" ) from None
31+ if len (res ) == 0 :
32+ return pa .Table .from_batches ([], schema = pa .schema ([]))
33+ return pa .RecordBatchFileReader (res .bytes ()).read_all ()
34+
35+
36+ # return pandas dataframe
37+ def to_df (r ):
38+ """convert arrow table to Dataframe"""
39+ t = to_arrowTable (r )
40+ return t .to_pandas (use_threads = True )
41+
42+
1443class Connection :
1544 def __init__ (self , connection_string : str ):
1645 # print("Connection", connection_string)
@@ -22,7 +51,13 @@ def cursor(self) -> "Cursor":
2251 return self ._cursor
2352
2453 def query (self , query : str , format : str = "CSV" ) -> Any :
25- return self ._conn .query (query , format )
54+ lower_output_format = format .lower ()
55+ result_func = _process_result_format_funs .get (lower_output_format , lambda x : x )
56+ if lower_output_format in _arrow_format :
57+ format = "Arrow"
58+
59+ result = self ._conn .query (query , format )
60+ return result_func (result )
2661
2762 def close (self ) -> None :
2863 # print("close")
@@ -41,17 +76,103 @@ def __init__(self, connection):
4176 def execute (self , query : str ) -> None :
4277 self ._cursor .execute (query )
4378 result_mv = self ._cursor .get_memview ()
44- # print("get_result", result_mv)
4579 if self ._cursor .has_error ():
4680 raise Exception (self ._cursor .error_message ())
4781 if self ._cursor .data_size () == 0 :
4882 self ._current_table = None
4983 self ._current_row = 0
84+ self ._column_names = []
85+ self ._column_types = []
5086 return
51- arrow_data = result_mv .tobytes ()
52- reader = pa .ipc .open_stream (io .BytesIO (arrow_data ))
53- self ._current_table = reader .read_all ()
54- self ._current_row = 0
87+
88+ # Parse JSON data
89+ json_data = result_mv .tobytes ().decode ("utf-8" )
90+ import json
91+
92+ try :
93+ # First line contains column names
94+ # Second line contains column types
95+ # Following lines contain data
96+ lines = json_data .strip ().split ("\n " )
97+ if len (lines ) < 2 :
98+ self ._current_table = None
99+ self ._current_row = 0
100+ self ._column_names = []
101+ self ._column_types = []
102+ return
103+
104+ self ._column_names = json .loads (lines [0 ])
105+ self ._column_types = json .loads (lines [1 ])
106+
107+ # Convert data rows
108+ rows = []
109+ for line in lines [2 :]:
110+ if not line .strip ():
111+ continue
112+ row_data = json .loads (line )
113+ converted_row = []
114+ for val , type_info in zip (row_data , self ._column_types ):
115+ # Handle NULL values first
116+ if val is None :
117+ converted_row .append (None )
118+ continue
119+
120+ # Basic type conversion
121+ try :
122+ if type_info .startswith ("Int" ) or type_info .startswith ("UInt" ):
123+ converted_row .append (int (val ))
124+ elif type_info .startswith ("Float" ):
125+ converted_row .append (float (val ))
126+ elif type_info == "Bool" :
127+ converted_row .append (bool (val ))
128+ elif type_info == "String" or type_info == "FixedString" :
129+ converted_row .append (str (val ))
130+ elif type_info .startswith ("DateTime" ):
131+ from datetime import datetime
132+
133+ # Check if the value is numeric (timestamp)
134+ val_str = str (val )
135+ if val_str .replace ("." , "" ).isdigit ():
136+ converted_row .append (datetime .fromtimestamp (float (val )))
137+ else :
138+ # Handle datetime string formats
139+ if "." in val_str : # Has microseconds
140+ converted_row .append (
141+ datetime .strptime (
142+ val_str , "%Y-%m-%d %H:%M:%S.%f"
143+ )
144+ )
145+ else : # No microseconds
146+ converted_row .append (
147+ datetime .strptime (val_str , "%Y-%m-%d %H:%M:%S" )
148+ )
149+ elif type_info .startswith ("Date" ):
150+ from datetime import date , datetime
151+
152+ # Check if the value is numeric (days since epoch)
153+ val_str = str (val )
154+ if val_str .isdigit ():
155+ converted_row .append (
156+ date .fromtimestamp (float (val ) * 86400 )
157+ )
158+ else :
159+ # Handle date string format
160+ converted_row .append (
161+ datetime .strptime (val_str , "%Y-%m-%d" ).date ()
162+ )
163+ else :
164+ # For unsupported types, keep as string
165+ converted_row .append (str (val ))
166+ except (ValueError , TypeError ):
167+ # If conversion fails, keep original value as string
168+ converted_row .append (str (val ))
169+ rows .append (tuple (converted_row ))
170+
171+ self ._current_table = rows
172+ self ._current_row = 0
173+
174+ except json .JSONDecodeError as e :
175+ raise Exception (f"Failed to parse JSON data: { e } " )
55176
56177 def commit (self ) -> None :
57178 self ._cursor .commit ()
@@ -60,12 +181,10 @@ def fetchone(self) -> Optional[tuple]:
60181 if not self ._current_table or self ._current_row >= len (self ._current_table ):
61182 return None
62183
63- row_dict = {
64- col : self ._current_table .column (col )[self ._current_row ].as_py ()
65- for col in self ._current_table .column_names
66- }
184+ # Now self._current_table is a list of row tuples
185+ row = self ._current_table [self ._current_row ]
67186 self ._current_row += 1
68- return tuple ( row_dict . values ())
187+ return row
69188
70189 def fetchmany (self , size : int = 1 ) -> tuple :
71190 if not self ._current_table :
@@ -99,6 +218,30 @@ def __next__(self) -> tuple:
99218 raise StopIteration
100219 return row
101220
221+ def column_names (self ) -> list :
222+ """Return a list of column names from the last executed query"""
223+ return self ._column_names if hasattr (self , "_column_names" ) else []
224+
225+ def column_types (self ) -> list :
226+ """Return a list of column types from the last executed query"""
227+ return self ._column_types if hasattr (self , "_column_types" ) else []
228+
229+ @property
230+ def description (self ) -> list :
231+ """
232+ Return a description of the columns as per DB-API 2.0
233+ Returns a list of 7-item tuples, each containing:
234+ (name, type_code, display_size, internal_size, precision, scale, null_ok)
235+ where only name and type_code are provided
236+ """
237+ if not hasattr (self , "_column_names" ) or not self ._column_names :
238+ return []
239+
240+ return [
241+ (name , type_info , None , None , None , None , None )
242+ for name , type_info in zip (self ._column_names , self ._column_types )
243+ ]
244+
102245
103246def connect (connection_string : str = ":memory:" ) -> Connection :
104247 """
0 commit comments