Skip to content

Commit

Permalink
add dataframe to agent
Browse files Browse the repository at this point in the history
  • Loading branch information
ea-rus committed Dec 27, 2024
1 parent 633527f commit 5139504
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
36 changes: 36 additions & 0 deletions mindsdb_sdk/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from uuid import uuid4
import datetime
import json
import pandas as pd

from mindsdb_sdk.knowledge_bases import KnowledgeBase
from mindsdb_sdk.models import Model
Expand Down Expand Up @@ -155,6 +156,41 @@ def add_webpage(
"""
self.collection.add_webpage(self.name, url, description, knowledge_base=knowledge_base, crawl_depth=crawl_depth, filters=filters)

def add_dataframe(
self,
df: pd.DataFrame,
description: str,
knowledge_base: str = None
):
"""
Add a list of webpages to the agent for retrieval.
:param df: dataframe to be added.
:param description: Description of the webpages. Used by agent to know when to do retrieval.
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
"""
if df is None or df.empty:
return

if knowledge_base is not None:
kb = self.collection.knowledge_bases.get(knowledge_base)
else:
kb_name = f'{self.name.lower()}_df_{uuid4().hex}_kb'
kb = self.collection._create_default_knowledge_base(self, kb_name)

# Insert crawled webpage.
kb.insert(df)

# Make sure skill name is unique.
skill_name = f'df_retrieval_skill_{uuid4().hex}'
retrieval_params = {
'source': kb.name,
'description': description,
}
dataframe_retrieval_skill = self.collection.skills.create(skill_name, 'retrieval', retrieval_params)
self.skills.append(dataframe_retrieval_skill)
self.collection.update(self.name, self)

def add_database(self, database: str, tables: List[str], description: str):
"""
Add a database to the agent for retrieval.
Expand Down
76 changes: 76 additions & 0 deletions tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,82 @@ def test_add_webpage(self, mock_post, mock_put, mock_get):
}
assert agent_update_json == expected_agent_json

@patch('requests.Session.get')
@patch('requests.Session.put')
@patch('requests.Session.post')
def test_add_dataframe(self, mock_post, mock_put, mock_get):
server = mindsdb_sdk.connect()
responses_mock(mock_get, [
# Existing agent get.
{
'name': 'test_agent',
'model_name': 'test_model',
'skills': [],
'params': {},
'created_at': None,
'updated_at': None,
'provider': 'mindsdb'
},
# get KB
{
'id': 1,
'name': 'my_kb',
'project_id': 1,
'embedding_model': 'openai_emb',
'vector_database': 'pvec',
'vector_database_table': 'tbl1',
'updated_at': '2024-10-04 10:55:25.350799',
'created_at': '2024-10-04 10:55:25.350790',
'params': {}
},
# Skills get in Agent update to check if it exists.
{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}},
# Existing agent get in Agent update.
{
'name':'test_agent',
'model_name':'test_model',
'skills':[],
'params':{},
'created_at':None,
'updated_at':None,
'provider':'mindsdb' # Added provider field
},
])
responses_mock(mock_post, [
# Skill creation.
{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}}
])
responses_mock(mock_put, [
# KB update.
{'name':'test_agent_docs_mdb_ai_kb'},
# Agent update with new skill.
{
'name':'test_agent',
'model_name':'test_model',
'skills':[{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}}],
'params':{},
'created_at':None,
'updated_at':None,
'provider':'mindsdb' # Added provider field
},
])
server.agents.test_agent.add_dataframe(pd.DataFrame([{'content': 'doc'}]), 'Documentation for MindsDB', 'existing_kb')

# Check Agent was updated with a new skill.
agent_update_json = mock_put.call_args[-1]['json']
expected_agent_json = {
'agent':{
'name':'test_agent',
'model_name':'test_model',
# Skill name is a generated UUID.
'skills_to_add':[agent_update_json['agent']['skills_to_add'][0]],
'skills_to_remove':[],
'params':{},
'provider': 'mindsdb'
}
}
assert agent_update_json == expected_agent_json

@patch('requests.Session.get')
@patch('requests.Session.put')
@patch('requests.Session.post')
Expand Down

0 comments on commit 5139504

Please sign in to comment.