diff --git a/mindsdb_sdk/agents.py b/mindsdb_sdk/agents.py index 934a12b..b357500 100644 --- a/mindsdb_sdk/agents.py +++ b/mindsdb_sdk/agents.py @@ -4,6 +4,7 @@ from uuid import uuid4 import datetime import json +import pandas as pd from mindsdb_sdk.knowledge_bases import KnowledgeBase from mindsdb_sdk.models import Model @@ -155,6 +156,41 @@ def add_webpage( """ self.collection.add_webpage(self.name, url, description, knowledge_base=knowledge_base, crawl_depth=crawl_depth, filters=filters) + def add_dataframe( + self, + df: pd.DataFrame, + description: str, + knowledge_base: str = None + ): + """ + Add a list of webpages to the agent for retrieval. + + :param df: dataframe to be added. + :param description: Description of the webpages. Used by agent to know when to do retrieval. + :param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given. + """ + if df is None or df.empty: + return + + if knowledge_base is not None: + kb = self.collection.knowledge_bases.get(knowledge_base) + else: + kb_name = f'{self.name.lower()}_df_{uuid4().hex}_kb' + kb = self.collection._create_default_knowledge_base(self, kb_name) + + # Insert crawled webpage. + kb.insert(df) + + # Make sure skill name is unique. + skill_name = f'df_retrieval_skill_{uuid4().hex}' + retrieval_params = { + 'source': kb.name, + 'description': description, + } + dataframe_retrieval_skill = self.collection.skills.create(skill_name, 'retrieval', retrieval_params) + self.skills.append(dataframe_retrieval_skill) + self.collection.update(self.name, self) + def add_database(self, database: str, tables: List[str], description: str): """ Add a database to the agent for retrieval. diff --git a/tests/test_sdk.py b/tests/test_sdk.py index e82de56..985f5ff 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -1664,6 +1664,82 @@ def test_add_webpage(self, mock_post, mock_put, mock_get): } assert agent_update_json == expected_agent_json + @patch('requests.Session.get') + @patch('requests.Session.put') + @patch('requests.Session.post') + def test_add_dataframe(self, mock_post, mock_put, mock_get): + server = mindsdb_sdk.connect() + responses_mock(mock_get, [ + # Existing agent get. + { + 'name': 'test_agent', + 'model_name': 'test_model', + 'skills': [], + 'params': {}, + 'created_at': None, + 'updated_at': None, + 'provider': 'mindsdb' + }, + # get KB + { + 'id': 1, + 'name': 'my_kb', + 'project_id': 1, + 'embedding_model': 'openai_emb', + 'vector_database': 'pvec', + 'vector_database_table': 'tbl1', + 'updated_at': '2024-10-04 10:55:25.350799', + 'created_at': '2024-10-04 10:55:25.350790', + 'params': {} + }, + # Skills get in Agent update to check if it exists. + {'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}}, + # Existing agent get in Agent update. + { + 'name':'test_agent', + 'model_name':'test_model', + 'skills':[], + 'params':{}, + 'created_at':None, + 'updated_at':None, + 'provider':'mindsdb' # Added provider field + }, + ]) + responses_mock(mock_post, [ + # Skill creation. + {'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}} + ]) + responses_mock(mock_put, [ + # KB update. + {'name':'test_agent_docs_mdb_ai_kb'}, + # Agent update with new skill. + { + 'name':'test_agent', + 'model_name':'test_model', + 'skills':[{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}}], + 'params':{}, + 'created_at':None, + 'updated_at':None, + 'provider':'mindsdb' # Added provider field + }, + ]) + server.agents.test_agent.add_dataframe(pd.DataFrame([{'content': 'doc'}]), 'Documentation for MindsDB', 'existing_kb') + + # Check Agent was updated with a new skill. + agent_update_json = mock_put.call_args[-1]['json'] + expected_agent_json = { + 'agent':{ + 'name':'test_agent', + 'model_name':'test_model', + # Skill name is a generated UUID. + 'skills_to_add':[agent_update_json['agent']['skills_to_add'][0]], + 'skills_to_remove':[], + 'params':{}, + 'provider': 'mindsdb' + } + } + assert agent_update_json == expected_agent_json + @patch('requests.Session.get') @patch('requests.Session.put') @patch('requests.Session.post')