From df06bec426033bfd0ac46bda8eb184d15da949e1 Mon Sep 17 00:00:00 2001 From: cFireworks Date: Wed, 16 Oct 2024 16:32:26 +0800 Subject: [PATCH] fix: support chinese in prompt generation (#1317) * Update dataframe_serializer.py fix #1168 Support Chinese characters in prompt generation stage * fix: Support Chinese characters in prompt generation stage (#1168) Update dataframe_serializer.py add test case to #1168 --- pandasai/helpers/dataframe_serializer.py | 2 +- .../helpers/test_dataframe_serializer.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/helpers/test_dataframe_serializer.py diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index 8b3664de1..cfaffc9b4 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -160,7 +160,7 @@ def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str: def convert_df_to_yml(self, df: pd.DataFrame, extras: dict) -> str: json_df = self.convert_df_to_json(df, extras) - yml_str = yaml.dump(json_df, sort_keys=False) + yml_str = yaml.dump(json_df, sort_keys=False, allow_unicode=True) if "is_direct_sql" in extras and extras["is_direct_sql"]: return f"\n{yml_str}\n
\n" return yml_str diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py new file mode 100644 index 000000000..3cf64b7df --- /dev/null +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -0,0 +1,30 @@ +import unittest + +import pandas as pd + +from pandasai.connectors import PandasConnector +from pandasai.helpers.dataframe_serializer import ( + DataframeSerializer, + DataframeSerializerType, +) + + +class TestDataframeSerializer(unittest.TestCase): + def setUp(self): + self.serializer = DataframeSerializer() + + def test_convert_df_to_yml(self): + # Test convert df to yml + data = {"name": ["en_name", "中文_名称"]} + connector = PandasConnector( + {"original_df": pd.DataFrame(data)}, + name="en_table_name", + description="中文_描述", + field_descriptions={k: k for k in data}, + ) + result = self.serializer.serialize( + connector, + type_=DataframeSerializerType.YML, + extras={"index": 0, "type": "pd.Dataframe"}, + ) + self.assertIn("中文_描述", result)