diff --git a/packages/openassistants/openassistants/contrib/sqlalchemy_query.py b/packages/openassistants/openassistants/contrib/sqlalchemy_query.py index f7221d3..db92ddc 100644 --- a/packages/openassistants/openassistants/contrib/sqlalchemy_query.py +++ b/packages/openassistants/openassistants/contrib/sqlalchemy_query.py @@ -1,6 +1,6 @@ import abc import asyncio -from typing import Annotated, Any, List, Literal, Sequence +from typing import Annotated, Any, List, Literal, Optional, Sequence import jsonschema import pandas as pd @@ -77,6 +77,7 @@ class QueryFunction(BaseFunction, abc.ABC): sqls: List[str] visualizations: List[str] summarization: str + data_table_output: Optional[bool] = True suggested_follow_ups: Annotated[List[SuggestedPrompt], Field(default_factory=list)] @abc.abstractmethod @@ -159,14 +160,15 @@ async def execute( results: List[FunctionOutput] = [] dataframes = await self._execute_sqls(deps) - results.extend( - [ - DataFrameOutput(dataframe=SerializedDataFrame.from_pd(df)) - for df in dataframes - ] - ) + if self.data_table_output: + results.extend( + [ + DataFrameOutput(dataframe=SerializedDataFrame.from_pd(df)) + for df in dataframes + ] + ) - yield results + yield results visualizations = await self._execute_visualizations(dataframes, deps)