-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: create a framework for testing functions (#922)
Co-authored-by: Jax Liu <[email protected]>
- Loading branch information
1 parent
b50b357
commit 367b3a1
Showing
6 changed files
with
183 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,148 +1,37 @@ | ||
function_type,name,return_type,description | ||
scalar,unistr,varchar,"Postgres: Evaluate escaped Unicode characters in the argument" | ||
aggregate,array_agg,array,"Aggregates values into array" | ||
aggregate,avg,numeric,"Returns average of non-null values" | ||
aggregate,bit_and,same as input,"Bitwise AND of non-null values" | ||
aggregate,bit_or,same as input,"Bitwise OR of non-null values" | ||
aggregate,bool_and,boolean,"True if all values are true" | ||
aggregate,bool_or,boolean,"True if any value is true" | ||
aggregate,count,bigint,"Counts number of rows" | ||
aggregate,every,boolean,"Equivalent to bool_and" | ||
aggregate,json_agg,json,"Aggregates values into JSON array" | ||
aggregate,json_object_agg,json,"Aggregates pairs into JSON object" | ||
aggregate,jsonb_agg,jsonb,"Aggregates values into JSONB array" | ||
aggregate,jsonb_object_agg,jsonb,"Aggregates pairs into JSONB object" | ||
aggregate,max,same as input,"Maximum value" | ||
aggregate,min,same as input,"Minimum value" | ||
aggregate,string_agg,text,"Concatenates values with delimiter" | ||
aggregate,sum,numeric,"Sum of values" | ||
aggregate,xmlagg,xml,"Concatenates XML values" | ||
scalar,abs,numeric,"Absolute value" | ||
scalar,acos,double precision,"Arc cosine" | ||
scalar,age,interval,"Difference between timestamps" | ||
scalar,array_append,array,"Append element to array" | ||
scalar,array_cat,array,"Concatenate arrays" | ||
scalar,array_length,int,"Array length for specified dimension" | ||
scalar,array_lower,int,"Lower bound of array" | ||
scalar,array_position,int,"Position of value in array" | ||
scalar,array_prepend,array,"Prepend element to array" | ||
scalar,array_remove,array,"Remove element from array" | ||
scalar,array_replace,array,"Replace element in array" | ||
scalar,array_to_json,json,"Convert array to JSON" | ||
scalar,array_upper,int,"Upper bound of array" | ||
scalar,ascii,int,"ASCII code of first character" | ||
scalar,bit_length,int,"Number of bits in string" | ||
scalar,btrim,text,"Remove characters from both ends" | ||
scalar,cbrt,double precision,"Cube root" | ||
scalar,ceil,numeric,"Round up to nearest integer" | ||
scalar,char_length,int,"Number of characters in string" | ||
scalar,chr,text,"Character from ASCII code" | ||
scalar,coalesce,same as args,"First non-null value" | ||
scalar,concat,text,"Concatenate text strings" | ||
scalar,concat_ws,text,"Concatenate with separator" | ||
scalar,convert,bytea,"Convert string to encoding" | ||
scalar,convert_from,text,"Convert from encoding" | ||
scalar,convert_to,bytea,"Convert to encoding" | ||
scalar,cos,double precision,"Cosine" | ||
scalar,cot,double precision,"Cotangent" | ||
scalar,current_date,date,"Current date" | ||
scalar,current_time,time,"Current time" | ||
scalar,date_part,double precision,"Get subfield from date/time" | ||
scalar,date_trunc,timestamp,"Truncate to specified precision" | ||
scalar,decode,bytea,"Decode binary from text representation" | ||
scalar,degrees,double precision,"Radians to degrees" | ||
scalar,div,numeric,"Integer division" | ||
scalar,encode,text,"Encode binary data to text representation" | ||
scalar,exp,numeric,"Exponential" | ||
scalar,extract,numeric,"Get subfield from date/time" | ||
scalar,floor,numeric,"Round down to nearest integer" | ||
scalar,format,text,"Format string" | ||
scalar,generate_series,setof int,"Generate series of values" | ||
scalar,greatest,same as arg types,"Greatest of arguments" | ||
scalar,host,text,"Extract host from IP address" | ||
scalar,initcap,text,"Capitalize first letter of each word" | ||
scalar,isfinite,boolean,"Test for finite date/timestamp/interval" | ||
scalar,json_array_length,int,"Length of JSON array" | ||
scalar,json_extract_path,json,"Get JSON object at path" | ||
scalar,json_object_keys,setof text,"Get JSON object keys" | ||
scalar,json_to_record,record,"Get JSON object as record" | ||
scalar,jsonb_array_length,int,"Length of JSONB array" | ||
scalar,jsonb_extract_path,jsonb,"Get JSONB object at path" | ||
scalar,jsonb_object_keys,setof text,"Get JSONB object keys" | ||
scalar,jsonb_to_record,record,"Get JSONB object as record" | ||
scalar,least,same as arg types,"Least of arguments" | ||
scalar,left,text,"Extract leftmost characters" | ||
scalar,length,int,"String length" | ||
scalar,ln,numeric,"Natural logarithm" | ||
scalar,log,numeric,"Logarithm" | ||
scalar,lower,text,"Convert to lower case" | ||
scalar,lpad,text,"Pad string on left" | ||
scalar,ltrim,text,"Remove characters from start" | ||
scalar,md5,text,"Calculate MD5 hash" | ||
scalar,mod,numeric,"Modulo (remainder)" | ||
scalar,now,timestamp,"Current transaction timestamp" | ||
scalar,nullif,same as arg1,"Return null if equal" | ||
scalar,octet_length,int,"Number of bytes in string" | ||
scalar,overlay,text,"Replace substring" | ||
scalar,parse_ident,text[],"Parse qualified identifier" | ||
scalar,pg_client_encoding,name,"Current client encoding" | ||
scalar,pg_get_expr,text,"Decompile internal form of expression" | ||
scalar,pg_get_viewdef,text,"Get view definition" | ||
scalar,pg_typeof,regtype,"Get data type of any value" | ||
scalar,pi,double precision,"π constant" | ||
scalar,position,int,"Location of substring" | ||
scalar,power,numeric,"Power" | ||
scalar,quote_ident,text,"Quote identifier" | ||
scalar,quote_literal,text,"Quote literal" | ||
scalar,quote_nullable,text,"Quote nullable" | ||
scalar,radians,double precision,"Degrees to radians" | ||
scalar,random,double precision,"Random value" | ||
scalar,regexp_match,text[],"Match regular expression" | ||
scalar,regexp_matches,setof text[],"Match regular expression with flags" | ||
scalar,regexp_replace,text,"Replace matching text" | ||
scalar,regexp_split_to_array,text[],"Split string by pattern" | ||
scalar,regexp_split_to_table,setof text,"Split string by pattern" | ||
scalar,repeat,text,"Repeat string" | ||
scalar,replace,text,"Replace substring" | ||
scalar,reverse,text,"Reverse string" | ||
scalar,right,text,"Extract rightmost characters" | ||
scalar,round,numeric,"Round to nearest integer or decimal" | ||
scalar,rpad,text,"Pad string on right" | ||
scalar,rtrim,text,"Remove characters from end" | ||
scalar,set_config,text,"Set parameter" | ||
scalar,sha224,bytea,"SHA-224 hash" | ||
scalar,sha256,bytea,"SHA-256 hash" | ||
scalar,sha384,bytea,"SHA-384 hash" | ||
scalar,sha512,bytea,"SHA-512 hash" | ||
scalar,sign,numeric,"Sign of number" | ||
scalar,sin,double precision,"Sine" | ||
scalar,split_part,text,"Split string on delimiter" | ||
scalar,sqrt,numeric,"Square root" | ||
scalar,starts_with,boolean,"String starts with" | ||
scalar,string_to_array,text[],"Split string to array" | ||
scalar,substring,text,"Extract substring" | ||
scalar,tan,double precision,"Tangent" | ||
scalar,timezone,interval,"Timezone offset" | ||
scalar,to_char,text,"Convert to string" | ||
scalar,to_date,date,"Convert string to date" | ||
scalar,to_hex,text,"Convert number to hex" | ||
scalar,to_json,json,"Convert to JSON" | ||
scalar,to_jsonb,jsonb,"Convert to JSONB" | ||
scalar,to_number,numeric,"Convert string to number" | ||
scalar,to_timestamp,timestamp,"Convert string to timestamp" | ||
scalar,translate,text,"Replace characters" | ||
scalar,trim,text,"Remove characters" | ||
scalar,trunc,numeric,"Truncate" | ||
scalar,upper,text,"Convert to upper case" | ||
scalar,uuid_generate_v4,uuid,"Generate UUID v4" | ||
window,cume_dist,double precision,"Cumulative distribution" | ||
window,dense_rank,bigint,"Rank without gaps" | ||
window,first_value,same as input,"First value in window" | ||
window,lag,same as input,"Value from previous row" | ||
window,last_value,same as input,"Last value in window" | ||
window,lead,same as input,"Value from next row" | ||
window,nth_value,same as input,"nth value in window" | ||
window,ntile,integer,"Split rows into buckets" | ||
window,percent_rank,double precision,"Relative rank" | ||
window,rank,bigint,"Rank with gaps" | ||
window,row_number,bigint,"Sequential number" | ||
function_type,name,return_type,param_names,param_types,description | ||
aggregate,every,boolean,,boolean,"Equivalent to bool_and" | ||
aggregate,json_object_agg,json,,"text,any","Aggregates pairs into JSON object" | ||
aggregate,jsonb_object_agg,jsonb,,"text,any","Aggregates pairs into JSONB object" | ||
aggregate,xmlagg,xml,,xml,"Concatenates XML values" | ||
scalar,age,interval,,"timestamp,timestamp","Difference between timestamps" | ||
scalar,array_lower,int,,"array,integer","Lower bound of array" | ||
scalar,array_to_json,json,,array,"Convert array to JSON" | ||
scalar,array_upper,int,,"array,integer","Upper bound of array" | ||
scalar,convert_from,text,,"bytea,text","Convert from encoding" | ||
scalar,convert_to,bytea,,"text,text","Convert to encoding" | ||
scalar,extract,numeric,,"text,timestamp","Get subfield from date/time" | ||
scalar,format,text,,"text,array<any>","Format string" | ||
scalar,greatest,same as arg types,,array<any>,"Greatest of arguments" | ||
scalar,host,text,,inet,"Extract host from IP address" | ||
scalar,isfinite,boolean,,timestamp,"Test for finite date/timestamp/interval" | ||
scalar,json_array_length,int,,json,"Length of JSON array" | ||
scalar,json_extract_path,json,,"json,array<text>","Get JSON object at path" | ||
scalar,json_object_keys,setof text,,json,"Get JSON object keys" | ||
scalar,jsonb_array_length,int,,jsonb,"Length of JSONB array" | ||
scalar,jsonb_extract_path,jsonb,,"jsonb,array<text>","Get JSONB object at path" | ||
scalar,jsonb_object_keys,setof text,,jsonb,"Get JSONB object keys" | ||
scalar,least,same as arg types,,array<any>,"Least of arguments" | ||
scalar,mod,numeric,,"numeric,numeric","Modulo (remainder)" | ||
scalar,parse_ident,array<text>,,"text,boolean","Parse qualified identifier" | ||
scalar,pg_client_encoding,name,,,"Current client encoding" | ||
scalar,pg_get_expr,text,,"pg_node_tree,oid","Decompile internal form of expression" | ||
scalar,pg_get_viewdef,text,,"oid","Get view definition" | ||
scalar,quote_ident,text,,text,"Quote identifier" | ||
scalar,quote_literal,text,,any,"Quote literal" | ||
scalar,quote_nullable,text,,any,"Quote nullable" | ||
scalar,regexp_split_to_array,array<text>,,"text,text","Split string by pattern" | ||
scalar,regexp_split_to_table,setof text,,"text,text","Split string by pattern" | ||
scalar,sign,numeric,,numeric,"Sign of number" | ||
scalar,to_json,json,,boolean,"Convert to JSON" | ||
scalar,to_number,numeric,,"text,text","Convert string to number" | ||
scalar,unistr,varchar,,text,"Postgres: Evaluate escaped Unicode characters in the argument" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import csv | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class Function(BaseModel): | ||
function_type: str | ||
name: str | ||
return_type: str | ||
param_names: list[str] | None | ||
param_types: list[str] | None | ||
description: str | ||
|
||
|
||
class FunctionCsvParser: | ||
def __init__(self, file_path: str): | ||
self.file_path = file_path | ||
|
||
def parse(self) -> list[Function]: | ||
with open(self.file_path, encoding="utf-8") as csvfile: | ||
return [ | ||
Function( | ||
function_type=row["function_type"], | ||
name=row["name"], | ||
return_type=row["return_type"], | ||
param_names=self._split_param(row["param_names"]), | ||
param_types=self._split_param(row["param_types"]), | ||
description=row["description"], | ||
) | ||
for row in csv.DictReader(csvfile) | ||
] | ||
|
||
@staticmethod | ||
def _split_param(param: str) -> list[str]: | ||
return param.split(",") if param else [] | ||
|
||
|
||
class SqlTestGenerator: | ||
def __init__(self): | ||
pass | ||
|
||
def generate_sql(self, function: Function) -> Optional[str]: | ||
if function.function_type == "scalar": | ||
return self.generate_scalar_sql(function) | ||
elif function.function_type == "aggregate": | ||
return self.generate_aggregate_sql(function) | ||
elif function.function_type == "window": | ||
return self.generate_window_sql(function) | ||
else: | ||
raise NotImplementedError( | ||
f"Unsupported function type: {function.function_type}" | ||
) | ||
|
||
def generate_scalar_sql(self, function: Function) -> str: | ||
args = self.generate_sample_args(function.param_types) | ||
if function.name in ("convert_from", "convert_to"): | ||
args[1] = "'UTF8'" | ||
if function.name == "extract": | ||
args[0] = "'day'" | ||
if function.name in ("json_array_length", "jsonb_array_length"): | ||
args[0] = '\'[{"key": "value"}]\'' | ||
if function.name in ("json_extract_path", "jsonb_extract_path"): | ||
args[1] = "'key'" | ||
if function.name == "to_number": | ||
args = ["'123'", "'999'"] | ||
formatted_args = ", ".join(args) | ||
return f"SELECT {function.name}({formatted_args})" | ||
|
||
def generate_aggregate_sql(self, function: Function) -> str: | ||
sample_args = self.generate_sample_args(function.param_types) | ||
formatted_args = ", ".join(sample_args) | ||
sql = f"SELECT {function.name}({formatted_args})" | ||
return sql | ||
|
||
def generate_window_sql(self, function: Function) -> str: | ||
# TODO: Implement window function generation | ||
return "" | ||
|
||
def generate_sample_args(self, param_types: list[str]) -> list[str]: | ||
return [self.map_param_type_to_sample(p_type) for p_type in param_types] | ||
|
||
@staticmethod | ||
def map_param_type_to_sample(p_type: str) -> str: | ||
p_type = p_type.lower() | ||
if p_type in {"int", "integer", "bigint", "smallint"}: | ||
return "1" | ||
elif p_type in {"numeric", "decimal", "double precision", "float", "real"}: | ||
return "1.0" | ||
elif p_type in {"text", "varchar", "char"}: | ||
return "'test'" | ||
elif p_type in {"boolean", "bool"}: | ||
return "TRUE" | ||
elif p_type in {"date"}: | ||
return "cast('2024-01-01' as date)" | ||
elif p_type in { | ||
"timestamp", | ||
"timestamp without time zone", | ||
"timestamp with time zone", | ||
}: | ||
return "cast('2024-01-01 00:00:00' as timestamp)" | ||
elif p_type in {"time", "time without time zone", "time with time zone"}: | ||
return "cast('12:00:00' as time)" | ||
elif p_type == "uuid": | ||
return "'123e4567-e89b-12d3-a456-426614174000'" | ||
elif p_type.startswith("array"): | ||
inner_type = p_type[p_type.find("<") + 1 : p_type.find(">")].strip() | ||
sample_inner = SqlTestGenerator.map_param_type_to_sample(inner_type) | ||
return f"ARRAY[{sample_inner}, {sample_inner}, {sample_inner}]" | ||
elif p_type in {"json", "jsonb"}: | ||
return '\'{"key": "value"}\'' | ||
elif p_type == "bytea": | ||
return "'\\xc3a4'" | ||
elif p_type == "interval": | ||
return "'1 day'" | ||
else: | ||
return "NULL" |