Skip to content

Commit

Permalink
test: create a framework for testing functions (#922)
Browse files Browse the repository at this point in the history
Co-authored-by: Jax Liu <[email protected]>
  • Loading branch information
grieve54706 and goldmedal authored Nov 20, 2024
1 parent b50b357 commit 367b3a1
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 154 deletions.
3 changes: 3 additions & 0 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import decimal

import orjson
Expand Down Expand Up @@ -34,6 +35,8 @@ def default(obj):
return obj.hex()
if isinstance(obj, pd.tseries.offsets.DateOffset):
return _date_offset_to_str(obj)
if isinstance(obj, datetime.timedelta):
return str(obj)
raise TypeError

json_obj = orjson.loads(
Expand Down
2 changes: 1 addition & 1 deletion ibis-server/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dev:

# run the pytest tests for the given marker
test MARKER:
poetry run pytest -m {{ MARKER }}
poetry run pytest -m '{{ MARKER }}'

docker-build:
# alias for `docker-build`
Expand Down
1 change: 1 addition & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ markers = [
"bigquery: mark a test as a bigquery test",
"canner: mark a test as a canner test",
"clickhouse: mark a test as a clickhouse test",
"functions: mark a test as a functions test",
"mssql: mark a test as a mssql test",
"mysql: mark a test as a mysql test",
"postgres: mark a test as a postgres test",
Expand Down
185 changes: 37 additions & 148 deletions ibis-server/resources/function_list/postgres.csv
Original file line number Diff line number Diff line change
@@ -1,148 +1,37 @@
function_type,name,return_type,description
scalar,unistr,varchar,"Postgres: Evaluate escaped Unicode characters in the argument"
aggregate,array_agg,array,"Aggregates values into array"
aggregate,avg,numeric,"Returns average of non-null values"
aggregate,bit_and,same as input,"Bitwise AND of non-null values"
aggregate,bit_or,same as input,"Bitwise OR of non-null values"
aggregate,bool_and,boolean,"True if all values are true"
aggregate,bool_or,boolean,"True if any value is true"
aggregate,count,bigint,"Counts number of rows"
aggregate,every,boolean,"Equivalent to bool_and"
aggregate,json_agg,json,"Aggregates values into JSON array"
aggregate,json_object_agg,json,"Aggregates pairs into JSON object"
aggregate,jsonb_agg,jsonb,"Aggregates values into JSONB array"
aggregate,jsonb_object_agg,jsonb,"Aggregates pairs into JSONB object"
aggregate,max,same as input,"Maximum value"
aggregate,min,same as input,"Minimum value"
aggregate,string_agg,text,"Concatenates values with delimiter"
aggregate,sum,numeric,"Sum of values"
aggregate,xmlagg,xml,"Concatenates XML values"
scalar,abs,numeric,"Absolute value"
scalar,acos,double precision,"Arc cosine"
scalar,age,interval,"Difference between timestamps"
scalar,array_append,array,"Append element to array"
scalar,array_cat,array,"Concatenate arrays"
scalar,array_length,int,"Array length for specified dimension"
scalar,array_lower,int,"Lower bound of array"
scalar,array_position,int,"Position of value in array"
scalar,array_prepend,array,"Prepend element to array"
scalar,array_remove,array,"Remove element from array"
scalar,array_replace,array,"Replace element in array"
scalar,array_to_json,json,"Convert array to JSON"
scalar,array_upper,int,"Upper bound of array"
scalar,ascii,int,"ASCII code of first character"
scalar,bit_length,int,"Number of bits in string"
scalar,btrim,text,"Remove characters from both ends"
scalar,cbrt,double precision,"Cube root"
scalar,ceil,numeric,"Round up to nearest integer"
scalar,char_length,int,"Number of characters in string"
scalar,chr,text,"Character from ASCII code"
scalar,coalesce,same as args,"First non-null value"
scalar,concat,text,"Concatenate text strings"
scalar,concat_ws,text,"Concatenate with separator"
scalar,convert,bytea,"Convert string to encoding"
scalar,convert_from,text,"Convert from encoding"
scalar,convert_to,bytea,"Convert to encoding"
scalar,cos,double precision,"Cosine"
scalar,cot,double precision,"Cotangent"
scalar,current_date,date,"Current date"
scalar,current_time,time,"Current time"
scalar,date_part,double precision,"Get subfield from date/time"
scalar,date_trunc,timestamp,"Truncate to specified precision"
scalar,decode,bytea,"Decode binary from text representation"
scalar,degrees,double precision,"Radians to degrees"
scalar,div,numeric,"Integer division"
scalar,encode,text,"Encode binary data to text representation"
scalar,exp,numeric,"Exponential"
scalar,extract,numeric,"Get subfield from date/time"
scalar,floor,numeric,"Round down to nearest integer"
scalar,format,text,"Format string"
scalar,generate_series,setof int,"Generate series of values"
scalar,greatest,same as arg types,"Greatest of arguments"
scalar,host,text,"Extract host from IP address"
scalar,initcap,text,"Capitalize first letter of each word"
scalar,isfinite,boolean,"Test for finite date/timestamp/interval"
scalar,json_array_length,int,"Length of JSON array"
scalar,json_extract_path,json,"Get JSON object at path"
scalar,json_object_keys,setof text,"Get JSON object keys"
scalar,json_to_record,record,"Get JSON object as record"
scalar,jsonb_array_length,int,"Length of JSONB array"
scalar,jsonb_extract_path,jsonb,"Get JSONB object at path"
scalar,jsonb_object_keys,setof text,"Get JSONB object keys"
scalar,jsonb_to_record,record,"Get JSONB object as record"
scalar,least,same as arg types,"Least of arguments"
scalar,left,text,"Extract leftmost characters"
scalar,length,int,"String length"
scalar,ln,numeric,"Natural logarithm"
scalar,log,numeric,"Logarithm"
scalar,lower,text,"Convert to lower case"
scalar,lpad,text,"Pad string on left"
scalar,ltrim,text,"Remove characters from start"
scalar,md5,text,"Calculate MD5 hash"
scalar,mod,numeric,"Modulo (remainder)"
scalar,now,timestamp,"Current transaction timestamp"
scalar,nullif,same as arg1,"Return null if equal"
scalar,octet_length,int,"Number of bytes in string"
scalar,overlay,text,"Replace substring"
scalar,parse_ident,text[],"Parse qualified identifier"
scalar,pg_client_encoding,name,"Current client encoding"
scalar,pg_get_expr,text,"Decompile internal form of expression"
scalar,pg_get_viewdef,text,"Get view definition"
scalar,pg_typeof,regtype,"Get data type of any value"
scalar,pi,double precision,"π constant"
scalar,position,int,"Location of substring"
scalar,power,numeric,"Power"
scalar,quote_ident,text,"Quote identifier"
scalar,quote_literal,text,"Quote literal"
scalar,quote_nullable,text,"Quote nullable"
scalar,radians,double precision,"Degrees to radians"
scalar,random,double precision,"Random value"
scalar,regexp_match,text[],"Match regular expression"
scalar,regexp_matches,setof text[],"Match regular expression with flags"
scalar,regexp_replace,text,"Replace matching text"
scalar,regexp_split_to_array,text[],"Split string by pattern"
scalar,regexp_split_to_table,setof text,"Split string by pattern"
scalar,repeat,text,"Repeat string"
scalar,replace,text,"Replace substring"
scalar,reverse,text,"Reverse string"
scalar,right,text,"Extract rightmost characters"
scalar,round,numeric,"Round to nearest integer or decimal"
scalar,rpad,text,"Pad string on right"
scalar,rtrim,text,"Remove characters from end"
scalar,set_config,text,"Set parameter"
scalar,sha224,bytea,"SHA-224 hash"
scalar,sha256,bytea,"SHA-256 hash"
scalar,sha384,bytea,"SHA-384 hash"
scalar,sha512,bytea,"SHA-512 hash"
scalar,sign,numeric,"Sign of number"
scalar,sin,double precision,"Sine"
scalar,split_part,text,"Split string on delimiter"
scalar,sqrt,numeric,"Square root"
scalar,starts_with,boolean,"String starts with"
scalar,string_to_array,text[],"Split string to array"
scalar,substring,text,"Extract substring"
scalar,tan,double precision,"Tangent"
scalar,timezone,interval,"Timezone offset"
scalar,to_char,text,"Convert to string"
scalar,to_date,date,"Convert string to date"
scalar,to_hex,text,"Convert number to hex"
scalar,to_json,json,"Convert to JSON"
scalar,to_jsonb,jsonb,"Convert to JSONB"
scalar,to_number,numeric,"Convert string to number"
scalar,to_timestamp,timestamp,"Convert string to timestamp"
scalar,translate,text,"Replace characters"
scalar,trim,text,"Remove characters"
scalar,trunc,numeric,"Truncate"
scalar,upper,text,"Convert to upper case"
scalar,uuid_generate_v4,uuid,"Generate UUID v4"
window,cume_dist,double precision,"Cumulative distribution"
window,dense_rank,bigint,"Rank without gaps"
window,first_value,same as input,"First value in window"
window,lag,same as input,"Value from previous row"
window,last_value,same as input,"Last value in window"
window,lead,same as input,"Value from next row"
window,nth_value,same as input,"nth value in window"
window,ntile,integer,"Split rows into buckets"
window,percent_rank,double precision,"Relative rank"
window,rank,bigint,"Rank with gaps"
window,row_number,bigint,"Sequential number"
function_type,name,return_type,param_names,param_types,description
aggregate,every,boolean,,boolean,"Equivalent to bool_and"
aggregate,json_object_agg,json,,"text,any","Aggregates pairs into JSON object"
aggregate,jsonb_object_agg,jsonb,,"text,any","Aggregates pairs into JSONB object"
aggregate,xmlagg,xml,,xml,"Concatenates XML values"
scalar,age,interval,,"timestamp,timestamp","Difference between timestamps"
scalar,array_lower,int,,"array,integer","Lower bound of array"
scalar,array_to_json,json,,array,"Convert array to JSON"
scalar,array_upper,int,,"array,integer","Upper bound of array"
scalar,convert_from,text,,"bytea,text","Convert from encoding"
scalar,convert_to,bytea,,"text,text","Convert to encoding"
scalar,extract,numeric,,"text,timestamp","Get subfield from date/time"
scalar,format,text,,"text,array<any>","Format string"
scalar,greatest,same as arg types,,array<any>,"Greatest of arguments"
scalar,host,text,,inet,"Extract host from IP address"
scalar,isfinite,boolean,,timestamp,"Test for finite date/timestamp/interval"
scalar,json_array_length,int,,json,"Length of JSON array"
scalar,json_extract_path,json,,"json,array<text>","Get JSON object at path"
scalar,json_object_keys,setof text,,json,"Get JSON object keys"
scalar,jsonb_array_length,int,,jsonb,"Length of JSONB array"
scalar,jsonb_extract_path,jsonb,,"jsonb,array<text>","Get JSONB object at path"
scalar,jsonb_object_keys,setof text,,jsonb,"Get JSONB object keys"
scalar,least,same as arg types,,array<any>,"Least of arguments"
scalar,mod,numeric,,"numeric,numeric","Modulo (remainder)"
scalar,parse_ident,array<text>,,"text,boolean","Parse qualified identifier"
scalar,pg_client_encoding,name,,,"Current client encoding"
scalar,pg_get_expr,text,,"pg_node_tree,oid","Decompile internal form of expression"
scalar,pg_get_viewdef,text,,"oid","Get view definition"
scalar,quote_ident,text,,text,"Quote identifier"
scalar,quote_literal,text,,any,"Quote literal"
scalar,quote_nullable,text,,any,"Quote nullable"
scalar,regexp_split_to_array,array<text>,,"text,text","Split string by pattern"
scalar,regexp_split_to_table,setof text,,"text,text","Split string by pattern"
scalar,sign,numeric,,numeric,"Sign of number"
scalar,to_json,json,,boolean,"Convert to JSON"
scalar,to_number,numeric,,"text,text","Convert string to number"
scalar,unistr,varchar,,text,"Postgres: Evaluate escaped Unicode characters in the argument"
29 changes: 24 additions & 5 deletions ibis-server/tests/routers/v3/connector/postgres/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os

import orjson
import pytest
Expand All @@ -8,6 +9,9 @@
from app.main import app
from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path
from tests.routers.v3.connector.postgres.conftest import base_url
from tests.util import FunctionCsvParser, SqlTestGenerator

pytestmark = pytest.mark.functions

manifest = {
"catalog": "my_catalog",
Expand Down Expand Up @@ -57,14 +61,14 @@ def test_function_list():
response = client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT + 49
the_func = next(filter(lambda x: x["name"] == "abs", result))
assert len(result) == DATAFUSION_FUNCTION_COUNT + 36
the_func = next(filter(lambda x: x["name"] == "extract", result))
assert the_func == {
"name": "abs",
"description": "Absolute value",
"name": "extract",
"description": "Get subfield from date/time",
"function_type": "scalar",
"param_names": None,
"param_types": None,
"param_types": "text,timestamp",
"return_type": "numeric",
}

Expand Down Expand Up @@ -107,3 +111,18 @@ def test_aggregate_function(manifest_str: str, connection_info):
"data": [[1]],
"dtypes": {"col": "int64"},
}

def test_functions(manifest_str: str, connection_info):
csv_parser = FunctionCsvParser(os.path.join(function_list_path, "postgres.csv"))
sql_generator = SqlTestGenerator()
for function in csv_parser.parse():
sql = sql_generator.generate_sql(function)
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": sql,
},
)
assert response.status_code == 200
117 changes: 117 additions & 0 deletions ibis-server/tests/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import csv
from typing import Optional

from pydantic import BaseModel


class Function(BaseModel):
function_type: str
name: str
return_type: str
param_names: list[str] | None
param_types: list[str] | None
description: str


class FunctionCsvParser:
def __init__(self, file_path: str):
self.file_path = file_path

def parse(self) -> list[Function]:
with open(self.file_path, encoding="utf-8") as csvfile:
return [
Function(
function_type=row["function_type"],
name=row["name"],
return_type=row["return_type"],
param_names=self._split_param(row["param_names"]),
param_types=self._split_param(row["param_types"]),
description=row["description"],
)
for row in csv.DictReader(csvfile)
]

@staticmethod
def _split_param(param: str) -> list[str]:
return param.split(",") if param else []


class SqlTestGenerator:
def __init__(self):
pass

def generate_sql(self, function: Function) -> Optional[str]:
if function.function_type == "scalar":
return self.generate_scalar_sql(function)
elif function.function_type == "aggregate":
return self.generate_aggregate_sql(function)
elif function.function_type == "window":
return self.generate_window_sql(function)
else:
raise NotImplementedError(
f"Unsupported function type: {function.function_type}"
)

def generate_scalar_sql(self, function: Function) -> str:
args = self.generate_sample_args(function.param_types)
if function.name in ("convert_from", "convert_to"):
args[1] = "'UTF8'"
if function.name == "extract":
args[0] = "'day'"
if function.name in ("json_array_length", "jsonb_array_length"):
args[0] = '\'[{"key": "value"}]\''
if function.name in ("json_extract_path", "jsonb_extract_path"):
args[1] = "'key'"
if function.name == "to_number":
args = ["'123'", "'999'"]
formatted_args = ", ".join(args)
return f"SELECT {function.name}({formatted_args})"

def generate_aggregate_sql(self, function: Function) -> str:
sample_args = self.generate_sample_args(function.param_types)
formatted_args = ", ".join(sample_args)
sql = f"SELECT {function.name}({formatted_args})"
return sql

def generate_window_sql(self, function: Function) -> str:
# TODO: Implement window function generation
return ""

def generate_sample_args(self, param_types: list[str]) -> list[str]:
return [self.map_param_type_to_sample(p_type) for p_type in param_types]

@staticmethod
def map_param_type_to_sample(p_type: str) -> str:
p_type = p_type.lower()
if p_type in {"int", "integer", "bigint", "smallint"}:
return "1"
elif p_type in {"numeric", "decimal", "double precision", "float", "real"}:
return "1.0"
elif p_type in {"text", "varchar", "char"}:
return "'test'"
elif p_type in {"boolean", "bool"}:
return "TRUE"
elif p_type in {"date"}:
return "cast('2024-01-01' as date)"
elif p_type in {
"timestamp",
"timestamp without time zone",
"timestamp with time zone",
}:
return "cast('2024-01-01 00:00:00' as timestamp)"
elif p_type in {"time", "time without time zone", "time with time zone"}:
return "cast('12:00:00' as time)"
elif p_type == "uuid":
return "'123e4567-e89b-12d3-a456-426614174000'"
elif p_type.startswith("array"):
inner_type = p_type[p_type.find("<") + 1 : p_type.find(">")].strip()
sample_inner = SqlTestGenerator.map_param_type_to_sample(inner_type)
return f"ARRAY[{sample_inner}, {sample_inner}, {sample_inner}]"
elif p_type in {"json", "jsonb"}:
return '\'{"key": "value"}\''
elif p_type == "bytea":
return "'\\xc3a4'"
elif p_type == "interval":
return "'1 day'"
else:
return "NULL"

0 comments on commit 367b3a1

Please sign in to comment.