From 38fedb4b917b16c39bfb533ffc86135372fbe00d Mon Sep 17 00:00:00 2001 From: Vladimir Rudnykh Date: Tue, 3 Dec 2024 19:47:38 +0700 Subject: [PATCH] Update base Func class and tests (#641) --- src/datachain/func/__init__.py | 5 + src/datachain/func/func.py | 335 ++++++++------ src/datachain/func/numeric.py | 162 +++++++ src/datachain/sql/functions/array.py | 4 + src/datachain/sql/functions/numeric.py | 43 ++ src/datachain/sql/sqlite/base.py | 69 ++- tests/unit/test_func.py | 616 +++++++++++++++++++------ 7 files changed, 945 insertions(+), 289 deletions(-) create mode 100644 src/datachain/func/numeric.py create mode 100644 src/datachain/sql/functions/numeric.py diff --git a/src/datachain/func/__init__.py b/src/datachain/func/__init__.py index 214654124..cfbbedea0 100644 --- a/src/datachain/func/__init__.py +++ b/src/datachain/func/__init__.py @@ -17,6 +17,7 @@ ) from .array import cosine_distance, euclidean_distance, length, sip_hash_64 from .conditional import greatest, least +from .numeric import bit_and, bit_or, bit_xor, int_hash_64 from .random import rand from .window import window @@ -24,6 +25,9 @@ "any_value", "array", "avg", + "bit_and", + "bit_or", + "bit_xor", "case", "collect", "concat", @@ -33,6 +37,7 @@ "euclidean_distance", "first", "greatest", + "int_hash_64", "least", "length", "literal", diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py index 747cfaf47..278a640c6 100644 --- a/src/datachain/func/func.py +++ b/src/datachain/func/func.py @@ -2,13 +2,15 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from sqlalchemy import BindParameter, Case, ColumnElement, desc +from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc from sqlalchemy.ext.hybrid import Comparator +from sqlalchemy.sql import func as sa_func from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.convert.sql_to_python import sql_to_python from datachain.lib.utils import DataChainColumnError, DataChainParamsError from datachain.query.schema import Column, ColumnMeta +from datachain.sql.functions import numeric from .base import Function @@ -98,94 +100,232 @@ def _db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: return list[col_type] if self.is_array else col_type # type: ignore[valid-type] def __add__(self, other: Union[ColT, float]) -> "Func": - return math_add(self, other) + if isinstance(other, (int, float)): + return Func("add", lambda a: a + other, [self]) + return Func("add", lambda a1, a2: a1 + a2, [self, other]) def __radd__(self, other: Union[ColT, float]) -> "Func": - return math_add(other, self) + if isinstance(other, (int, float)): + return Func("add", lambda a: other + a, [self]) + return Func("add", lambda a1, a2: a1 + a2, [other, self]) def __sub__(self, other: Union[ColT, float]) -> "Func": - return math_sub(self, other) + if isinstance(other, (int, float)): + return Func("sub", lambda a: a - other, [self]) + return Func("sub", lambda a1, a2: a1 - a2, [self, other]) def __rsub__(self, other: Union[ColT, float]) -> "Func": - return math_sub(other, self) + if isinstance(other, (int, float)): + return Func("sub", lambda a: other - a, [self]) + return Func("sub", lambda a1, a2: a1 - a2, [other, self]) def __mul__(self, other: Union[ColT, float]) -> "Func": - return math_mul(self, other) + if isinstance(other, (int, float)): + return Func("mul", lambda a: a * other, [self]) + return Func("mul", lambda a1, a2: a1 * a2, [self, other]) def __rmul__(self, other: Union[ColT, float]) -> "Func": - return math_mul(other, self) + if isinstance(other, (int, float)): + return Func("mul", lambda a: other * a, [self]) + return Func("mul", lambda a1, a2: a1 * a2, [other, self]) def __truediv__(self, other: Union[ColT, float]) -> "Func": - return math_truediv(self, other) + if isinstance(other, (int, float)): + return Func("div", lambda a: _truediv(a, other), [self], result_type=float) + return Func( + "div", lambda a1, a2: _truediv(a1, a2), [self, other], result_type=float + ) def __rtruediv__(self, other: Union[ColT, float]) -> "Func": - return math_truediv(other, self) + if isinstance(other, (int, float)): + return Func("div", lambda a: _truediv(other, a), [self], result_type=float) + return Func( + "div", lambda a1, a2: _truediv(a1, a2), [other, self], result_type=float + ) def __floordiv__(self, other: Union[ColT, float]) -> "Func": - return math_floordiv(self, other) + if isinstance(other, (int, float)): + return Func( + "floordiv", lambda a: _floordiv(a, other), [self], result_type=int + ) + return Func( + "floordiv", lambda a1, a2: _floordiv(a1, a2), [self, other], result_type=int + ) def __rfloordiv__(self, other: Union[ColT, float]) -> "Func": - return math_floordiv(other, self) + if isinstance(other, (int, float)): + return Func( + "floordiv", lambda a: _floordiv(other, a), [self], result_type=int + ) + return Func( + "floordiv", lambda a1, a2: _floordiv(a1, a2), [other, self], result_type=int + ) def __mod__(self, other: Union[ColT, float]) -> "Func": - return math_mod(self, other) + if isinstance(other, (int, float)): + return Func("mod", lambda a: a % other, [self], result_type=int) + return Func("mod", lambda a1, a2: a1 % a2, [self, other], result_type=int) def __rmod__(self, other: Union[ColT, float]) -> "Func": - return math_mod(other, self) - - def __pow__(self, other: Union[ColT, float]) -> "Func": - return math_pow(self, other) - - def __rpow__(self, other: Union[ColT, float]) -> "Func": - return math_pow(other, self) - - def __lshift__(self, other: Union[ColT, float]) -> "Func": - return math_lshift(self, other) - - def __rlshift__(self, other: Union[ColT, float]) -> "Func": - return math_lshift(other, self) - - def __rshift__(self, other: Union[ColT, float]) -> "Func": - return math_rshift(self, other) - - def __rrshift__(self, other: Union[ColT, float]) -> "Func": - return math_rshift(other, self) + if isinstance(other, (int, float)): + return Func("mod", lambda a: other % a, [self], result_type=int) + return Func("mod", lambda a1, a2: a1 % a2, [other, self], result_type=int) def __and__(self, other: Union[ColT, float]) -> "Func": - return math_and(self, other) + if isinstance(other, (int, float)): + return Func( + "and", lambda a: numeric.bit_and(a, other), [self], result_type=int + ) + return Func( + "and", + lambda a1, a2: numeric.bit_and(a1, a2), + [self, other], + result_type=int, + ) def __rand__(self, other: Union[ColT, float]) -> "Func": - return math_and(other, self) + if isinstance(other, (int, float)): + return Func( + "and", lambda a: numeric.bit_and(other, a), [self], result_type=int + ) + return Func( + "and", + lambda a1, a2: numeric.bit_and(a1, a2), + [other, self], + result_type=int, + ) def __or__(self, other: Union[ColT, float]) -> "Func": - return math_or(self, other) + if isinstance(other, (int, float)): + return Func( + "or", lambda a: numeric.bit_or(a, other), [self], result_type=int + ) + return Func( + "or", lambda a1, a2: numeric.bit_or(a1, a2), [self, other], result_type=int + ) def __ror__(self, other: Union[ColT, float]) -> "Func": - return math_or(other, self) + if isinstance(other, (int, float)): + return Func( + "or", lambda a: numeric.bit_or(other, a), [self], result_type=int + ) + return Func( + "or", lambda a1, a2: numeric.bit_or(a1, a2), [other, self], result_type=int + ) def __xor__(self, other: Union[ColT, float]) -> "Func": - return math_xor(self, other) + if isinstance(other, (int, float)): + return Func( + "xor", lambda a: numeric.bit_xor(a, other), [self], result_type=int + ) + return Func( + "xor", + lambda a1, a2: numeric.bit_xor(a1, a2), + [self, other], + result_type=int, + ) def __rxor__(self, other: Union[ColT, float]) -> "Func": - return math_xor(other, self) + if isinstance(other, (int, float)): + return Func( + "xor", lambda a: numeric.bit_xor(other, a), [self], result_type=int + ) + return Func( + "xor", + lambda a1, a2: numeric.bit_xor(a1, a2), + [other, self], + result_type=int, + ) + + def __rshift__(self, other: Union[ColT, float]) -> "Func": + if isinstance(other, (int, float)): + return Func( + "rshift", + lambda a: numeric.bit_rshift(a, other), + [self], + result_type=int, + ) + return Func( + "rshift", + lambda a1, a2: numeric.bit_rshift(a1, a2), + [self, other], + result_type=int, + ) + + def __rrshift__(self, other: Union[ColT, float]) -> "Func": + if isinstance(other, (int, float)): + return Func( + "rshift", + lambda a: numeric.bit_rshift(other, a), + [self], + result_type=int, + ) + return Func( + "rshift", + lambda a1, a2: numeric.bit_rshift(a1, a2), + [other, self], + result_type=int, + ) + + def __lshift__(self, other: Union[ColT, float]) -> "Func": + if isinstance(other, (int, float)): + return Func( + "lshift", + lambda a: numeric.bit_lshift(a, other), + [self], + result_type=int, + ) + return Func( + "lshift", + lambda a1, a2: numeric.bit_lshift(a1, a2), + [self, other], + result_type=int, + ) + + def __rlshift__(self, other: Union[ColT, float]) -> "Func": + if isinstance(other, (int, float)): + return Func( + "lshift", + lambda a: numeric.bit_lshift(other, a), + [self], + result_type=int, + ) + return Func( + "lshift", + lambda a1, a2: numeric.bit_lshift(a1, a2), + [other, self], + result_type=int, + ) def __lt__(self, other: Union[ColT, float]) -> "Func": - return math_lt(self, other) + if isinstance(other, (int, float)): + return Func("lt", lambda a: a < other, [self], result_type=bool) + return Func("lt", lambda a1, a2: a1 < a2, [self, other], result_type=bool) def __le__(self, other: Union[ColT, float]) -> "Func": - return math_le(self, other) + if isinstance(other, (int, float)): + return Func("le", lambda a: a <= other, [self], result_type=bool) + return Func("le", lambda a1, a2: a1 <= a2, [self, other], result_type=bool) def __eq__(self, other): - return math_eq(self, other) + if isinstance(other, (int, float)): + return Func("eq", lambda a: a == other, [self], result_type=bool) + return Func("eq", lambda a1, a2: a1 == a2, [self, other], result_type=bool) def __ne__(self, other): - return math_ne(self, other) + if isinstance(other, (int, float)): + return Func("ne", lambda a: a != other, [self], result_type=bool) + return Func("ne", lambda a1, a2: a1 != a2, [self, other], result_type=bool) def __gt__(self, other: Union[ColT, float]) -> "Func": - return math_gt(self, other) + if isinstance(other, (int, float)): + return Func("gt", lambda a: a > other, [self], result_type=bool) + return Func("gt", lambda a1, a2: a1 > a2, [self, other], result_type=bool) def __ge__(self, other: Union[ColT, float]) -> "Func": - return math_ge(self, other) + if isinstance(other, (int, float)): + return Func("ge", lambda a: a >= other, [self], result_type=bool) + return Func("ge", lambda a1, a2: a1 >= a2, [self, other], result_type=bool) def label(self, label: str) -> "Func": return Func( @@ -283,107 +423,12 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": ) -def math_func( - name: str, - inner: Callable, - params: Sequence[Union[ColT, float]], - result_type: Optional["DataType"] = None, -) -> Func: - """Returns math function from the columns.""" - cols, args = [], [] - for arg in params: - if isinstance(arg, (int, float)): - args.append(arg) - else: - cols.append(arg) - return Func(name, inner, cols=cols, args=args, result_type=result_type) - - -def math_add(*args: Union[ColT, float]) -> Func: - """Computes the sum of the column.""" - return math_func("add", lambda a1, a2: a1 + a2, args) - - -def math_sub(*args: Union[ColT, float]) -> Func: - """Computes the diff of the column.""" - return math_func("sub", lambda a1, a2: a1 - a2, args) - - -def math_mul(*args: Union[ColT, float]) -> Func: - """Computes the product of the column.""" - return math_func("mul", lambda a1, a2: a1 * a2, args) - - -def math_truediv(*args: Union[ColT, float]) -> Func: - """Computes the division of the column.""" - return math_func("div", lambda a1, a2: a1 / a2, args, result_type=float) - - -def math_floordiv(*args: Union[ColT, float]) -> Func: - """Computes the floor division of the column.""" - return math_func("floordiv", lambda a1, a2: a1 // a2, args, result_type=float) - - -def math_mod(*args: Union[ColT, float]) -> Func: - """Computes the modulo of the column.""" - return math_func("mod", lambda a1, a2: a1 % a2, args, result_type=float) - - -def math_pow(*args: Union[ColT, float]) -> Func: - """Computes the power of the column.""" - return math_func("pow", lambda a1, a2: a1**a2, args, result_type=float) - - -def math_lshift(*args: Union[ColT, float]) -> Func: - """Computes the left shift of the column.""" - return math_func("lshift", lambda a1, a2: a1 << a2, args, result_type=int) - - -def math_rshift(*args: Union[ColT, float]) -> Func: - """Computes the right shift of the column.""" - return math_func("rshift", lambda a1, a2: a1 >> a2, args, result_type=int) - - -def math_and(*args: Union[ColT, float]) -> Func: - """Computes the logical AND of the column.""" - return math_func("and", lambda a1, a2: a1 & a2, args, result_type=bool) - - -def math_or(*args: Union[ColT, float]) -> Func: - """Computes the logical OR of the column.""" - return math_func("or", lambda a1, a2: a1 | a2, args, result_type=bool) - - -def math_xor(*args: Union[ColT, float]) -> Func: - """Computes the logical XOR of the column.""" - return math_func("xor", lambda a1, a2: a1 ^ a2, args, result_type=bool) - - -def math_lt(*args: Union[ColT, float]) -> Func: - """Computes the less than comparison of the column.""" - return math_func("lt", lambda a1, a2: a1 < a2, args, result_type=bool) - - -def math_le(*args: Union[ColT, float]) -> Func: - """Computes the less than or equal comparison of the column.""" - return math_func("le", lambda a1, a2: a1 <= a2, args, result_type=bool) - - -def math_eq(*args: Union[ColT, float]) -> Func: - """Computes the equality comparison of the column.""" - return math_func("eq", lambda a1, a2: a1 == a2, args, result_type=bool) - - -def math_ne(*args: Union[ColT, float]) -> Func: - """Computes the inequality comparison of the column.""" - return math_func("ne", lambda a1, a2: a1 != a2, args, result_type=bool) - - -def math_gt(*args: Union[ColT, float]) -> Func: - """Computes the greater than comparison of the column.""" - return math_func("gt", lambda a1, a2: a1 > a2, args, result_type=bool) +def _truediv(a, b): + # Using sqlalchemy.sql.func.divide here instead of / operator + # because of a bug in ClickHouse SQLAlchemy dialect + # See https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/335 + return sa_func.divide(a, b) -def math_ge(*args: Union[ColT, float]) -> Func: - """Computes the greater than or equal comparison of the column.""" - return math_func("ge", lambda a1, a2: a1 >= a2, args, result_type=bool) +def _floordiv(a, b): + return cast(_truediv(a, b), Integer) diff --git a/src/datachain/func/numeric.py b/src/datachain/func/numeric.py new file mode 100644 index 000000000..a26e7ff51 --- /dev/null +++ b/src/datachain/func/numeric.py @@ -0,0 +1,162 @@ +from typing import Union + +from datachain.sql.functions import numeric + +from .func import ColT, Func + + +def bit_and(*args: Union[ColT, int]) -> Func: + """ + Computes the bitwise AND operation between two values. + + Args: + args (str | int): Two values to compute the bitwise AND operation between. + If a string is provided, it is assumed to be the name of the column vector. + If an integer is provided, it is assumed to be a constant value. + + Returns: + Func: A Func object that represents the bitwise AND function. + + Example: + ```py + dc.mutate( + xor1=func.bit_and("signal.values", 0x0F), + ) + ``` + + Notes: + - Result column will always be of type int. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, int): + func_args.append(arg) + else: + cols.append(arg) + + if len(cols) + len(func_args) != 2: + raise ValueError("bit_and() requires exactly two arguments") + + return Func( + "bit_and", + inner=numeric.bit_and, + cols=cols, + args=func_args, + result_type=int, + ) + + +def bit_or(*args: Union[ColT, int]) -> Func: + """ + Computes the bitwise AND operation between two values. + + Args: + args (str | int): Two values to compute the bitwise OR operation between. + If a string is provided, it is assumed to be the name of the column vector. + If an integer is provided, it is assumed to be a constant value. + + Returns: + Func: A Func object that represents the bitwise OR function. + + Example: + ```py + dc.mutate( + xor1=func.bit_or("signal.values", 0x0F), + ) + ``` + + Notes: + - Result column will always be of type int. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, int): + func_args.append(arg) + else: + cols.append(arg) + + if len(cols) + len(func_args) != 2: + raise ValueError("bit_or() requires exactly two arguments") + + return Func( + "bit_or", + inner=numeric.bit_or, + cols=cols, + args=func_args, + result_type=int, + ) + + +def bit_xor(*args: Union[ColT, int]) -> Func: + """ + Computes the bitwise XOR operation between two values. + + Args: + args (str | int): Two values to compute the bitwise XOR operation between. + If a string is provided, it is assumed to be the name of the column vector. + If an integer is provided, it is assumed to be a constant value. + + Returns: + Func: A Func object that represents the bitwise XOR function. + + Example: + ```py + dc.mutate( + xor1=func.bit_xor("signal.values", 0x0F), + ) + ``` + + Notes: + - Result column will always be of type int. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, int): + func_args.append(arg) + else: + cols.append(arg) + + if len(cols) + len(func_args) != 2: + raise ValueError("bit_xor() requires exactly two arguments") + + return Func( + "bit_xor", + inner=numeric.bit_xor, + cols=cols, + args=func_args, + result_type=int, + ) + + +def int_hash_64(col: Union[ColT, int]) -> Func: + """ + Returns the 64-bit hash of an integer. + + Args: + col (str | int): String to compute the hash of. + If a string is provided, it is assumed to be the name of the column. + If a int is provided, it is assumed to be an int literal. + If a Func is provided, it is assumed to be a function returning an int. + + Returns: + Func: A Func object that represents the 64-bit hash function. + + Example: + ```py + dc.mutate( + val_hash=func.int_hash_64("val"), + ) + ``` + + Note: + - Result column will always be of type int. + """ + cols, args = [], [] + if isinstance(col, int): + args.append(col) + else: + cols.append(col) + + return Func( + "int_hash_64", inner=numeric.int_hash_64, cols=cols, args=args, result_type=int + ) diff --git a/src/datachain/sql/functions/array.py b/src/datachain/sql/functions/array.py index 567162fe6..ab7cfc8ad 100644 --- a/src/datachain/sql/functions/array.py +++ b/src/datachain/sql/functions/array.py @@ -38,6 +38,10 @@ class length(GenericFunction): # noqa: N801 class sip_hash_64(GenericFunction): # noqa: N801 + """ + Computes the SipHash-64 hash of the array. + """ + type = Int64() package = "hash" name = "sip_hash_64" diff --git a/src/datachain/sql/functions/numeric.py b/src/datachain/sql/functions/numeric.py new file mode 100644 index 000000000..2a7a82d6c --- /dev/null +++ b/src/datachain/sql/functions/numeric.py @@ -0,0 +1,43 @@ +from sqlalchemy.sql.functions import GenericFunction, ReturnTypeFromArgs + +from datachain.sql.types import Int64 +from datachain.sql.utils import compiler_not_implemented + + +class bit_and(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True + + +class bit_or(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True + + +class bit_xor(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True + + +class bit_rshift(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True + + +class bit_lshift(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True + + +class int_hash_64(GenericFunction): # noqa: N801 + """ + Computes the 64-bit hash of an integer. + """ + + type = Int64() + package = "hash" + name = "int_hash_64" + inherit_cache = True + + +compiler_not_implemented(bit_and) +compiler_not_implemented(bit_or) +compiler_not_implemented(bit_xor) +compiler_not_implemented(bit_rshift) +compiler_not_implemented(bit_lshift) +compiler_not_implemented(int_hash_64) diff --git a/src/datachain/sql/sqlite/base.py b/src/datachain/sql/sqlite/base.py index 4599ed10c..d5589cf85 100644 --- a/src/datachain/sql/sqlite/base.py +++ b/src/datachain/sql/sqlite/base.py @@ -15,7 +15,14 @@ from sqlalchemy.sql.expression import case from sqlalchemy.sql.functions import func -from datachain.sql.functions import aggregate, array, conditional, random, string +from datachain.sql.functions import ( + aggregate, + array, + conditional, + numeric, + random, + string, +) from datachain.sql.functions import path as sql_path from datachain.sql.selectable import Values, base_values_compiler from datachain.sql.sqlite.types import ( @@ -47,6 +54,8 @@ empty_str = literal("") dot = literal(".") +MAX_INT64 = 2**64 - 1 + def setup(): global setup_is_complete # noqa: PLW0603 @@ -89,6 +98,12 @@ def setup(): compiles(aggregate.group_concat, "sqlite")(compile_group_concat) compiles(aggregate.any_value, "sqlite")(compile_any_value) compiles(aggregate.collect, "sqlite")(compile_collect) + compiles(numeric.bit_and, "sqlite")(compile_bitwise_and) + compiles(numeric.bit_or, "sqlite")(compile_bitwise_or) + compiles(numeric.bit_xor, "sqlite")(compile_bitwise_xor) + compiles(numeric.bit_rshift, "sqlite")(compile_bitwise_rshift) + compiles(numeric.bit_lshift, "sqlite")(compile_bitwise_lshift) + compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64) if load_usearch_extension(sqlite3.connect(":memory:")): compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext) @@ -163,6 +178,19 @@ def sqlite_string_split(string: str, sep: str, maxsplit: int = -1) -> str: return orjson.dumps(string.split(sep, maxsplit)).decode("utf-8") +def sqlite_int_hash_64(x: int) -> int: + """IntHash64 implementation from ClickHouse.""" + x ^= 0x4CF2D2BAAE6DA887 + x ^= x >> 33 + x = (x * 0xFF51AFD7ED558CCD) & MAX_INT64 + x ^= x >> 33 + x = (x * 0xC4CEB9FE1A85EC53) & MAX_INT64 + x ^= x >> 33 + # SQLite does not support unsigned 64-bit integers, + # so we need to convert to signed 64-bit + return x if x < 1 << 63 else (x & MAX_INT64) - (1 << 64) + + def register_user_defined_sql_functions() -> None: # Register optional functions if we have the necessary dependencies # and otherwise register functions that will raise an exception with @@ -185,6 +213,21 @@ def create_vector_functions(conn): _registered_function_creators["vector_functions"] = create_vector_functions + def create_numeric_functions(conn): + conn.create_function("divide", 2, lambda a, b: a / b, deterministic=True) + conn.create_function("bitwise_and", 2, lambda a, b: a & b, deterministic=True) + conn.create_function("bitwise_or", 2, lambda a, b: a | b, deterministic=True) + conn.create_function("bitwise_xor", 2, lambda a, b: a ^ b, deterministic=True) + conn.create_function( + "bitwise_rshift", 2, lambda a, b: a >> b, deterministic=True + ) + conn.create_function( + "bitwise_lshift", 2, lambda a, b: a << b, deterministic=True + ) + conn.create_function("int_hash_64", 1, sqlite_int_hash_64, deterministic=True) + + _registered_function_creators["numeric_functions"] = create_numeric_functions + def sqlite_regexp_replace(string: str, pattern: str, replacement: str) -> str: return re.sub(pattern, replacement, string) @@ -316,6 +359,30 @@ def compile_euclidean_distance(element, compiler, **kwargs): return f"euclidean_distance({compiler.process(element.clauses, **kwargs)})" +def compile_bitwise_and(element, compiler, **kwargs): + return compiler.process(func.bitwise_and(*element.clauses.clauses), **kwargs) + + +def compile_bitwise_or(element, compiler, **kwargs): + return compiler.process(func.bitwise_or(*element.clauses.clauses), **kwargs) + + +def compile_bitwise_xor(element, compiler, **kwargs): + return compiler.process(func.bitwise_xor(*element.clauses.clauses), **kwargs) + + +def compile_bitwise_rshift(element, compiler, **kwargs): + return compiler.process(func.bitwise_rshift(*element.clauses.clauses), **kwargs) + + +def compile_bitwise_lshift(element, compiler, **kwargs): + return compiler.process(func.bitwise_lshift(*element.clauses.clauses), **kwargs) + + +def compile_int_hash_64(element, compiler, **kwargs): + return compiler.process(func.int_hash_64(*element.clauses.clauses), **kwargs) + + def py_json_array_length(arr): return len(orjson.loads(arr)) diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index 3ca009e12..7ca4d3dde 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -1,21 +1,30 @@ import pytest from sqlalchemy import Label -from datachain import func +from datachain import DataChain +from datachain.func import int_hash_64 +from datachain.func.random import rand +from datachain.func.string import length as strlen from datachain.lib.signal_schema import SignalSchema +from datachain.sql.sqlite.base import sqlite_int_hash_64 -@pytest.fixture -def rnd(): - return func.random.rand() +@pytest.fixture() +def dc(): + return DataChain.from_values( + num=list(range(1, 6)), + val=["x" * i for i in range(1, 6)], + ) -def test_db_cols(rnd): +def test_db_cols(): + rnd = rand() assert rnd._db_cols == [] assert rnd._db_col_type(SignalSchema({})) is None -def test_label(rnd): +def test_label(): + rnd = rand() assert rnd.col_label is None assert rnd.label("test2") == "test2" @@ -24,233 +33,554 @@ def test_label(rnd): assert f.label("test2") == "test2" -def test_col_name(rnd): +def test_col_name(): + rnd = rand() assert rnd.get_col_name() == "rand" assert rnd.label("test").get_col_name() == "test" assert rnd.get_col_name("test2") == "test2" -def test_result_type(rnd): +def test_result_type(): + rnd = rand() assert rnd.get_result_type(SignalSchema({})) is int -def test_get_column(rnd): +def test_get_column(): + rnd = rand() col = rnd.get_column(SignalSchema({})) assert isinstance(col, Label) assert col.name == "rand" -def test_add(rnd): - f = rnd + 1 +def test_add(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 + 1 + assert str(f) == "add()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 1 + rnd2 assert str(f) == "add()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] - f = 1 + rnd + f = rnd1 + rnd2 assert str(f) == "add()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_add_mutate(dc): + res = dc.mutate(test=strlen("val") + 1).order_by("num").collect("test") + assert list(res) == [2, 3, 4, 5, 6] + + res = dc.mutate(test=1 + strlen("val")).order_by("num").collect("test") + assert list(res) == [2, 3, 4, 5, 6] + + res = dc.mutate(test=strlen("val") + strlen("val")).order_by("num").collect("test") + assert list(res) == [2, 4, 6, 8, 10] + +def test_sub(): + rnd1, rnd2 = rand(), rand() -def test_sub(rnd): - f = rnd - 1 + f = rnd1 - 1 assert str(f) == "sub()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = 1 - rnd + f = 1 - rnd2 assert str(f) == "sub()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + f = rnd1 - rnd2 + assert str(f) == "sub()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_sub_mutate(dc): + res = dc.mutate(test=strlen("val") - 1).order_by("num").collect("test") + assert list(res) == [0, 1, 2, 3, 4] + + res = dc.mutate(test=5 - strlen("val")).order_by("num").collect("test") + assert list(res) == [4, 3, 2, 1, 0] -def test_mul(rnd): - f = rnd * 1 + res = dc.mutate(test=strlen("val") - strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] + + +def test_mul(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 * 2 assert str(f) == "mul()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = 1 * rnd + f = 2 * rnd2 assert str(f) == "mul()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + f = rnd1 * rnd2 + assert str(f) == "mul()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] -def test_realdiv(rnd): - f = rnd / 1 + +def test_mul_mutate(dc): + res = dc.mutate(test=strlen("val") * 2).order_by("num").collect("test") + assert list(res) == [2, 4, 6, 8, 10] + + res = dc.mutate(test=3 * strlen("val")).order_by("num").collect("test") + assert list(res) == [3, 6, 9, 12, 15] + + res = dc.mutate(test=strlen("val") * strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 4, 9, 16, 25] + + +def test_truediv(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 / 2 + assert str(f) == "div()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 1 / rnd2 assert str(f) == "div()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] - f = 1 / rnd + f = rnd1 / rnd2 assert str(f) == "div()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1, rnd2] + assert f.args == [] -def test_floordiv(rnd): - f = rnd // 1 +def test_truediv_mutate(dc): + res = dc.mutate(test=strlen("val") / 2).order_by("num").collect("test") + assert list(res) == [0.5, 1.0, 1.5, 2.0, 2.5] + + res = dc.mutate(test=10 / strlen("val")).order_by("num").collect("test") + assert list(res) == [10.0, 5.0, 10 / 3, 2.5, 2.0] + + res = dc.mutate(test=strlen("val") / strlen("val")).order_by("num").collect("test") + assert list(res) == [1.0, 1.0, 1.0, 1.0, 1.0] + + +def test_floordiv(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 // 2 assert str(f) == "floordiv()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = 1 // rnd + f = 1 // rnd2 assert str(f) == "floordiv()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + f = rnd1 // rnd2 + assert str(f) == "floordiv()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] -def test_mod(rnd): - f = rnd % 1 - assert str(f) == "mod()" - assert f.cols == [rnd] - assert f.args == [1] - f = 1 % rnd - assert str(f) == "mod()" - assert f.cols == [rnd] - assert f.args == [1] +def test_floordiv_mutate(dc): + res = dc.mutate(test=strlen("val") // 2).order_by("num").collect("test") + assert list(res) == [0, 1, 1, 2, 2] + res = dc.mutate(test=10 // strlen("val")).order_by("num").collect("test") + assert list(res) == [10, 5, 3, 2, 2] -def test_pow(rnd): - f = rnd**1 - assert str(f) == "pow()" - assert f.cols == [rnd] - assert f.args == [1] + res = dc.mutate(test=strlen("val") // strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 1, 1] - f = 1**rnd - assert str(f) == "pow()" - assert f.cols == [rnd] - assert f.args == [1] +def test_mod(): + rnd1, rnd2 = rand(), rand() -def test_lshift(rnd): - f = rnd << 1 - assert str(f) == "lshift()" - assert f.cols == [rnd] - assert f.args == [1] + f = rnd1 % 2 + assert str(f) == "mod()" + assert f.cols == [rnd1] + assert f.args == [] - f = 1 << rnd - assert str(f) == "lshift()" - assert f.cols == [rnd] - assert f.args == [1] + f = 10 % rnd2 + assert str(f) == "mod()" + assert f.cols == [rnd2] + assert f.args == [] + f = rnd1 % rnd2 + assert str(f) == "mod()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] -def test_rshift(rnd): - f = rnd >> 1 - assert str(f) == "rshift()" - assert f.cols == [rnd] - assert f.args == [1] - f = 1 >> rnd - assert str(f) == "rshift()" - assert f.cols == [rnd] - assert f.args == [1] +def test_mod_mutate(dc): + res = dc.mutate(test=strlen("val") % 2).order_by("num").collect("test") + assert list(res) == [1, 0, 1, 0, 1] + res = dc.mutate(test=10 % strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 1, 2, 0] -def test_and(rnd): - f = rnd & 1 + res = dc.mutate(test=strlen("val") % strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] + + +def test_and(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 & 2 + assert str(f) == "and()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 2 & rnd2 assert str(f) == "and()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] - f = 1 & rnd + f = rnd1 & rnd2 assert str(f) == "and()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_and_mutate(dc): + res = dc.mutate(test=strlen("val") & 2).order_by("num").collect("test") + assert list(res) == [0, 2, 2, 0, 0] + + res = dc.mutate(test=2 & strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 2, 2, 0, 0] + res = dc.mutate(test=strlen("val") & strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 2, 3, 4, 5] -def test_or(rnd): - f = rnd | 1 + +def test_or(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 | 2 + assert str(f) == "or()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 2 | rnd2 assert str(f) == "or()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] - f = 1 | rnd + f = rnd1 | rnd2 assert str(f) == "or()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_or_mutate(dc): + res = dc.mutate(test=strlen("val") | 2).order_by("num").collect("test") + assert list(res) == [3, 2, 3, 6, 7] + + res = dc.mutate(test=2 | strlen("val")).order_by("num").collect("test") + assert list(res) == [3, 2, 3, 6, 7] + + res = dc.mutate(test=strlen("val") | strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 2, 3, 4, 5] -def test_xor(rnd): - f = rnd ^ 1 +def test_xor(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 ^ 2 + assert str(f) == "xor()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 2 ^ rnd2 assert str(f) == "xor()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] - f = 1 ^ rnd + f = rnd1 ^ rnd2 assert str(f) == "xor()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_xor_mutate(dc): + res = dc.mutate(test=strlen("val") ^ 2).order_by("num").collect("test") + assert list(res) == [3, 0, 1, 6, 7] + + res = dc.mutate(test=2 ^ strlen("val")).order_by("num").collect("test") + assert list(res) == [3, 0, 1, 6, 7] + + res = dc.mutate(test=strlen("val") ^ strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] -def test_lt(rnd): - f = rnd < 1 +def test_rshift(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 >> 2 + assert str(f) == "rshift()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 2 >> rnd2 + assert str(f) == "rshift()" + assert f.cols == [rnd2] + assert f.args == [] + + f = rnd1 >> rnd2 + assert str(f) == "rshift()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_rshift_mutate(dc): + res = dc.mutate(test=strlen("val") >> 2).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 1, 1] + + res = dc.mutate(test=2 >> strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 0, 0, 0, 0] + + res = dc.mutate(test=strlen("val") >> strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] + + +def test_lshift(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 << 2 + assert str(f) == "lshift()" + assert f.cols == [rnd1] + assert f.args == [] + + f = 2 << rnd2 + assert str(f) == "lshift()" + assert f.cols == [rnd2] + assert f.args == [] + + f = rnd1 << rnd2 + assert str(f) == "lshift()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_lshift_mutate(dc): + res = dc.mutate(test=strlen("val") << 2).order_by("num").collect("test") + assert list(res) == [4, 8, 12, 16, 20] + + res = dc.mutate(test=2 << strlen("val")).order_by("num").collect("test") + assert list(res) == [4, 8, 16, 32, 64] + + res = dc.mutate(test=strlen("val") << strlen("val")).order_by("num").collect("test") + assert list(res) == [2, 8, 24, 64, 160] + + +def test_lt(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 < 1 assert str(f) == "lt()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = rnd > 1 + f = rnd2 > 1 assert str(f) == "gt()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + + f = rnd1 < rnd2 + assert str(f) == "lt()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_lt_mutate(dc): + res = dc.mutate(test=strlen("val") < 3).order_by("num").collect("test") + assert list(res) == [1, 1, 0, 0, 0] + + res = dc.mutate(test=strlen("val") > 3).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 1, 1] + + res = dc.mutate(test=strlen("val") < strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] -def test_le(rnd): - f = rnd <= 1 +def test_le(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 <= 1 assert str(f) == "le()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = rnd >= 1 + f = rnd2 >= 1 assert str(f) == "ge()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + + f = rnd1 <= rnd2 + assert str(f) == "le()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_le_mutate(dc): + res = dc.mutate(test=strlen("val") <= 3).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 0, 0] + + res = dc.mutate(test=strlen("val") >= 3).order_by("num").collect("test") + assert list(res) == [0, 0, 1, 1, 1] + + res = dc.mutate(test=strlen("val") <= strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 1, 1] + +def test_eq(): + rnd1, rnd2 = rand(), rand() -def test_eq(rnd): - f = rnd == 1 + f = rnd1 == 1 assert str(f) == "eq()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = rnd == 1 + f = rnd2 == 1 assert str(f) == "eq()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + f = rnd1 == rnd2 + assert str(f) == "eq()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_eq_mutate(dc): + res = dc.mutate(test=strlen("val") == 2).order_by("num").collect("test") + assert list(res) == [0, 1, 0, 0, 0] + + res = dc.mutate(test=strlen("val") == 4).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 1, 0] -def test_ne(rnd): - f = rnd != 1 + res = dc.mutate(test=strlen("val") == strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 1, 1] + + +def test_ne(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 != 1 assert str(f) == "ne()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = rnd != 1 + f = rnd2 != 1 assert str(f) == "ne()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + f = rnd1 != rnd2 + assert str(f) == "ne()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] -def test_gt(rnd): - f = rnd > 1 + +def test_ne_mutate(dc): + res = dc.mutate(test=strlen("val") != 2).order_by("num").collect("test") + assert list(res) == [1, 0, 1, 1, 1] + + res = dc.mutate(test=strlen("val") != 4).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 0, 1] + + res = dc.mutate(test=strlen("val") != strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] + + +def test_gt(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 > 1 assert str(f) == "gt()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = rnd < 1 + f = rnd2 < 1 assert str(f) == "lt()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + + f = rnd1 > rnd2 + assert str(f) == "gt()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] -def test_ge(rnd): - f = rnd >= 1 +def test_gt_mutate(dc): + res = dc.mutate(test=strlen("val") > 2).order_by("num").collect("test") + assert list(res) == [0, 0, 1, 1, 1] + + res = dc.mutate(test=strlen("val") < 4).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 0, 0] + + res = dc.mutate(test=strlen("val") > strlen("val")).order_by("num").collect("test") + assert list(res) == [0, 0, 0, 0, 0] + + +def test_ge(): + rnd1, rnd2 = rand(), rand() + + f = rnd1 >= 1 assert str(f) == "ge()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd1] + assert f.args == [] - f = rnd <= 1 + f = rnd2 <= 1 assert str(f) == "le()" - assert f.cols == [rnd] - assert f.args == [1] + assert f.cols == [rnd2] + assert f.args == [] + + f = rnd1 >= rnd2 + assert str(f) == "ge()" + assert f.cols == [rnd1, rnd2] + assert f.args == [] + + +def test_ge_mutate(dc): + res = dc.mutate(test=strlen("val") >= 2).order_by("num").collect("test") + assert list(res) == [0, 1, 1, 1, 1] + + res = dc.mutate(test=strlen("val") <= 4).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 1, 0] + + res = dc.mutate(test=strlen("val") >= strlen("val")).order_by("num").collect("test") + assert list(res) == [1, 1, 1, 1, 1] + + +@pytest.mark.parametrize( + "value,inthash", + [ + [0, 4761183170873013810], + [1, 10577349846663553072 - (1 << 64)], + [5, 15228578409069794350 - (1 << 64)], + [123456, 13379111408315310133 - (1 << 64)], + ], +) +def test_sqlite_int_hash_64(value, inthash): + assert sqlite_int_hash_64(value) == inthash + + +def test_int_hash_64_mutate(dc): + res = dc.mutate(test=int_hash_64(strlen("val"))).order_by("num").collect("test") + assert [x & (2**64 - 1) for x in res] == [ + 10577349846663553072, + 18198135717204167749, + 9624464864560415994, + 7766709361750702608, + 15228578409069794350, + ]