Skip to content

Commit

Permalink
test(trino): test all functions (#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 authored Nov 29, 2024
1 parent ccda553 commit 2aba5dc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 92 deletions.
108 changes: 22 additions & 86 deletions ibis-server/resources/function_list/trino.csv
Original file line number Diff line number Diff line change
@@ -1,86 +1,22 @@
function_type,name,return_type,description
aggregate,approx_percentile,same as input,"Approximates percentile"
aggregate,approximate_distinct,bigint,"Approximates count of distinct values"
aggregate,array_agg,array,"Aggregates values into array"
aggregate,avg,double,"Returns average of values"
aggregate,corr,double,"Returns correlation coefficient"
aggregate,count,bigint,"Counts number of rows"
aggregate,covar_samp,double,"Returns sample covariance"
aggregate,map_agg,map,"Aggregates key/value pairs into map"
aggregate,max,same as input,"Returns maximum value"
aggregate,min,same as input,"Returns minimum value"
aggregate,regr_intercept,double,"Returns linear regression intercept"
aggregate,regr_slope,double,"Returns linear regression slope"
aggregate,string_agg,varchar,"Concatenates strings with delimiter"
aggregate,sum,numeric,"Returns sum of values"
scalar,abs,double,"Returns absolute value of the argument"
scalar,array_distinct,array,"Removes duplicate values from array"
scalar,array_join,varchar,"Joins array elements with delimiter"
scalar,array_sort,array,"Sorts array elements"
scalar,cardinality,bigint,"Returns size of array or map"
scalar,cbrt,double,"Returns cube root of the argument"
scalar,ceil,double/decimal,"Rounds up to nearest integer"
scalar,concat,varchar,"Concatenates given strings"
scalar,contains,boolean,"Checks if array contains element"
scalar,current_date,date,"Returns current date"
scalar,date_add,date,"Adds interval to date"
scalar,date_sub,date,"Subtracts interval from date"
scalar,date_trunc,timestamp,"Truncates timestamp to specified precision"
scalar,element_at,any,"Returns element at specified position in array"
scalar,filter,array,"Filters array using lambda"
scalar,floor,double/decimal,"Rounds down to nearest integer"
scalar,format_datetime,varchar,"Formats datetime according to format string"
scalar,from_base64,varbinary,"Converts base64 to binary"
scalar,from_hex,varbinary,"Converts hex string to binary"
scalar,from_unixtime,timestamp,"Converts unix timestamp to timestamp"
scalar,hamming_distance,bigint,"Calculates Hamming distance"
scalar,is_finite,boolean,"Tests if value is finite"
scalar,is_infinite,boolean,"Tests if value is infinite"
scalar,is_nan,boolean,"Tests if value is NaN"
scalar,json_extract,json,"Extracts JSON by JSONPath"
scalar,json_format,varchar,"Pretty prints JSON"
scalar,json_parse,json,"Parses string as JSON"
scalar,length,bigint,"Returns length of string"
scalar,levenshtein_distance,bigint,"Calculates Levenshtein distance"
scalar,lower,varchar,"Converts string to lowercase"
scalar,map_concat,map,"Concatenates two maps"
scalar,map_keys,array,"Returns array of map keys"
scalar,map_values,array,"Returns array of map values"
scalar,md5,varchar,"Computes MD5 hash"
scalar,parse_datetime,timestamp,"Parses string to datetime using format"
scalar,reduce,any,"Reduces array to single value using lambda"
scalar,regexp_extract,varchar,"Extracts substring using regex"
scalar,regexp_like,boolean,"Tests if string matches regex"
scalar,regexp_replace,varchar,"Replaces substring using regex"
scalar,replace,varchar,"Replaces substring in string"
scalar,round,double/decimal,"Rounds to nearest integer or decimal places"
scalar,sha256,varchar,"Computes SHA256 hash"
scalar,split,array,"Splits string by delimiter into array"
scalar,split_part,varchar,"Returns specific part from split string"
scalar,strpos,bigint,"Returns position of substring"
scalar,substr,varchar,"Extracts substring from string"
scalar,to_base64,varchar,"Converts binary to base64"
scalar,to_hex,varchar,"Converts number to hex string"
scalar,to_unixtime,double,"Converts timestamp to unix timestamp"
scalar,transform,array,"Applies lambda to each element"
scalar,trim,varchar,"Removes leading and trailing whitespace"
scalar,try,same as input,"Returns null if evaluation fails"
scalar,upper,varchar,"Converts string to uppercase"
scalar,url_decode,varchar,"Decodes URL encoded string"
scalar,url_encode,varchar,"URL encodes string"
scalar,uuid,varchar,"Generates random UUID"
scalar,word_stem,varchar,"Returns word stem (English only)"
scalar,xxhash64,bigint,"Computes xxHash64 hash"
scalar,zip_with,array,"Combines two arrays using lambda"
window,cume_dist,double,"Returns cumulative distribution"
window,dense_rank,bigint,"Returns rank without gaps"
window,first_value,any,"Returns first value in window"
window,lag,any,"Returns value from previous row"
window,last_value,any,"Returns last value in window"
window,lead,any,"Returns value from following row"
window,nth_value,any,"Returns nth value in window"
window,nth_value,any,"Returns value at specified row"
window,ntile,bigint,"Divides rows into buckets"
window,percent_rank,double,"Returns percent rank of row"
window,rank,bigint,"Returns rank with gaps"
window,row_number,bigint,"Returns sequential row number"
function_type,name,return_type,param_names,param_types,description
scalar,array_distinct,array,,array,Removes duplicate values from array
scalar,array_sort,array,,array,Sorts array elements
scalar,cardinality,bigint,,array or map,Returns size of array or map
scalar,cbrt,double,,double,Returns cube root of the argument
scalar,ceil,double,,double or decimal,Rounds up to nearest integer
scalar,floor,double,,double or decimal,Rounds down to nearest integer
scalar,from_base64,varbinary,,varchar,Converts base64 to binary
scalar,from_unixtime,timestamp,,double,Converts unix timestamp to timestamp
scalar,is_finite,boolean,,double or decimal,Tests if value is finite
scalar,is_infinite,boolean,,double or decimal,Tests if value is infinite
scalar,is_nan,boolean,,double or decimal,Tests if value is NaN
scalar,map_keys,array,,map,Returns array of map keys
scalar,map_values,array,,map,Returns array of map values
scalar,round,double,,double or decimal,integer,Rounds to nearest integer or decimal places
scalar,to_base64,varchar,,varbinary,Converts binary to base64
scalar,to_unixtime,double,,timestamp,Converts timestamp to unix timestamp
scalar,try,same as input,,any,Returns null if evaluation fails
scalar,upper,varchar,,varchar,Converts string to uppercase
scalar,url_decode,varchar,,varchar,Decodes URL encoded string
scalar,url_encode,varchar,,varchar,URL encodes string
scalar,word_stem,varchar,,varchar,Returns word stem (English only)
29 changes: 23 additions & 6 deletions ibis-server/tests/routers/v3/connector/trino/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os

import orjson
import pytest
Expand All @@ -8,6 +9,7 @@
from app.main import app
from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path
from tests.routers.v3.connector.trino.conftest import base_url
from tests.util import FunctionCsvParser, SqlTestGenerator

manifest = {
"catalog": "my_catalog",
Expand Down Expand Up @@ -57,15 +59,15 @@ def test_function_list():
response = client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT + 30
the_func = next(filter(lambda x: x["name"] == "abs", result))
assert len(result) == DATAFUSION_FUNCTION_COUNT + 9
the_func = next(filter(lambda x: x["name"] == "array_distinct", result))
assert the_func == {
"name": "abs",
"description": "Returns absolute value of the argument",
"name": "array_distinct",
"description": "Removes duplicate values from array",
"function_type": "scalar",
"param_names": None,
"param_types": None,
"return_type": "double",
"param_types": "array",
"return_type": "array",
}

config.set_remote_function_list_path(None)
Expand Down Expand Up @@ -107,3 +109,18 @@ def test_aggregate_function(manifest_str: str, connection_info):
"data": [[1]],
"dtypes": {"col": "int64"},
}

def test_functions(manifest_str: str, connection_info):
csv_parser = FunctionCsvParser(os.path.join(function_list_path, "trino.csv"))
sql_generator = SqlTestGenerator("trino")
for function in csv_parser.parse():
sql = sql_generator.generate_sql(function)
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": sql,
},
)
assert response.status_code == 200
27 changes: 27 additions & 0 deletions ibis-server/tests/util/sql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def _get_generator(self):
return MSSqlGenerator()
if self.dialect == "clickhouse":
return ClickhouseSqlGenerator()
if self.dialect == "trino":
return TrinoSqlGenerator()
raise NotImplementedError(f"Unsupported dialect: {self.dialect}")

@staticmethod
Expand Down Expand Up @@ -218,3 +220,28 @@ def generate_window_sql(self, function: Function) -> str:
SELECT 5 AS id, 'A' AS category
) AS t
"""


class TrinoSqlGenerator(SqlGenerator):
def generate_aggregate_sql(self, function: Function) -> str:
sample_args = self.generate_sample_args(function.param_types)
formatted_args = ", ".join(sample_args)
return f"SELECT {function.name}({formatted_args})"

def generate_scalar_sql(self, function: Function) -> str:
args = self.generate_sample_args(function.param_types)
formatted_args = ", ".join(args)
return f"SELECT {function.name}({formatted_args})"

def generate_window_sql(self, function: Function) -> str:
return f"""
SELECT
{function.name}() OVER (ORDER BY id) AS {function.name.lower()}
FROM (
SELECT 1 AS id, 'A' AS category UNION ALL
SELECT 2 AS id, 'B' AS category UNION ALL
SELECT 3 AS id, 'A' AS category UNION ALL
SELECT 4 AS id, 'B' AS category UNION ALL
SELECT 5 AS id, 'A' AS category
) AS t
"""

0 comments on commit 2aba5dc

Please sign in to comment.