Skip to content

Commit

Permalink
feat: support Grok
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Feb 18, 2025
1 parent 2a1623e commit 2264f5b
Show file tree
Hide file tree
Showing 13 changed files with 505 additions and 53 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion anda_engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "anda_engine"
description = "Agents engine for Anda -- an AI agent framework built with Rust, powered by ICP and TEEs."
repository = "https://github.com/ldclabs/anda/tree/main/anda_engine"
publish = true
version = "0.4.2"
version = "0.4.3"
edition.workspace = true
keywords.workspace = true
categories.workspace = true
Expand Down
21 changes: 15 additions & 6 deletions anda_engine/src/model/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::APP_USER_AGENT;
// ================================================================
// Main DeepSeek Client
// ================================================================
const DEEKSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
const API_BASE_URL: &str = "https://api.deepseek.com";
pub static DEEKSEEK_V3: &str = "deepseek-chat";
pub static DEEKSEEK_R1: &str = "deepseek-reasoner";

Expand All @@ -40,16 +40,22 @@ impl Client {
/// # Returns
/// Configured DeepSeek client instance
pub fn new(api_key: &str, endpoint: Option<String>) -> Self {
let endpoint = endpoint.unwrap_or_else(|| API_BASE_URL.to_string());
let endpoint = if endpoint.is_empty() {
API_BASE_URL.to_string()
} else {
endpoint
};
Self {
endpoint: endpoint.unwrap_or_else(|| DEEKSEEK_API_BASE_URL.to_string()),
endpoint,
http: reqwest::Client::builder()
.use_rustls_tls()
.https_only(true)
.http2_keep_alive_interval(Some(Duration::from_secs(25)))
.http2_keep_alive_timeout(Duration::from_secs(15))
.http2_keep_alive_while_idle(true)
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(120))
.timeout(Duration::from_secs(180))
.user_agent(APP_USER_AGENT)
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
Expand Down Expand Up @@ -77,7 +83,10 @@ impl Client {

/// Creates a new completion model instance using the default DeepSeek model
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
CompletionModel::new(
self.clone(),
if model.is_empty() { DEEKSEEK_V3 } else { model },
)
}
}

Expand Down Expand Up @@ -122,7 +131,7 @@ impl CompletionResponse {
let choice = self.choices.pop().ok_or("No completion choice")?;
full_history.push(json!(choice.message));
let mut output = AgentOutput {
content: choice.message.content,
content: choice.message.content.unwrap_or_default(),
tool_calls: choice.message.tool_calls.map(|tools| {
tools
.into_iter()
Expand Down Expand Up @@ -162,7 +171,7 @@ pub struct Choice {
pub struct MessageOutput {
pub role: String,
#[serde(default)]
pub content: String,
pub content: Option<String>,
pub refusal: Option<String>,
pub tool_calls: Option<Vec<ToolCallOutput>>,
}
Expand Down
80 changes: 79 additions & 1 deletion anda_engine/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ use anda_core::{
AgentOutput, BoxError, BoxPinFut, CompletionFeatures, CompletionRequest, Embedding,
EmbeddingFeatures, ToolCall,
};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use std::{convert::Infallible, str::FromStr, sync::Arc};

pub mod cohere;
pub mod deepseek;
pub mod openai;
pub mod xai;

/// Trait for dynamic completion features that can be used across threads
pub trait CompletionFeaturesDyn: Send + Sync + 'static {
Expand Down Expand Up @@ -185,3 +187,79 @@ impl EmbeddingFeatures for Model {
self.embedder.embed_query(text.to_string()).await
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum HybridContent {
Text { text: String },
Image { image_url: ImageDetail },
Audio { input_audio: AudioDetail },
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ImageDetail {
pub url: String,
pub detail: String,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct AudioDetail {
pub data: String,
pub format: String,
}

impl From<String> for HybridContent {
fn from(text: String) -> Self {
HybridContent::Text { text }
}
}

impl From<&str> for HybridContent {
fn from(text: &str) -> Self {
text.to_owned().into()
}
}

impl FromStr for HybridContent {
type Err = Infallible;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(s.into())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_hybrid_content() {
let content = HybridContent::Text {
text: "Hello, world!".to_string(),
};
let json = serde_json::to_string(&content).unwrap();
assert_eq!(json, r#"{"type":"text","text":"Hello, world!"}"#);

let ct: HybridContent = serde_json::from_str(&json).unwrap();
assert_eq!(ct, content);

let ct = HybridContent::from("Hello, world!");
assert_eq!(ct, content);

let content = HybridContent::Image {
image_url: ImageDetail {
url: "https://example.com/image.jpg".to_string(),
detail: "high".to_string(),
},
};

let json = serde_json::to_string(&content).unwrap();
assert_eq!(
json,
r#"{"type":"image","image_url":{"url":"https://example.com/image.jpg","detail":"high"}}"#
);

let ct: HybridContent = serde_json::from_str(&json).unwrap();
assert_eq!(ct, content);
}
}
14 changes: 10 additions & 4 deletions anda_engine/src/model/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::APP_USER_AGENT;
// ================================================================
// Main OpenAI Client
// ================================================================
const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
const API_BASE_URL: &str = "https://api.openai.com/v1";

// ================================================================
// OpenAI Embedding API
Expand Down Expand Up @@ -54,16 +54,22 @@ impl Client {
/// # Arguments
/// * `api_key` - OpenAI API key for authentication
pub fn new(api_key: &str, endpoint: Option<String>) -> Self {
let endpoint = endpoint.unwrap_or_else(|| API_BASE_URL.to_string());
let endpoint = if endpoint.is_empty() {
API_BASE_URL.to_string()
} else {
endpoint
};
Self {
endpoint: endpoint.unwrap_or_else(|| OPENAI_API_BASE_URL.to_string()),
endpoint,
http: reqwest::Client::builder()
.use_rustls_tls()
.https_only(true)
.http2_keep_alive_interval(Some(Duration::from_secs(25)))
.http2_keep_alive_timeout(Duration::from_secs(15))
.http2_keep_alive_while_idle(true)
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(120))
.timeout(Duration::from_secs(180))
.user_agent(APP_USER_AGENT)
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
Expand Down Expand Up @@ -110,7 +116,7 @@ impl Client {
/// # Arguments
/// * `model` - Name of the completion model to use
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
CompletionModel::new(self.clone(), if model.is_empty() { O3_MINI } else { model })
}
}

Expand Down
Loading

0 comments on commit 2264f5b

Please sign in to comment.