-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* fix(connectors): make sqlalchemy non-optional and check yfinance imports (#561) * feat: use duckdb for the cache * Release v1.2.2 * feat: `output_type` parameter (#519) * (feat): update prompt template in `GeneratePythonCodePrompt`, add `output_type_hint` variable to be interpolated * (feat): update `.chat()` method for `SmartDataframe` and `SmartDatalake`, add optional `output_type` parameter * (feat): add `get_output_type_hint()` in `GeneratePythonCodePrompt` class * (feat): add "output_type_hint" to `default_values` when forming prompt's template context * (tests): update tests in `TestGeneratePythonCodePrompt` * (tests): add tests for checking `output_type` interpotaion to a prompt * refactor: `output_type` parameter (#519) * (refactor): update setting value for `{output_type_hint}` in prompt class * fix: `output_type` parameter (#519) * (tests): fix error in `TestGeneratePythonCodePrompt` with confused actual prompt's content and excepted prompt's content (which led to tests being failed) * (refactor): update test method `test_str_with_args()` with `parametrize` decorator, remove duplication of code (DRY) * tests: `output_type` parameter (#519) * (tests): parametrizing `test_run_passing_output_type()` with different output types. * refactor: `output_type` parameter (#519) * (refactor): move output types and their hints to a separate python module * (feat): validation for inappropriate output type and value * (tests): update tests accordingly, add a new test method to check if the message about incorrect output type is added to logs * chore: `output_type` parameter (#519) * (chore): remove unused lines * refactor: `output_type` parameter (#519) * (refactor): pass `output_type_hint` in the `default_values`, don't bother with setting this template variable in prompt's `__init__()` * (fix): correct templates examples, add them as a part of `output_type_hint` * (refactor): use `df_type()` (utility functions) to get rid from imports of third packages in _output_types.py * (refactor): add logging to the factory-function `output_type_factory()` to enhance verbosity for behaviour inspection ---------
- Loading branch information
1 parent
b489c9c
commit 69e2b14
Showing
14 changed files
with
939 additions
and
499 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import logging | ||
from typing import Union, Optional | ||
|
||
from ._output_types import ( | ||
NumberOutputType, | ||
DataFrameOutputType, | ||
PlotOutputType, | ||
StringOutputType, | ||
DefaultOutputType, | ||
) | ||
from .. import Logger | ||
|
||
|
||
output_types_map = { | ||
"number": NumberOutputType, | ||
"dataframe": DataFrameOutputType, | ||
"plot": PlotOutputType, | ||
"string": StringOutputType, | ||
} | ||
|
||
|
||
def output_type_factory( | ||
output_type: str = None, logger: Optional[Logger] = None | ||
) -> Union[ | ||
NumberOutputType, | ||
DataFrameOutputType, | ||
PlotOutputType, | ||
StringOutputType, | ||
DefaultOutputType, | ||
]: | ||
""" | ||
Factory function to get appropriate instance for output type. | ||
Uses `output_types_map` to determine the output type class. | ||
Args: | ||
output_type (Optional[str]): A name of the output type. | ||
Defaults to None, an instance of `DefaultOutputType` will be | ||
returned. | ||
logger (Optional[str]): If passed, collects logs about correctness | ||
of the `output_type` argument and what kind of OutputType | ||
is created. | ||
Returns: | ||
(Union[ | ||
NumberOutputType, | ||
DataFrameOutputType, | ||
PlotOutputType, | ||
StringOutputType, | ||
DefaultOutputType | ||
]): An instance of the output type. | ||
""" | ||
if output_type is not None and output_type not in output_types_map and logger: | ||
possible_types_msg = ", ".join(f"'{type_}'" for type_ in output_types_map) | ||
logger.log( | ||
f"Unknown value for the parameter `output_type`: '{output_type}'." | ||
f"Possible values are: {possible_types_msg} and None for default " | ||
f"output type (miscellaneous).", | ||
level=logging.WARNING, | ||
) | ||
|
||
output_type_helper = output_types_map.get(output_type, DefaultOutputType)() | ||
|
||
if logger: | ||
logger.log( | ||
f"{output_type_helper.__class__} is going to be used.", level=logging.DEBUG | ||
) | ||
|
||
return output_type_helper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import re | ||
from decimal import Decimal | ||
from abc import abstractmethod, ABC | ||
from typing import Any, Iterable | ||
|
||
from ..df_info import df_type | ||
|
||
|
||
class BaseOutputType(ABC): | ||
@property | ||
@abstractmethod | ||
def template_hint(self) -> str: | ||
... | ||
|
||
@property | ||
@abstractmethod | ||
def name(self) -> str: | ||
... | ||
|
||
def _validate_type(self, actual_type: str) -> bool: | ||
if actual_type != self.name: | ||
return False | ||
return True | ||
|
||
@abstractmethod | ||
def _validate_value(self, actual_value): | ||
... | ||
|
||
def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]: | ||
""" | ||
Validate 'type' and 'value' from the result dict. | ||
Args: | ||
result (dict[str, Any]): The result of code execution in | ||
dict representation. Should have the following schema: | ||
{ | ||
"type": <output_type_name>, | ||
"value": <generated_value> | ||
} | ||
Returns: | ||
(tuple(bool, Iterable(str)): | ||
Boolean value whether the result matches output type | ||
and collection of logs containing messages about | ||
'type' or 'value' mismatches. | ||
""" | ||
validation_logs = [] | ||
actual_type, actual_value = result.get("type"), result.get("value") | ||
|
||
type_ok = self._validate_type(actual_type) | ||
if not type_ok: | ||
validation_logs.append( | ||
f"The result dict contains inappropriate 'type'. " | ||
f"Expected '{self.name}', actual '{actual_type}'." | ||
) | ||
value_ok = self._validate_value(actual_value) | ||
if not value_ok: | ||
validation_logs.append( | ||
f"Actual value {repr(actual_value)} seems to be inappropriate " | ||
f"for the type '{self.name}'." | ||
) | ||
|
||
return all((type_ok, value_ok)), validation_logs | ||
|
||
|
||
class NumberOutputType(BaseOutputType): | ||
@property | ||
def template_hint(self): | ||
return """- type (must be "number") | ||
- value (must be a number) | ||
Example output: { "type": "number", "value": 125 }""" | ||
|
||
@property | ||
def name(self): | ||
return "number" | ||
|
||
def _validate_value(self, actual_value: Any) -> bool: | ||
if isinstance(actual_value, (int, float, Decimal)): | ||
return True | ||
return False | ||
|
||
|
||
class DataFrameOutputType(BaseOutputType): | ||
@property | ||
def template_hint(self): | ||
return """- type (must be "dataframe") | ||
- value (must be a pandas dataframe) | ||
Example output: { "type": "dataframe", "value": pd.DataFrame({...}) }""" | ||
|
||
@property | ||
def name(self): | ||
return "dataframe" | ||
|
||
def _validate_value(self, actual_value: Any) -> bool: | ||
return bool(df_type(actual_value)) | ||
|
||
|
||
class PlotOutputType(BaseOutputType): | ||
@property | ||
def template_hint(self): | ||
return """- type (must be "plot") | ||
- value (must be a string containing the path of the plot image) | ||
Example output: { "type": "plot", "value": "export/charts/temp.png" }""" | ||
|
||
@property | ||
def name(self): | ||
return "plot" | ||
|
||
def _validate_value(self, actual_value: Any) -> bool: | ||
if not isinstance(actual_value, str): | ||
return False | ||
|
||
path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$" | ||
if re.match(path_to_plot_pattern, actual_value): | ||
return True | ||
|
||
return False | ||
|
||
|
||
class StringOutputType(BaseOutputType): | ||
@property | ||
def template_hint(self): | ||
return """- type (must be "string") | ||
- value (must be a conversational answer, as a string) | ||
Example output: { "type": "string", "value": "The highest salary is $9,000." }""" | ||
|
||
@property | ||
def name(self): | ||
return "string" | ||
|
||
def _validate_value(self, actual_value: Any) -> bool: | ||
if isinstance(actual_value, str): | ||
return True | ||
return False | ||
|
||
|
||
class DefaultOutputType(BaseOutputType): | ||
@property | ||
def template_hint(self): | ||
return """- type (possible values "string", "number", "dataframe", "plot") | ||
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary) | ||
Examples: | ||
{ "type": "string", "value": "The highest salary is $9,000." } | ||
or | ||
{ "type": "number", "value": 125 } | ||
or | ||
{ "type": "dataframe", "value": pd.DataFrame({...}) } | ||
or | ||
{ "type": "plot", "value": "export/charts/temp.png" }""" # noqa E501 | ||
|
||
@property | ||
def name(self): | ||
return "default" | ||
|
||
def _validate_type(self, actual_type: str) -> bool: | ||
return True | ||
|
||
def _validate_value(self, actual_value: Any) -> bool: | ||
return True | ||
|
||
def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable]: | ||
""" | ||
Validate 'type' and 'value' from the result dict. | ||
Returns: | ||
(bool): True since the `DefaultOutputType` | ||
is supposed to have no validation | ||
""" | ||
return True, () |
Oops, something went wrong.