Skip to content

Commit

Permalink
Fix neg function in python script (#1907)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Fix neg function in python script

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Test cases
- [x] Python SDK impacted, Need to update PyPI
  • Loading branch information
yangzq50 authored Sep 24, 2024
1 parent 6496244 commit c442e08
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 44 deletions.
21 changes: 6 additions & 15 deletions python/infinity_embedded/local_infinity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,13 @@ def traverse_conditions(cons, fn=None):
return traverse_conditions(value)

elif isinstance(cons, exp.Neg):
func_expr = WrapFunctionExpr()
func_expr.func_name = '-'
func_expr.arguments = [parse_expr(cons.hashable_args[0])]
parsed_expr = WrapParsedExpr()
if isinstance(cons.hashable_args[0], exp.Literal):
constant_expr = WrapConstantExpr()
if cons.hashable_args[0].is_int:
constant_expr.literal_type = LiteralType.kInteger
constant_expr.i64_value = -int(cons.hashable_args[0].output_name)
elif cons.hashable_args[0].is_number:
constant_expr.literal_type = LiteralType.kDouble
constant_expr.f64_value = -float(cons.hashable_args[0].output_name)
else:
raise Exception(f"unknown literal type: {cons}")

parsed_expr.type = ParsedExprType.kConstant
parsed_expr.constant_expr = constant_expr

return parsed_expr
parsed_expr.type = ParsedExprType.kFunction
parsed_expr.function_expr = func_expr
return parsed_expr
elif isinstance(cons, exp.Anonymous):
arguments = []
for arg in cons.args['expressions']:
Expand Down
24 changes: 7 additions & 17 deletions python/infinity_sdk/infinity/remote_thrift/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,13 @@ def traverse_conditions(cons, fn=None) -> ttypes.ParsedExpr:
for value in cons.hashable_args:
return traverse_conditions(value)
elif isinstance(cons, exp.Neg):
parsed_expr = ttypes.ParsedExpr()
if isinstance(cons.hashable_args[0], exp.Literal):
constant_expr = ttypes.ConstantExpr()
if cons.hashable_args[0].is_int:
constant_expr.literal_type = ttypes.LiteralType.Int64
constant_expr.i64_value = -int(cons.hashable_args[0].output_name)
elif cons.hashable_args[0].is_number:
constant_expr.literal_type = ttypes.LiteralType.Double
constant_expr.f64_value = -float(cons.hashable_args[0].output_name)
else:
raise InfinityException(ErrorCode.INVALID_EXPRESSION, f"unknown literal type: {cons}")

parser_expr_type = ttypes.ParsedExprType()
parser_expr_type.constant_expr = constant_expr
parsed_expr.type = parser_expr_type

return parsed_expr
func_expr = ttypes.FunctionExpr(
function_name='-',
arguments=[parse_expr(cons.hashable_args[0])]
)
expr_type = ttypes.ParsedExprType(function_expr=func_expr)
parsed_expr = ttypes.ParsedExpr(type=expr_type)
return parsed_expr
elif isinstance(cons, exp.Anonymous):
arguments = []
for arg in cons.args['expressions']:
Expand Down
10 changes: 10 additions & 0 deletions python/test_pysdk/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,13 @@ def test_func():
test_func()
res = db_obj.drop_table("test_filter_fulltext" + suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK

def test_neg_func(self, suffix):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_neg_func" + suffix, ConflictType.Ignore)
table_obj = db_obj.create_table("test_neg_func" + suffix, {"num": {"type": "float64"}}, ConflictType.Error)
table_obj.insert([{"num": 1.0}, {"num": 2.0}, {"num": 3.0}])
pd.testing.assert_frame_equal(table_obj.output(["-abs(num) - 1"]).filter("-abs(num) >= -2").to_df(),
pd.DataFrame({"(-(ABS(num)) - 1)": (-2.0, -3.0)}))
res = db_obj.drop_table("test_neg_func" + suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK
21 changes: 9 additions & 12 deletions src/storage/data_block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,18 @@ void DataBlock::Finalize() {
if (finalized) {
return;
}
bool first_flat_column_vector = false;
bool have_flat_column_vector = false;
SizeT row_count = 0;
for (SizeT idx = 0; idx < column_count_; ++idx) {
if (column_vectors[idx]->vector_type() == ColumnVectorType::kConstant) {
continue;
} else {
if (first_flat_column_vector) {
if (row_count != column_vectors[idx]->Size()) {
String error_message = "Column vectors in same data block have different size.";
UnrecoverableError(error_message);
}
} else {
first_flat_column_vector = true;
if (column_vectors[idx]->vector_type() != ColumnVectorType::kConstant) {
const SizeT current_row_count = column_vectors[idx]->Size();
if (have_flat_column_vector && row_count != current_row_count) {
UnrecoverableError("Column vectors in same data block have different size.");
}
row_count = column_vectors[idx]->Size();
have_flat_column_vector = true;
row_count = current_row_count;
} else if (!have_flat_column_vector) {
row_count = 1;
}
}
row_count_ = row_count;
Expand Down

0 comments on commit c442e08

Please sign in to comment.