Skip to content

Commit

Permalink
fix: fix function calling for openai
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Feb 17, 2025
1 parent c9604ed commit 8459b1c
Show file tree
Hide file tree
Showing 18 changed files with 155 additions and 132 deletions.
15 changes: 5 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion anda_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "anda_core"
description = "Core types and traits for Anda -- an AI agent framework built with Rust, powered by ICP and TEEs."
repository = "https://github.com/ldclabs/anda/tree/main/anda_core"
publish = true
version = "0.4.0"
version = "0.4.1"
edition.workspace = true
keywords.workspace = true
categories.workspace = true
Expand All @@ -23,3 +23,4 @@ object_store = { workspace = true }
ic_cose_types = { workspace = true }
tokio-util = { workspace = true }
reqwest = { workspace = true }
schemars = { workspace = true }
47 changes: 47 additions & 0 deletions anda_core/src/json.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use schemars::schema::{RootSchema, Schema, SchemaObject, SingleOrVec};

pub use schemars::{schema_for, JsonSchema};

/// Function Calling has strict requirements for JsonSchema, use fix_json_schema to fix it
/// 1. Remove $schema field
/// 2. Remove $format field
/// 3. Object type Schema must set additionalProperties: false
/// 4. required field should include all properties fields, meaning all struct fields are required (no Option)
pub fn fix_json_schema(schema: &mut RootSchema) {
schema.meta_schema = None; // Remove the $schema field
fix_obj_schema(&mut schema.schema);
}

fn fix_obj_schema(schema: &mut SchemaObject) {
schema.format = None; // Remove the $format field
if let Some(obj) = &mut schema.object {
// https://platform.openai.com/docs/guides/structured-outputs#additionalproperties-false-must-always-be-set-in-objects
obj.additional_properties = Some(Box::new(Schema::Bool(false)));
// if obj.required.len() != obj.properties.len() {
// obj.required = obj.properties.keys().cloned().collect();
// }
for v in obj.properties.values_mut() {
if let Schema::Object(o) = v {
fix_obj_schema(o);
}
}
}
if let Some(arr) = &mut schema.array {
if let Some(v) = &mut arr.items {
match v {
SingleOrVec::Single(v) => {
if let Schema::Object(o) = v.as_mut() {
fix_obj_schema(o);
}
}
SingleOrVec::Vec(arr) => {
for v in arr {
if let Schema::Object(o) = v {
fix_obj_schema(o);
}
}
}
}
}
}
}
2 changes: 2 additions & 0 deletions anda_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use std::{future::Future, pin::Pin};
pub mod agent;
pub mod context;
pub mod http;
pub mod json;
pub mod model;
pub mod tool;

pub use agent::*;
pub use context::*;
pub use http::*;
pub use json::*;
pub use model::*;
pub use tool::*;

Expand Down
4 changes: 2 additions & 2 deletions anda_engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "anda_engine"
description = "Agents engine for Anda -- an AI agent framework built with Rust, powered by ICP and TEEs."
repository = "https://github.com/ldclabs/anda/tree/main/anda_engine"
publish = true
version = "0.4.1"
version = "0.4.2"
edition.workspace = true
keywords.workspace = true
categories.workspace = true
Expand All @@ -26,13 +26,13 @@ ic_cose_types = { workspace = true }
ic_tee_gateway_sdk = { workspace = true }
tokio-util = { workspace = true }
structured-logger = { workspace = true }
schemars = { workspace = true }
reqwest = { workspace = true }
rand = { workspace = true }
moka = { workspace = true }
toml = { workspace = true }
tokio = { workspace = true }
log = { workspace = true }
schemars = { workspace = true }
url = { workspace = true }

[dev-dependencies]
Expand Down
39 changes: 22 additions & 17 deletions anda_engine/src/context/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,25 +272,30 @@ impl CompletionFeatures for AgentCtx {

// remove called tool from req.tools
req.tools.retain(|t| t.name != tool.name);
match self.tool_call(&tool.name, tool.args.clone()).await {
Ok((val, con)) => {
if con {
// need to use LLM to continue processing tool_call result
tool_calls_continue.push(json!(Message {
role: "tool".to_string(),
content: val.clone().into(),
name: None,
tool_call_id: Some(tool.id.clone()),
}));
if self.tools.contains(&tool.name) {
match self.tool_call(&tool.name, tool.args.clone()).await {
Ok((val, con)) => {
if con {
// need to use LLM to continue processing tool_call result
tool_calls_continue.push(json!(Message {
role: "tool".to_string(),
content: val.clone().into(),
name: None,
tool_call_id: Some(tool.id.clone()),
}));
}
tool.result = Some(val);
}
Err(err) => {
res.failed_reason = Some(err.to_string());
return Ok(res);
}
tool.result = Some(val);
}
Err(_err) => {
// TODO:
// support remote_tool_call
// support agent_run
// support remote_agent_run
}
} else {
// TODO:
// support remote_tool_call
// support agent_run
// support remote_agent_run
}
}

Expand Down
9 changes: 5 additions & 4 deletions anda_engine/src/extension/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
//! #[derive(JsonSchema, Serialize, Deserialize)]
//! struct ContactInfo {
//! name: String,
//! phone: Option<String>,
//! phone: String,
//! }
//!
//! let extractor = Extractor::<ContactInfo>::default();
Expand All @@ -42,12 +42,13 @@
//! - These traits can be easily derived using the `derive` macro
use anda_core::{
Agent, AgentOutput, BoxError, CompletionFeatures, CompletionRequest, FunctionDefinition, Tool,
fix_json_schema, Agent, AgentOutput, BoxError, CompletionFeatures, CompletionRequest,
FunctionDefinition, Tool,
};
use schemars::{schema_for, JsonSchema};
use serde_json::{json, Value};
use std::marker::PhantomData;

pub use schemars::{schema_for, JsonSchema};
pub use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::context::{AgentCtx, BaseCtx};
Expand Down Expand Up @@ -83,7 +84,7 @@ where
/// uses the type's title (if available) as the tool name
pub fn new() -> SubmitTool<T> {
let mut schema = schema_for!(T);
schema.meta_schema = None; // Remove the $schema field
fix_json_schema(&mut schema);
let name = schema
.schema
.metadata
Expand Down
4 changes: 2 additions & 2 deletions anda_engine/src/extension/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! .build("default_agent".to_string())?;
//! ```
use anda_core::{BoxError, FunctionDefinition, HttpFeatures, Tool};
use anda_core::{fix_json_schema, BoxError, FunctionDefinition, HttpFeatures, Tool};
use http::header;
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -86,7 +86,7 @@ impl GoogleSearchTool {
/// * `result_number` - Optional number of results to return (defaults to 5)
pub fn new(api_key: String, search_engine_id: String, result_number: Option<u8>) -> Self {
let mut schema = schema_for!(SearchArgs);
schema.meta_schema = None; // Remove the $schema field
fix_json_schema(&mut schema);

GoogleSearchTool {
api_key,
Expand Down
3 changes: 2 additions & 1 deletion anda_engine/src/extension/segmenter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
use anda_core::{
evaluate_tokens, Agent, AgentOutput, BoxError, CompletionFeatures, Tool, ToolCall,
};
use schemars::JsonSchema;

use super::extractor::{Deserialize, Extractor, JsonSchema, Serialize, SubmitTool};
use super::extractor::{Deserialize, Extractor, Serialize, SubmitTool};
use crate::context::AgentCtx;

/// Represents the output of document segmentation containing multiple text segments
Expand Down
12 changes: 9 additions & 3 deletions anda_engine/src/model/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,12 @@ impl CompletionFeaturesDyn for CompletionModel {
let mut body = json!({
"model": model,
"messages": full_history.clone(),
"temperature": req.temperature,
});

let body = body.as_object_mut().unwrap();
if let Some(temperature) = req.temperature {
body.insert("temperature".to_string(), Value::from(temperature));
}

if let Some(max_tokens) = req.max_tokens {
body.insert("max_tokens".to_string(), Value::from(max_tokens));
Expand Down Expand Up @@ -314,7 +317,8 @@ impl CompletionFeaturesDyn for CompletionModel {

let response = client.post("/chat/completions").json(body).send().await?;
if response.status().is_success() {
match response.json::<CompletionResponse>().await {
let text = response.text().await?;
match serde_json::from_str::<CompletionResponse>(&text) {
Ok(res) => {
if log_enabled!(Debug) {
if let Ok(val) = serde_json::to_string(&res) {
Expand All @@ -323,7 +327,9 @@ impl CompletionFeaturesDyn for CompletionModel {
}
res.try_into(full_history)
}
Err(err) => Err(format!("DeepSeek completions error: {}", err).into()),
Err(err) => {
Err(format!("DeepSeek completions error: {}, body: {}", err, text).into())
}
}
} else {
let msg = response.text().await?;
Expand Down
Loading

0 comments on commit 8459b1c

Please sign in to comment.