Skip to content

Commit

Permalink
Use fuzzy matching when searching dbgpts (#2110)
Browse files Browse the repository at this point in the history
Co-authored-by: jiaoqiyuan <[email protected]>
  • Loading branch information
jiaoqiyuan and jiaoqiyuan authored Nov 5, 2024
1 parent 52062fd commit b4ce217
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
44 changes: 42 additions & 2 deletions dbgpt/serve/dbgpts/hub/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc

from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util.pagination_utils import PaginationResult

from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
Expand Down Expand Up @@ -109,3 +110,42 @@ def to_response(self, entity: ServeEntity) -> ServerResponse:
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

def dbgpts_list(
self,
query_request: ServeRequest,
page: int,
page_size: int,
desc_order_column: Optional[str] = None,
) -> PaginationResult[ServerResponse]:
"""Get a page of dbgpts.
Args:
query_request (ServeRequest): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
desc_order_column(Optional[str]): The column for descending order.
Returns:
PaginationResult: The pagination result.
"""
session = self.get_raw_session()
try:
query = session.query(ServeEntity)
if query_request.name:
query = query.filter(ServeEntity.name.like(f"%{query_request.name}%"))
if desc_order_column:
query = query.order_by(desc(getattr(ServeEntity, desc_order_column)))
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
finally:
session.close()

return PaginationResult(
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
4 changes: 2 additions & 2 deletions dbgpt/serve/dbgpts/hub/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def init_app(self, system_app: SystemApp) -> None:
self._system_app = system_app

@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao

Expand Down Expand Up @@ -130,7 +130,7 @@ def get_list_by_page(
installed=request.installed,
)

return self.dao.get_list_page(query_request, page, page_size)
return self.dao.dbgpts_list(query_request, page, page_size)

def refresh_hub_from_git(
self,
Expand Down

0 comments on commit b4ce217

Please sign in to comment.