Skip to content

Commit

Permalink
Merge pull request #13 from fetchai/feat/migrate-to-pydantic-v2
Browse files Browse the repository at this point in the history
Migrate to pydantic v2
  • Loading branch information
XaviPeiro authored Aug 28, 2024
2 parents 6bda180 + 8a05a93 commit e2b6c54
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ share/python-wheels/
*.egg
MANIFEST

#VS Code extensions
.history/*

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down
15 changes: 6 additions & 9 deletions ai_engine_sdk/api_models/parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@

def get_options_from_raw_api_response(raw_api_response: dict) -> list[dict[str, str]]:
return [
{
'key': str(o['key']), 'title': o['value']
}
for o in raw_api_response['agent_json']['options']
{"key": str(o["key"]), "title": o["value"]}
for o in raw_api_response["agent_json"]["options"]
]


def get_task_options_from_options(options: list[dict[str, str]]) -> list[TaskOption]:
return [
TaskOption.parse_obj({
"key": option['key'],
"title": option['title']
})
TaskOption.model_validate({"key": option["key"], "title": option["title"]})
for option in options
]


def get_indexed_task_options_from_raw_api_response(raw_api_response: dict) -> dict[TaskOption]:
def get_indexed_task_options_from_raw_api_response(
raw_api_response: dict,
) -> dict[TaskOption]:
options_list = get_options_from_raw_api_response(raw_api_response=raw_api_response)
task_options_list = get_task_options_from_options(options=options_list)

Expand Down
30 changes: 15 additions & 15 deletions ai_engine_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ async def _submit_message(self, payload: ApiMessagePayload):
api_key=self._api_key,
method='POST',
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}/submit",
payload={'payload': payload.dict()}
payload={'payload': payload.model_dump()}
)

async def start(self, objective: str, context: Optional[str] = None):
await self._submit_message(
payload=ApiStartMessage.parse_obj({
payload=ApiStartMessage.model_validate({
'session_id': self.session_id,
'bucket_id': self.function_group,
'message_id': str(uuid4()).lower(),
Expand All @@ -123,7 +123,7 @@ async def start(self, objective: str, context: Optional[str] = None):

async def submit_task_selection(self, selection: TaskSelectionMessage, options: list[TaskOption]):
await self._submit_message(
payload=ApiUserJsonMessage.parse_obj({
payload=ApiUserJsonMessage.model_validate({
'session_id': self.session_id,
'message_id': str(uuid4()).lower(),
'referral_id': selection.id,
Expand All @@ -136,7 +136,7 @@ async def submit_task_selection(self, selection: TaskSelectionMessage, options:

async def submit_response(self, query: AgentMessage, response: str):
await self._submit_message(
payload=ApiUserMessageMessage.parse_obj(
payload=ApiUserMessageMessage.model_validate(
{
'session_id': self.session_id,
'message_id': str(uuid4()).lower(),
Expand All @@ -148,7 +148,7 @@ async def submit_response(self, query: AgentMessage, response: str):

async def submit_confirmation(self, confirmation: ConfirmationMessage):
await self._submit_message(
payload=ApiUserMessageMessage.parse_obj({
payload=ApiUserMessageMessage.model_validate({
'session_id': self.session_id,
'message_id': str(uuid4()).lower(),
'referral_id': confirmation.id,
Expand All @@ -158,7 +158,7 @@ async def submit_confirmation(self, confirmation: ConfirmationMessage):

async def reject_confirmation(self, confirmation: ConfirmationMessage, reason: str):
await self._submit_message(
payload=ApiUserMessageMessage.parse_obj({
payload=ApiUserMessageMessage.model_validate({
'session_id': self.session_id,
'message_id': str(uuid4()).lower(),
'referral_id': confirmation.id,
Expand Down Expand Up @@ -188,7 +188,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
if is_task_selection_message(message_type=agent_json_type):
indexed_task_options: dict = get_indexed_task_options_from_raw_api_response(raw_api_response=message)
newMessages.append(
TaskSelectionMessage.parse_obj({
TaskSelectionMessage.model_validate({
'type': agent_json_type,
'id': message['message_id'],
'timestamp': message['timestamp'],
Expand All @@ -198,7 +198,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
)
elif is_api_context_json(message_type=agent_json_type, agent_json_text=agent_json['text']):
newMessages.append(
ConfirmationMessage.parse_obj({
ConfirmationMessage.model_validate({
'id': message['message_id'],
'timestamp': message['timestamp'],
'text': agent_json['text'],
Expand All @@ -208,7 +208,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
)
elif is_data_request_message(message_type=agent_json_type):
newMessages.append(
DataRequestMessage.parse_obj({
DataRequestMessage.model_validate({
"id": message['message_id'],
"text": agent_json['text'],
"type": agent_json_type,
Expand All @@ -220,7 +220,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
print(f"UNKNOWN-JSON: {message}")
elif is_api_agent_info_message(message):
newMessages.append(
AiEngineMessage.parse_obj({
AiEngineMessage.model_validate({
'id': message['message_id'],
'type': 'ai-engine',
'timestamp': message['timestamp'],
Expand All @@ -229,7 +229,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
)
elif is_api_agent_message_message(message):
newMessages.append(
AgentMessage.parse_obj({
AgentMessage.model_validate({
'id': message['message_id'],
'type': 'agent',
'timestamp': message['timestamp'],
Expand All @@ -239,7 +239,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
elif is_api_stop_message(message):
print(f"STOP: {message}")
newMessages.append(
StopMessage.parse_obj({
StopMessage.model_validate({
'id': message['message_id'],
'timestamp': message['timestamp'],
'type': 'stop',
Expand Down Expand Up @@ -289,7 +289,7 @@ async def get_public_function_groups(self) -> List[FunctionGroup]:
)
return list(
map(
lambda item: FunctionGroup.parse_obj(item),
lambda item: FunctionGroup.model_validate(item),
raw_response
)
)
Expand All @@ -303,7 +303,7 @@ async def get_private_function_groups(self) -> List[FunctionGroup]:
)
return list(
map(
lambda item: FunctionGroup.parse_obj(item),
lambda item: FunctionGroup.model_validate(item),
raw_response
)
)
Expand Down Expand Up @@ -393,7 +393,7 @@ async def create_session(self, function_group: str, opts: Optional[dict] = None)
api_key=self._api_key,
method='POST',
endpoint="/v1beta1/engine/chat/sessions",
payload=request_payload.dict()
payload=request_payload.model_dump()
)

return Session(self._api_base_url, self._api_key, response['session_id'], function_group)
Expand Down
Loading

0 comments on commit e2b6c54

Please sign in to comment.