Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sort/orderby for infinity-sdk and infinity-embedded-sdk #1944

Merged
merged 12 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/local_infinity/fulltext/fulltext_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void BenchmarkQuery(SharedPtr<Infinity> infinity, const String &db_name, const S
output_columns->emplace_back(select_rowid_expr);
output_columns->emplace_back(select_score_expr);
}
infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns);
infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns, nullptr);
/*
auto result = infinity->Search(db_name, table_name, search_expr, nullptr, output_columns);
{
Expand Down
2 changes: 1 addition & 1 deletion benchmark/local_infinity/infinity_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ int main() {
output_columns->emplace_back(col2);

[[maybe_unused]] auto ignored =
infinity->Search("default_db", "benchmark_test", nullptr, nullptr, nullptr, nullptr, output_columns);
infinity->Search("default_db", "benchmark_test", nullptr, nullptr, nullptr, nullptr, output_columns, nullptr);
});
results.push_back(fmt::format("-> Select QPS: {}", total_times / tims_costing_second));
}
Expand Down
2 changes: 1 addition & 1 deletion benchmark/local_infinity/knn/knn_query_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ int main(int argc, char *argv[]) {
auto select_rowid_expr = new FunctionExpr();
select_rowid_expr->func_name_ = "row_id";
output_columns->emplace_back(select_rowid_expr);
auto result = infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns);
auto result = infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns, nullptr);
{
auto &cv = result.result_table_->GetDataBlockById(0)->column_vectors;
auto &column = *cv[0];
Expand Down
3 changes: 3 additions & 0 deletions python/infinity_embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class ConflictType(object):
Error = 1
Replace = 2

class SortType(object):
Asc = 0
Desc = 1

class InfinityException(Exception):
def __init__(self, error_code=0, error_message=None):
Expand Down
5 changes: 3 additions & 2 deletions python/infinity_embedded/local_infinity/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,13 @@ def export_data(self, db_name: str, table_name: str, file_name: str, export_opti
return self.convert_res(self.client.Export(db_name, table_name, columns, file_name, export_options))

def select(self, db_name: str, table_name: str, select_list: list[WrapParsedExpr], search_expr,
where_expr, limit_expr, offset_expr, group_by_list=None):
where_expr, limit_expr, offset_expr, order_by_list: list[WrapOrderByExpr], group_by_list=None):
if self.client is None:
raise Exception("Local infinity is not connected")
return self.convert_res(self.client.Search(db_name, table_name, select_list,
order_by_list=order_by_list,
wrap_search_expr=search_expr, where_expr=where_expr,
limit_expr=limit_expr, offset_expr=offset_expr),
limit_expr=limit_expr, offset_expr=offset_expr, ),
has_result_data=True)

def explain(self, db_name: str, table_name: str, explain_type, select_list, search_expr,
Expand Down
75 changes: 74 additions & 1 deletion python/infinity_embedded/local_infinity/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def __init__(
filter: Optional[WrapParsedExpr],
limit: Optional[WrapParsedExpr],
offset: Optional[WrapParsedExpr],
sort: Optional[WrapOrderByExpr]
):
self.columns = columns
self.search = search
self.filter = filter
self.limit = limit
self.offset = offset
self.sort = sort


class ExplainQuery(Query):
Expand All @@ -44,7 +46,7 @@ def __init__(
offset: Optional[WrapParsedExpr],
explain_type: Optional[BaseExplainType],
):
super().__init__(columns, search, filter, limit, offset)
super().__init__(columns, search, filter, limit, offset, None)
self.explain_type = explain_type


Expand All @@ -56,13 +58,15 @@ def __init__(self, table):
self._filter = None
self._limit = None
self._offset = None
self._sort = []

def reset(self):
self._columns = None
self._search = None
self._filter = None
self._limit = None
self._offset = None
self._sort = []

def match_dense(
self,
Expand Down Expand Up @@ -434,13 +438,82 @@ def output(self, columns: Optional[list]) -> InfinityLocalQueryBuilder:
self._columns = select_list
return self

def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityLocalQueryBuilder:
sort_list: List[WrapOrderByExpr] = []
for order_by_expr in order_by_expr_list:
if isinstance(order_by_expr[0], str):
order_by_expr[0] = order_by_expr[0].lower()

match order_by_expr[0]:
case "*":
column_expr = WrapColumnExpr()
column_expr.star = True

parsed_expr = WrapParsedExpr(ParsedExprType.kColumn)
parsed_expr.column_expr = column_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_row_id":
func_expr = WrapFunctionExpr()
func_expr.func_name = "row_id"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_score":
func_expr = WrapFunctionExpr()
func_expr.func_name = "score"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_similarity":
func_expr = WrapFunctionExpr()
func_expr.func_name = "similarity"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_distance":
func_expr = WrapFunctionExpr()
func_expr.func_name = "distance"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case _:
parsed_expr = parse_expr(maybe_parse(order_by_expr[0]))
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)

self._sort = sort_list
return self

def to_result(self):
query = Query(
columns=self._columns,
search=self._search,
filter=self._filter,
limit=self._limit,
offset=self._offset,
sort=self._sort,
)
self.reset()
return self._table._execute_query(query)
Expand Down
18 changes: 16 additions & 2 deletions python/infinity_embedded/local_infinity/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from infinity_embedded.embedded_infinity_ext import ConflictType as LocalConflictType
from infinity_embedded.embedded_infinity_ext import WrapIndexInfo, ImportOptions, CopyFileType, WrapParsedExpr, \
ParsedExprType, WrapUpdateExpr, ExportOptions, WrapOptimizeOptions
from infinity_embedded.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN
from infinity_embedded.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN, SortType
from infinity_embedded.common import INSERT_DATA, VEC, SparseVector, InfinityException
from infinity_embedded.errors import ErrorCode
from infinity_embedded.index import IndexInfo
Expand Down Expand Up @@ -357,6 +357,19 @@ def limit(self, limit: Optional[int]):
def offset(self, offset: Optional[int]):
self.query_builder.offset(offset)
return self

def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
order_by_expr[1] = True
else :
order_by_expr[1] = False
self.query_builder.sort(order_by_expr_list)
return self

def to_df(self):
return self.query_builder.to_df()
Expand Down Expand Up @@ -398,7 +411,8 @@ def _execute_query(self, query: Query):
where_expr=query.filter,
group_by_list=None,
limit_expr=query.limit,
offset_expr=query.offset)
offset_expr=query.offset,
order_by_list=query.sort)

# process the results
if res.error_code == ErrorCode.OK:
Expand Down
3 changes: 3 additions & 0 deletions python/infinity_sdk/infinity/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class ConflictType(object):
Error = 1
Replace = 2

class SortType(object):
Asc = 0
Desc = 1

class InfinityException(Exception):
def __init__(self, error_code=0, error_message=None):
Expand Down
3 changes: 2 additions & 1 deletion python/infinity_sdk/infinity/remote_thrift/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def export_data(self, db_name: str, table_name: str, file_name: str, export_opti
export_option=export_options))

def select(self, db_name: str, table_name: str, select_list, search_expr,
where_expr, group_by_list, limit_expr, offset_expr):
where_expr, group_by_list, limit_expr, offset_expr, order_by_list):
return self.client.Select(SelectRequest(session_id=self.session_id,
db_name=db_name,
table_name=table_name,
Expand All @@ -209,6 +209,7 @@ def select(self, db_name: str, table_name: str, select_list, search_expr,
group_by_list=group_by_list,
limit_expr=limit_expr,
offset_expr=offset_expr,
order_by_list=order_by_list
))

def explain(self, db_name: str, table_name: str, select_list, search_expr,
Expand Down
54 changes: 52 additions & 2 deletions python/infinity_sdk/infinity/remote_thrift/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyarrow import Table
from sqlglot import condition, maybe_parse

from infinity.common import VEC, SparseVector, InfinityException
from infinity.common import VEC, SparseVector, InfinityException, SortType
from infinity.errors import ErrorCode
from infinity.remote_thrift.infinity_thrift_rpc.ttypes import *
from infinity.remote_thrift.types import (
Expand All @@ -45,12 +45,14 @@ def __init__(
filter: Optional[ParsedExpr],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
sort: Optional[List[OrderByExpr]],
):
self.columns = columns
self.search = search
self.filter = filter
self.limit = limit
self.offset = offset
self.sort = sort


class ExplainQuery(Query):
Expand All @@ -61,9 +63,10 @@ def __init__(
filter: Optional[ParsedExpr],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
#sort: Optional[List[OrderByExpr]],
explain_type: Optional[ExplainType],
):
super().__init__(columns, search, filter, limit, offset)
super().__init__(columns, search, filter, limit, offset, None)
self.explain_type = explain_type


Expand All @@ -75,13 +78,15 @@ def __init__(self, table):
self._filter = None
self._limit = None
self._offset = None
self._sort = None

def reset(self):
self._columns = None
self._search = None
self._filter = None
self._limit = None
self._offset = None
self._sort = None

def match_dense(
self,
Expand Down Expand Up @@ -340,6 +345,50 @@ def output(self, columns: Optional[list]) -> InfinityThriftQueryBuilder:

self._columns = select_list
return self

def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityThriftQueryBuilder:
sort_list: List[OrderByExpr] = []
for order_by_expr in order_by_expr_list:
if isinstance(order_by_expr[0], str):
order_by_expr[0] = order_by_expr[0].lower()

match order_by_expr[0]:
case "*":
column_expr = ColumnExpr(star=True, column_name=[])
expr_type = ParsedExprType(column_expr=column_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_row_id":
func_expr = FunctionExpr(function_name="row_id", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_score":
func_expr = FunctionExpr(function_name="score", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_similarity":
func_expr = FunctionExpr(function_name="similarity", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_distance":
func_expr = FunctionExpr(function_name="distance", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case _:
parsed_expr = parse_expr(maybe_parse(order_by_expr[0]))
sort_list.append(OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]))

self._sort = sort_list
return self

def to_result(self) -> tuple[dict[str, list[Any]], dict[str, Any]]:
query = Query(
Expand All @@ -348,6 +397,7 @@ def to_result(self) -> tuple[dict[str, list[Any]], dict[str, Any]]:
filter=self._filter,
limit=self._limit,
offset=self._offset,
sort=self._sort,
)
self.reset()
return self._table._execute_query(query)
Expand Down
18 changes: 16 additions & 2 deletions python/infinity_sdk/infinity/remote_thrift/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
get_ordinary_info,
)
from infinity.table import ExplainType
from infinity.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN
from infinity.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN, SortType
from infinity.utils import deprecated_api


Expand Down Expand Up @@ -376,6 +376,19 @@ def limit(self, limit: Optional[int]):
def offset(self, offset: Optional[int]):
self.query_builder.offset(offset)
return self

def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
order_by_expr[1] = True
else :
order_by_expr[1] = False
self.query_builder.sort(order_by_expr_list)
return self

def to_result(self):
return self.query_builder.to_result()
Expand Down Expand Up @@ -421,7 +434,8 @@ def _execute_query(self, query: Query) -> tuple[dict[str, list[Any]], dict[str,
where_expr=query.filter,
group_by_list=None,
limit_expr=query.limit,
offset_expr=query.offset)
offset_expr=query.offset,
order_by_list=query.sort)

# process the results
if res.error_code == ErrorCode.OK:
Expand Down
Loading
Loading