diff --git a/ibis-server/resources/function_list/bigquery.csv b/ibis-server/resources/function_list/bigquery.csv index 11e33a675..859529471 100644 --- a/ibis-server/resources/function_list/bigquery.csv +++ b/ibis-server/resources/function_list/bigquery.csv @@ -1,104 +1,65 @@ -function_type,name,return_type,description -aggregate,countif,INT64,"Counts the rows where a condition is true." -aggregate,corr,FLOAT64,"Calculates the correlation coefficient of a set of numbers." -aggregate,covar_pop,FLOAT64,"Calculates the population covariance of a set of numbers." -aggregate,covar_samp,FLOAT64,"Calculates the sample covariance of a set of numbers." -aggregate,stddev_pop,FLOAT64,"Calculates the population standard deviation of a set of numbers." -aggregate,stddev_samp,FLOAT64,"Calculates the sample standard deviation of a set of numbers." -aggregate,var_pop,FLOAT64,"Calculates the population variance of a set of numbers." -aggregate,var_samp,FLOAT64,"Calculates the sample variance of a set of numbers." -aggregate,bit_and,INT64,"Returns the bitwise AND of non-NULL input values." -aggregate,bit_or,INT64,"Returns the bitwise OR of non-NULL input values." -aggregate,bit_xor,INT64,"Returns the bitwise XOR of non-NULL input values." -aggregate,any_value,ANY,"Returns any arbitrary value from the input values." -aggregate,array_agg,ARRAY,"Aggregates values into an array." -aggregate,string_agg,STRING,"Aggregates string values with a delimiter." -aggregate,count,INT64,"Counts the number of rows." -aggregate,max,ANY,"Returns the maximum value." -aggregate,min,ANY,"Returns the minimum value." -aggregate,sum,ANY,"Returns the sum of values." -aggregate,avg,FLOAT64,"Returns the average of values." -scalar,acos,FLOAT64,"Returns the arccosine of a number." -scalar,asin,FLOAT64,"Returns the arcsine of a number." -scalar,atan,FLOAT64,"Returns the arctangent of a number." -scalar,atan2,FLOAT64,"Returns the arctangent of two numbers." -scalar,cos,FLOAT64,"Returns the cosine of a number." -scalar,cosh,FLOAT64,"Returns the hyperbolic cosine of a number." -scalar,sin,FLOAT64,"Returns the sine of a number." -scalar,sinh,FLOAT64,"Returns the hyperbolic sine of a number." -scalar,tan,FLOAT64,"Returns the tangent of a number." -scalar,tanh,FLOAT64,"Returns the hyperbolic tangent of a number." -scalar,greatest,ANY,"Returns the greatest value in a list of expressions." -scalar,least,ANY,"Returns the least value in a list of expressions." -scalar,nullifzero,FLOAT64,"Returns NULL if the input is zero." -scalar,zeroifnull,FLOAT64,"Returns zero if the input is NULL." -scalar,format,STRING,"Formats values into a string." -scalar,lpad,STRING,"Pads a string on the left to a certain length." -scalar,rpad,STRING,"Pads a string on the right to a certain length." -scalar,left,STRING,"Returns a substring from the beginning of a string." -scalar,right,STRING,"Returns a substring from the end of a string." -scalar,starts_with,BOOL,"Returns TRUE if the string starts with the specified prefix." -scalar,ends_with,BOOL,"Returns TRUE if the string ends with the specified suffix." -scalar,array_length,INT64,"Returns the length of an array." -scalar,array_reverse,ARRAY,"Reverses the elements in an array." -scalar,array_concat,ARRAY,"Concatenates multiple arrays into one." -scalar,array_to_string,STRING,"Converts an array to a single string." -scalar,generate_array,ARRAY,"Generates an array of values in a range." -scalar,generate_date_array,ARRAY,"Generates an array of dates in a range." -scalar,parse_timestamp,TIMESTAMPTZ,"Parses a timestamp from a string." -scalar,string_to_array,ARRAY,"Splits a string into an array of substrings." -scalar,safe_divide,FLOAT64,"Divides two numbers, returning NULL if the divisor is zero." -scalar,safe_multiply,FLOAT64,"Multiplies two numbers, returning NULL if an overflow occurs." -scalar,safe_add,FLOAT64,"Adds two numbers, returning NULL if an overflow occurs." -scalar,safe_subtract,FLOAT64,"Subtracts two numbers, returning NULL if an overflow occurs." -scalar,abs,FLOAT64,"Returns the absolute value of a number." -scalar,ceil,INT64,"Rounds up to the nearest integer." -scalar,floor,INT64,"Rounds down to the nearest integer." -scalar,round,FLOAT64,"Rounds to the specified number of decimal places." -scalar,trunc,FLOAT64,"Truncates to the specified number of decimal places." -scalar,pow,FLOAT64,"Returns a number raised to a power." -scalar,sqrt,FLOAT64,"Returns the square root of a number." -scalar,log,FLOAT64,"Returns the natural logarithm of a number." -scalar,log10,FLOAT64,"Returns the base-10 logarithm of a number." -scalar,concat,STRING,"Concatenates two or more strings." -scalar,lower,STRING,"Converts a string to lowercase." -scalar,upper,STRING,"Converts a string to uppercase." -scalar,trim,STRING,"Removes leading and trailing whitespace." -scalar,ltrim,STRING,"Removes leading whitespace." -scalar,rtrim,STRING,"Removes trailing whitespace." -scalar,length,INT64,"Returns the length of a string." -scalar,regexp_contains,BOOL,"Returns TRUE if the string contains a match for the regex." -scalar,regexp_extract,STRING,"Extracts the first match of the regex from the string." -scalar,regexp_replace,STRING,"Replaces all matches of the regex with a replacement string." -scalar,substr,STRING,"Returns a substring." -scalar,cast,ANY,"Converts a value to a different data type." -scalar,safe_cast,ANY,"Converts a value to a different data type, returning NULL on error." -scalar,current_date,DATE,"Returns the current date." -scalar,current_datetime,TIMESTAMP,"Returns the current date." -scalar,date_add,DATE,"Adds a specified interval to a date." -scalar,date_sub,DATE,"Subtracts a specified interval from a date." -scalar,date_diff,INT64,"Returns the difference between two dates." -scalar,date_trunc,DATE,"Truncates a date to a specified granularity." -scalar,timestamp_add,TIMESTAMPTZ,"Adds a specified interval to a timestamp." -scalar,timestamp_sub,TIMESTAMPTZ,"Subtracts a specified interval from a timestamp." -scalar,timestamp_diff,INT64,"Returns the difference between two timestamps." -scalar,timestamp_trunc,TIMESTAMPTZ,"Truncates a timestamp to a specified granularity." -scalar,timestamp_micros,TIMESTAMPTZ,"Converts the number of microseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp_millis,TIMESTAMPTZ,"Converts the number of milliseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp_seconds,TIMESTAMPTZ,"Converts the number of seconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,format_date,STRING,"Formats a date according to the specified format string." -scalar,format_timestamp,STRING,"Formats a timestamp according to the specified format string." -scalar,parse_date,DATE,"Parses a date from a string." -window,ntile,INT64,"Divides rows into a specified number of buckets and assigns a bucket number." -window,percent_rank,FLOAT64,"Calculates the percent rank of a value in a partition." -window,cume_dist,FLOAT64,"Calculates the cumulative distribution of a value in a partition." -window,nth_value,ANY,"Returns the nth value in a window partition." -window,percentile_cont,FLOAT64,"Computes a continuous percentile of a value." -window,percentile_disc,ANY,"Computes a discrete percentile of a value." -window,row_number,INT64,"Returns the sequential row number." -window,rank,INT64,"Returns the rank with gaps." -window,dense_rank,INT64,"Returns the rank without gaps." -window,lag,ANY,"Returns the value from a previous row." -window,lead,ANY,"Returns the value from a subsequent row." -window,first_value,ANY,"Returns the first value in the window frame." -window,last_value,ANY,"Returns the last value in the window frame." +function_type,name,return_type,param_names,param_types,description +aggregate,countif,int,,boolean,"Counts the rows where a condition is true." +aggregate,corr,float,,"float,float","Calculates the correlation coefficient of a set of numbers." +aggregate,covar_pop,float,,"float,float","Calculates the population covariance of a set of numbers." +aggregate,covar_samp,float,,"float,float","Calculates the sample covariance of a set of numbers." +aggregate,stddev_pop,float,,"float","Calculates the population standard deviation of a set of numbers." +aggregate,stddev_samp,float,,"float","Calculates the sample standard deviation of a set of numbers." +aggregate,var_pop,float,,"float","Calculates the population variance of a set of numbers." +aggregate,bit_and,int,,"int","Returns the bitwise AND of non-NULL input values." +aggregate,bit_or,int,,"int","Returns the bitwise OR of non-NULL input values." +aggregate,bit_xor,int,,"int","Returns the bitwise XOR of non-NULL input values." +aggregate,any_value,any,,"any","Returns any arbitrary value from the input values." +aggregate,array_agg,array,,"any","Aggregates values into an array." +aggregate,string_agg,text,,"text,text","Aggregates string values with a delimiter." +aggregate,count,int,,"any","Counts the number of rows." +aggregate,max,any,,"any","Returns the maximum value." +aggregate,min,any,,"any","Returns the minimum value." +aggregate,sum,any,,"any","Returns the sum of values." +aggregate,avg,float,,"int","Returns the average of values." +scalar,acos,float,,"float","Returns the arccosine of a number." +scalar,asin,float,,"float","Returns the arcsine of a number." +scalar,atan,float,,"float","Returns the arctangent of a number." +scalar,atan2,float,,"float,float","Returns the arctangent of two numbers." +scalar,cos,float,,"float","Returns the cosine of a number." +scalar,cosh,float,,"float","Returns the hyperbolic cosine of a number." +scalar,sin,float,,"float","Returns the sine of a number." +scalar,sinh,float,,"float","Returns the hyperbolic sine of a number." +scalar,tan,float,,"float","Returns the tangent of a number." +scalar,tanh,float,,"float","Returns the hyperbolic tangent of a number." +scalar,greatest,any,,"any","Returns the greatest value in a list of expressions." +scalar,least,any,,"any","Returns the least value in a list of expressions." +scalar,format,text,,"text","Formats values into a string." +scalar,lpad,text,,"text,int","Pads a string on the left to a certain length." +scalar,rpad,text,,"text,int","Pads a string on the right to a certain length." +scalar,left,text,,"text,int","Returns a substring from the beginning of a string." +scalar,right,text,,"text,int","Returns a substring from the end of a string." +scalar,array_length,int,,"array","Returns the length of an array." +scalar,array_reverse,array,,"array","Reverses the elements in an array." +scalar,array_concat,array,,"array,array","Concatenates multiple arrays into one." +scalar,array_to_string,text,,"array,text","Converts an array to a single string." +scalar,safe_divide,float,,"float,float","Divides two numbers, returning NULL if the divisor is zero." +scalar,safe_multiply,float,,"float,float","Multiplies two numbers, returning NULL if an overflow occurs." +scalar,safe_add,float,,"float,float","Adds two numbers, returning NULL if an overflow occurs." +scalar,safe_subtract,float,,"float,float","Subtracts two numbers, returning NULL if an overflow occurs." +scalar,abs,float,,"float","Returns the absolute value of a number." +scalar,floor,int,,"float","Rounds down to the nearest integer." +scalar,current_date,date,,"","Returns the current date." +scalar,current_datetime,timestamp,,"","Returns current date and time." +scalar,json_query,text,,"json,text","Extracts a JSON value from a JSON string." +scalar,json_value,text,,"json,text","Extracts a scalar JSON value as a string." +scalar,json_query_array,array,,"json,text","Extracts a JSON array from a JSON string." +scalar,json_value_array,array,,"json,text","Extracts an array of scalar JSON values as strings." +scalar,lax_bool,boolean,,"any","Converts a value to boolean with relaxed type checking." +scalar,lax_float64,float,,"any","Converts a value to float with relaxed type checking." +scalar,lax_int64,int,,"any","Converts a value to int with relaxed type checking." +scalar,lax_string,text,,"any","Converts a value to text with relaxed type checking." +scalar,bool,boolean,,"any","Converts a JSON value to SQL boolean type." +scalar,float64,float,,"any","Converts a JSON value to SQL float type." +scalar,int64,int,,"any","Converts a JSON value to SQL int type." +scalar,string,text,,"any","Converts a JSON value to SQL text type." +window,cume_dist,float,,"","Gets the cumulative distribution (relative position (0,1]) of each row within a window." +window,dense_rank,int,,"","Gets the dense rank (1-based, no gaps) of each row within a window." +window,percent_rank,float,,"","Gets the percentile rank (from 0 to 1) of each row within a window." +window,rank,int,,"","Gets the rank (1-based) of each row within a window." +window,row_number,int,,"","Gets the sequential row number (1-based) of each row within a window." diff --git a/ibis-server/tests/model/__init__.py b/ibis-server/tests/model/__init__.py new file mode 100644 index 000000000..258c7eaa8 --- /dev/null +++ b/ibis-server/tests/model/__init__.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class Function(BaseModel): + function_type: str + name: str + return_type: str + param_names: list[str] | None + param_types: list[str] | None + description: str diff --git a/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py b/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py index 477a8b8c5..1dd9c8115 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py @@ -1,4 +1,5 @@ import base64 +import os import orjson import pytest @@ -8,6 +9,9 @@ from app.main import app from tests.conftest import DATAFUSION_FUNCTION_COUNT from tests.routers.v3.connector.bigquery.conftest import base_url, function_list_path +from tests.util import FunctionCsvParser, SqlTestGenerator + +pytestmark = pytest.mark.functions manifest = { "catalog": "my_catalog", @@ -47,15 +51,15 @@ def test_function_list(): response = client.get(url=f"{base_url}/functions") assert response.status_code == 200 result = response.json() - assert len(result) == DATAFUSION_FUNCTION_COUNT + 34 - the_func = next(filter(lambda x: x["name"] == "abs", result)) + assert len(result) == DATAFUSION_FUNCTION_COUNT + 22 + the_func = next(filter(lambda x: x["name"] == "string_agg", result)) assert the_func == { - "name": "abs", - "description": "Returns the absolute value of a number.", - "function_type": "scalar", + "name": "string_agg", + "description": "Aggregates string values with a delimiter.", + "function_type": "aggregate", "param_names": None, - "param_types": None, - "return_type": "FLOAT64", + "param_types": "text,text", + "return_type": "text", } config.set_remote_function_list_path(None) @@ -97,3 +101,18 @@ def test_aggregate_function(manifest_str: str, connection_info): "data": [[1]], "dtypes": {"col": "int64"}, } + + def test_functions(manifest_str: str, connection_info): + csv_parser = FunctionCsvParser(os.path.join(function_list_path, "bigquery.csv")) + sql_generator = SqlTestGenerator("bigquery") + for function in csv_parser.parse(): + sql = sql_generator.generate_sql(function) + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": sql, + }, + ) + assert response.status_code == 200 diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py index 833619343..df1ce40d9 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py @@ -114,7 +114,7 @@ def test_aggregate_function(manifest_str: str, connection_info): def test_functions(manifest_str: str, connection_info): csv_parser = FunctionCsvParser(os.path.join(function_list_path, "postgres.csv")) - sql_generator = SqlTestGenerator() + sql_generator = SqlTestGenerator("postgres") for function in csv_parser.parse(): sql = sql_generator.generate_sql(function) response = client.post( diff --git a/ibis-server/tests/util.py b/ibis-server/tests/util.py deleted file mode 100644 index 08606fe98..000000000 --- a/ibis-server/tests/util.py +++ /dev/null @@ -1,117 +0,0 @@ -import csv -from typing import Optional - -from pydantic import BaseModel - - -class Function(BaseModel): - function_type: str - name: str - return_type: str - param_names: list[str] | None - param_types: list[str] | None - description: str - - -class FunctionCsvParser: - def __init__(self, file_path: str): - self.file_path = file_path - - def parse(self) -> list[Function]: - with open(self.file_path, encoding="utf-8") as csvfile: - return [ - Function( - function_type=row["function_type"], - name=row["name"], - return_type=row["return_type"], - param_names=self._split_param(row["param_names"]), - param_types=self._split_param(row["param_types"]), - description=row["description"], - ) - for row in csv.DictReader(csvfile) - ] - - @staticmethod - def _split_param(param: str) -> list[str]: - return param.split(",") if param else [] - - -class SqlTestGenerator: - def __init__(self): - pass - - def generate_sql(self, function: Function) -> Optional[str]: - if function.function_type == "scalar": - return self.generate_scalar_sql(function) - elif function.function_type == "aggregate": - return self.generate_aggregate_sql(function) - elif function.function_type == "window": - return self.generate_window_sql(function) - else: - raise NotImplementedError( - f"Unsupported function type: {function.function_type}" - ) - - def generate_scalar_sql(self, function: Function) -> str: - args = self.generate_sample_args(function.param_types) - if function.name in ("convert_from", "convert_to"): - args[1] = "'UTF8'" - if function.name == "extract": - args[0] = "'day'" - if function.name in ("json_array_length", "jsonb_array_length"): - args[0] = '\'[{"key": "value"}]\'' - if function.name in ("json_extract_path", "jsonb_extract_path"): - args[1] = "'key'" - if function.name == "to_number": - args = ["'123'", "'999'"] - formatted_args = ", ".join(args) - return f"SELECT {function.name}({formatted_args})" - - def generate_aggregate_sql(self, function: Function) -> str: - sample_args = self.generate_sample_args(function.param_types) - formatted_args = ", ".join(sample_args) - sql = f"SELECT {function.name}({formatted_args})" - return sql - - def generate_window_sql(self, function: Function) -> str: - # TODO: Implement window function generation - return "" - - def generate_sample_args(self, param_types: list[str]) -> list[str]: - return [self.map_param_type_to_sample(p_type) for p_type in param_types] - - @staticmethod - def map_param_type_to_sample(p_type: str) -> str: - p_type = p_type.lower() - if p_type in {"int", "integer", "bigint", "smallint"}: - return "1" - elif p_type in {"numeric", "decimal", "double precision", "float", "real"}: - return "1.0" - elif p_type in {"text", "varchar", "char"}: - return "'test'" - elif p_type in {"boolean", "bool"}: - return "TRUE" - elif p_type in {"date"}: - return "cast('2024-01-01' as date)" - elif p_type in { - "timestamp", - "timestamp without time zone", - "timestamp with time zone", - }: - return "cast('2024-01-01 00:00:00' as timestamp)" - elif p_type in {"time", "time without time zone", "time with time zone"}: - return "cast('12:00:00' as time)" - elif p_type == "uuid": - return "'123e4567-e89b-12d3-a456-426614174000'" - elif p_type.startswith("array"): - inner_type = p_type[p_type.find("<") + 1 : p_type.find(">")].strip() - sample_inner = SqlTestGenerator.map_param_type_to_sample(inner_type) - return f"ARRAY[{sample_inner}, {sample_inner}, {sample_inner}]" - elif p_type in {"json", "jsonb"}: - return '\'{"key": "value"}\'' - elif p_type == "bytea": - return "'\\xc3a4'" - elif p_type == "interval": - return "'1 day'" - else: - return "NULL" diff --git a/ibis-server/tests/util/__init__.py b/ibis-server/tests/util/__init__.py new file mode 100644 index 000000000..c75f2bf11 --- /dev/null +++ b/ibis-server/tests/util/__init__.py @@ -0,0 +1,7 @@ +from tests.util.csv_parser import FunctionCsvParser +from tests.util.sql_generator import SqlTestGenerator + +__all__ = [ + "FunctionCsvParser", + "SqlTestGenerator", +] diff --git a/ibis-server/tests/util/csv_parser.py b/ibis-server/tests/util/csv_parser.py new file mode 100644 index 000000000..b8014c1aa --- /dev/null +++ b/ibis-server/tests/util/csv_parser.py @@ -0,0 +1,26 @@ +import csv + +from tests.model import Function + + +class FunctionCsvParser: + def __init__(self, file_path: str): + self.file_path = file_path + + def parse(self) -> list[Function]: + with open(self.file_path, encoding="utf-8") as csvfile: + return [ + Function( + function_type=row["function_type"], + name=row["name"], + return_type=row["return_type"], + param_names=self._split_param(row["param_names"]), + param_types=self._split_param(row["param_types"]), + description=row["description"], + ) + for row in csv.DictReader(csvfile) + ] + + @staticmethod + def _split_param(param: str) -> list[str]: + return param.split(",") if param else [] diff --git a/ibis-server/tests/util/sql_generator.py b/ibis-server/tests/util/sql_generator.py new file mode 100644 index 000000000..4ccd1ab29 --- /dev/null +++ b/ibis-server/tests/util/sql_generator.py @@ -0,0 +1,145 @@ +import re +from abc import ABC +from typing import Optional + +from tests.model import Function + + +class SqlTestGenerator: + def __init__(self, dialect: str): + self.dialect = dialect + self._generator = self._get_generator() + + def generate_sql(self, function: Function) -> Optional[str]: + if function.function_type == "aggregate": + return self._generator.generate_aggregate_sql(function) + if function.function_type == "scalar": + return self._generator.generate_scalar_sql(function) + if function.function_type == "window": + return self._generator.generate_window_sql(function) + raise NotImplementedError( + f"Unsupported function type: {function.function_type}" + ) + + def _get_generator(self): + if self.dialect == "bigquery": + return BigQuerySqlGenerator() + if self.dialect == "postgres": + return PostgresSqlGenerator() + raise NotImplementedError(f"Unsupported dialect: {self.dialect}") + + @staticmethod + def map_sample_parm_by_type(p_type: str) -> str: + p_type = p_type.lower() + if p_type in {"int", "integer", "bigint", "smallint"}: + return "1" + elif p_type in {"numeric", "decimal", "double precision", "float", "real"}: + return "1.0" + elif p_type in {"text", "varchar", "char", "string"}: + return "'test'" + elif p_type in {"boolean", "bool"}: + return "TRUE" + elif p_type in {"date"}: + return "cast('2024-01-01' as date)" + elif p_type in { + "timestamp", + "timestamp without time zone", + "timestamp with time zone", + }: + return "cast('2024-01-01 00:00:00' as timestamp)" + elif p_type in {"time", "time without time zone", "time with time zone"}: + return "cast('12:00:00' as time)" + elif p_type == "uuid": + return "'123e4567-e89b-12d3-a456-426614174000'" + elif p_type.startswith("array"): + inner_type = ( + re.match(r"array<(.+)>", p_type).group(1) + if p_type.startswith("array<") and p_type.endswith(">") + else "int" + ) + element = SqlTestGenerator.map_sample_parm_by_type(inner_type) + return f"ARRAY[{element}, {element}]" + elif p_type in {"json", "jsonb"}: + return '\'{"key": "value"}\'' + elif p_type == "bytea": + return "'\\xc3a4'" + elif p_type == "interval": + return "'1 day'" + else: + return "NULL" + + +class SqlGenerator(ABC): + def generate_aggregate_sql(self, function: Function) -> str: + raise NotImplementedError + + def generate_scalar_sql(self, function: Function) -> str: + raise NotImplementedError + + def generate_window_sql(self, function: Function) -> str: + raise NotImplementedError + + @staticmethod + def generate_sample_args(param_types: list[str]) -> list[str]: + return [ + SqlTestGenerator.map_sample_parm_by_type(p_type) for p_type in param_types + ] + + +class PostgresSqlGenerator(SqlGenerator): + def generate_aggregate_sql(self, function: Function) -> str: + sample_args = self.generate_sample_args(function.param_types) + formatted_args = ", ".join(sample_args) + return f"SELECT {function.name}({formatted_args})" + + def generate_scalar_sql(self, function: Function) -> str: + args = self.generate_sample_args(function.param_types) + if function.name in ("convert_from", "convert_to"): + args[1] = "'UTF8'" + if function.name == "extract": + args[0] = "'day'" + if function.name in ("json_array_length", "jsonb_array_length"): + args[0] = '\'[{"key": "value"}]\'' + if function.name in ("json_extract_path", "jsonb_extract_path"): + args[1] = "'key'" + if function.name == "to_number": + args = ["'123'", "'999'"] + formatted_args = ", ".join(args) + return f"SELECT {function.name}({formatted_args})" + + +class BigQuerySqlGenerator(SqlGenerator): + def generate_aggregate_sql(self, function: Function) -> str: + args = self.generate_sample_args(function.param_types) + if function.name == "array_agg": + args[0] = "x" + table = "(SELECT 1 AS x UNION ALL SELECT 2)" + else: + table = "(SELECT 1)" + formatted_args = ", ".join(args) + return f"SELECT {function.name}({formatted_args}) FROM {table} AS t(x)" + + def generate_scalar_sql(self, function: Function) -> str: + args = self.generate_sample_args(function.param_types) + if function.name in ( + "json_query", + "json_value", + "json_query_array", + "json_value_array", + ): + args[1] = "'$'" + formatted_args = ", ".join(args) + return f"SELECT {function.name}({formatted_args})" + + def generate_window_sql(self, function: Function) -> str: + return f""" + SELECT + {function.name}() OVER (ORDER BY id) AS {function.name.lower()} + FROM ( + SELECT 1 AS id, 'A' AS category UNION ALL + SELECT 2 AS id, 'B' AS category UNION ALL + SELECT 3 AS id, 'A' AS category UNION ALL + SELECT 4 AS id, 'B' AS category UNION ALL + SELECT 5 AS id, 'A' AS category + ) AS t + """