From 431c3d510ff601c862303a35bfb1ffd451f42cd3 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Wed, 19 Jun 2024 14:27:56 +0200 Subject: [PATCH] chore(SemanticAgent): add samples in schema and support back-tick json load (#1241) * fix(SemanticAgent): join data to be fixed * fix(semantic_agent): json load to also look for json in backtick --- pandasai/ee/agents/semantic_agent/__init__.py | 3 +- .../semantic_agent/pipeline/llm_call.py | 4 +- .../prompts/generate_df_schema.py | 5 +- .../prompts/templates/generate_df_schema.tmpl | 167 ++++++++++-------- pandasai/ee/helpers/json_helper.py | 14 ++ 5 files changed, 113 insertions(+), 80 deletions(-) create mode 100644 pandasai/ee/helpers/json_helper.py diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py index dc8aee31b..d6f736372 100644 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ b/pandasai/ee/agents/semantic_agent/__init__.py @@ -15,6 +15,7 @@ from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import ( GenerateDFSchemaPrompt, ) +from pandasai.ee.helpers.json_helper import extract_json_from_json_str from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson from pandasai.helpers.cache import Cache from pandasai.helpers.memory import Memory @@ -186,7 +187,7 @@ def _create_schema(self): """ ) self._schema = result.replace("# SAMPLE SCHEMA", "") - schema_data = json.loads(result.replace("# SAMPLE SCHEMA", "")) + schema_data = extract_json_from_json_str(result.replace("# SAMPLE SCHEMA", "")) if isinstance(schema_data, dict): schema_data = [schema_data] diff --git a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py index e9946140d..af1bd2e18 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py @@ -1,6 +1,6 @@ -import json from typing import Any +from pandasai.ee.helpers.json_helper import extract_json_from_json_str from pandasai.helpers.logger import Logger from pandasai.pipelines.base_logic_unit import BaseLogicUnit from pandasai.pipelines.logic_unit_output import LogicUnitOutput @@ -42,7 +42,7 @@ def execute(self, input: Any, **kwargs) -> Any: ) try: # Validate is valid Json - response_json = json.loads(response) + response_json = extract_json_from_json_str(response) pipeline_context.add("llm_call", response) diff --git a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py index 80237c944..28390f8b7 100644 --- a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py +++ b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py @@ -3,6 +3,7 @@ from jinja2 import Environment, FileSystemLoader +from pandasai.ee.helpers.json_helper import extract_json_from_json_str from pandasai.prompts.base import BasePrompt @@ -30,7 +31,9 @@ def __init__(self, **kwargs): def validate(self, output: str) -> bool: try: - json_data = json.loads(output.replace("# SAMPLE SCHEMA", "")) + json_data = extract_json_from_json_str( + output.replace("# SAMPLE SCHEMA", "") + ) context = self.props["context"] if isinstance(json_data, dict): json_data = [json_data] diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl index 6a45e5fe1..edec51e2d 100644 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl +++ b/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl @@ -1,132 +1,147 @@ # SAMPLE SCHEMA [ { - "name":"Contracts", - "table":"contracts", - "measures":[ + "name": "Contracts", + "table": "contracts", + "measures": [ { - "name":"contract_count", - "type":"count", - "sql":"store_id" + "name": "contract_count", + "type": "count", + "sql": "store_id" }, { - "name":"contract_duration", - "type":"number", - "sql":"${contract_end_date} - ${contract_start_date}" + "name": "contract_duration", + "type": "number", + "sql": "${contract_end_date} - ${contract_start_date}" }, { - "name":"contract_avg_duration", - "type":"avg", - "sql":"${contract_duration}" + "name": "contract_avg_duration", + "type": "avg", + "sql": "${contract_duration}" } ], - "dimensions":[ + "dimensions": [ { - "name":"contract_code", - "type":"string", - "sql":"contract_code" + "name": "contract_code", + "type": "string", + "sql": "contract_code", + "samples": ["C12345", "C67890"] }, { - "name":"store_id", - "type":"string", - "sql":"store_id" + "name": "store_id", + "type": "string", + "sql": "store_id", + "samples": ["S12345", "S67890"] }, { - "name":"tenant_code", - "type":"string", - "sql":"tenant_code" + "name": "tenant_code", + "type": "string", + "sql": "tenant_code", + "samples": ["T12345", "T67890"] }, { - "name":"tenant_name", - "type":"string", - "sql":"tenant_name" + "name": "tenant_name", + "type": "string", + "sql": "tenant_name", + "samples": ["Tenant A", "Tenant B"] }, { - "name":"store_brand", - "type":"string", - "sql":"store_brand" + "name": "store_brand", + "type": "string", + "sql": "store_brand", + "samples": ["Brand X", "Brand Y"] }, { - "name":"branch_segment_1", - "type":"string", - "sql":"branch_segment_1" + "name": "branch_segment_1", + "type": "string", + "sql": "branch_segment_1", + "samples": ["Segment 1", "Segment 2"] }, { - "name":"branch_segment_2", - "type":"string", - "sql":"branch_segment_2" + "name": "branch_segment_2", + "type": "string", + "sql": "branch_segment_2", + "samples": ["Segment A", "Segment B"] }, { - "name":"contract_start_date", - "type":"date", - "sql":"contract_start_date" + "name": "contract_start_date", + "type": "date", + "sql": "contract_start_date", + "samples": ["2023-01-01", "2023-02-01"] }, { - "name":"contract_end_date", - "type":"date", - "sql":"contract_end_date" + "name": "contract_end_date", + "type": "date", + "sql": "contract_end_date", + "samples": ["2024-01-01", "2024-02-01"] } ], - "joins":[ + "joins": [ { - "name":"corrispettivi", - "join_type":"left", - "sql":"${Contracts.contract_code} = ${Fees.contract_id}" + "name": "Fee", + "join_type": "left", + "sql": "${Contracts.contract_code} = ${Fees.contract_id}" } ] }, { - "name":"Fees", - "table":"fees", - "measures":[ + "name": "Fees", + "table": "fees", + "measures": [ { - "name":"total_taxable", - "type":"sum", - "sql":"imponibile_tot" + "name": "total_taxable", + "type": "sum", + "sql": "imponibile_tot" }, { - "name":"total_revenue", - "type":"sum", - "sql":"totale_tot" + "name": "total_revenue", + "type": "sum", + "sql": "totale_tot" } ], - "dimensions":[ + "dimensions": [ { - "name":"contract_id", - "type":"string", - "sql":"contract_id" + "name": "contract_id", + "type": "string", + "sql": "contract_id", + "samples": ["C12345", "C67890"] }, { - "name":"code", - "type":"string", - "sql":"code" + "name": "code", + "type": "string", + "sql": "code", + "samples": ["F12345", "F67890"] }, { - "name":"station", - "type":"string", - "sql":"station" + "name": "station", + "type": "string", + "sql": "station", + "samples": ["Station X", "Station Y"] }, { - "name":"tenant_id", - "type":"string", - "sql":"tenant_id" + "name": "tenant_id", + "type": "string", + "sql": "tenant_id", + "samples": ["T12345", "T67890"] }, { - "name":"day", - "type":"date", - "sql":"day" + "name": "day", + "type": "date", + "sql": "day", + "samples": ["2023-01-01", "2023-02-01"] }, { - "name":"store_id", - "type":"string", - "sql":"store_id" + "name": "store_id", + "type": "string", + "sql": "store_id", + "samples": ["S12345", "S67890"] } ], - "joins":[ + "joins": [ { - "name":"contracts", - "join_type":"right", - "sql":"${Fees.contract_id} = ${Fees.contract_code}" + "name": "Contracts", + "join_type": "right", + "sql": "${Fees.contract_id} = ${Contracts.contract_code}" } ] } diff --git a/pandasai/ee/helpers/json_helper.py b/pandasai/ee/helpers/json_helper.py new file mode 100644 index 000000000..a7ca0bce2 --- /dev/null +++ b/pandasai/ee/helpers/json_helper.py @@ -0,0 +1,14 @@ +import json + + +def extract_json_from_json_str(json_str): + start_index = json_str.find("```json") + + end_index = json_str.find("```", start_index) + + if start_index == -1: + return json.loads(json_str) + + json_data = json_str[(start_index + len("```json")) : end_index].strip() + + return json.loads(json_data)