diff --git a/Cargo.lock b/Cargo.lock index 848cbd6d..c39c6efc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1023,7 +1023,6 @@ dependencies = [ "indexmap 2.9.0", "indoc", "itertools 0.14.0", - "json5", "log", "neo4rs", "owo-colors", diff --git a/Cargo.toml b/Cargo.toml index 1a19b22b..24ffcb3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,7 +108,6 @@ bytes = "1.10.1" rand = "0.9.0" indoc = "2.0.6" owo-colors = "4.2.0" -json5 = "0.4.1" aws-config = "1.6.2" aws-sdk-s3 = "1.85.0" aws-sdk-sqs = "1.67.0" diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index 1001f908..3d770e5f 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -1,10 +1,7 @@ -use crate::llm::{ - LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat, - ToJsonSchemaOptions, -}; -use anyhow::{bail, Context, Result}; use async_trait::async_trait; -use json5; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, bail, Context}; +use crate::llm::prompt_utils::STRICT_JSON_PROMPT; use serde_json::Value; use crate::api_bail; @@ -48,9 +45,11 @@ impl LlmGenerationClient for Client { }); // Add system prompt as top-level field if present (required) - if let Some(system) = request.system_prompt { - payload["system"] = serde_json::json!(system); + let mut system_prompt = request.system_prompt.unwrap_or_default(); + if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) { + system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into(); } + payload["system"] = serde_json::json!(system_prompt); // Extract schema from output_format, error if not JsonSchema let schema = match request.output_format.as_ref() { @@ -67,8 +66,7 @@ impl LlmGenerationClient for Client { let encoded_api_key = encode(&self.api_key); - let resp = self - .client + let resp = self.client .post(url) .header("x-api-key", encoded_api_key.as_ref()) .header("anthropic-version", "2023-06-01") @@ -76,60 +74,22 @@ impl LlmGenerationClient for Client { .send() .await .context("HTTP error")?; - let mut resp_json: Value = resp.json().await.context("Invalid JSON")?; + let resp_json: Value = resp.json().await.context("Invalid JSON")?; if let Some(error) = resp_json.get("error") { bail!("Anthropic API error: {:?}", error); } - // Debug print full response - // println!("Anthropic API full response: {resp_json:?}"); - - let resp_content = &resp_json["content"]; - let tool_name = "report_result"; - let mut extracted_json: Option = None; - if let Some(array) = resp_content.as_array() { - for item in array { - if item.get("type") == Some(&Value::String("tool_use".to_string())) - && item.get("name") == Some(&Value::String(tool_name.to_string())) - { - if let Some(input) = item.get("input") { - extracted_json = Some(input.clone()); - break; - } - } - } - } - let text = if let Some(json) = extracted_json { - // Try strict JSON serialization first - serde_json::to_string(&json)? - } else { - // Fallback: try text if no tool output found - match &mut resp_json["content"][0]["text"] { - Value::String(s) => { - // Try strict JSON parsing first - match serde_json::from_str::(s) { - Ok(_) => std::mem::take(s), - Err(e) => { - // Try permissive json5 parsing as fallback - match json5::from_str::(s) { - Ok(value) => { - println!("[Anthropic] Used permissive JSON5 parser for output"); - serde_json::to_string(&value)? - }, - Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}"))) - } - } - } - } - _ => { - return Err(anyhow::anyhow!( - "No structured tool output or text found in response" - )) - } - } + // Extract the text response + let text = match resp_json["content"][0]["text"].as_str() { + Some(s) => s.to_string(), + None => bail!("No text in response"), }; - Ok(LlmGenerateResponse { text }) + // Try to parse as JSON + match serde_json::from_str::(&text) { + Ok(val) => Ok(LlmGenerateResponse::Json(val)), + Err(_) => Ok(LlmGenerateResponse::Text(text)), + } } fn json_schema_options(&self) -> ToJsonSchemaOptions { diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 11c34ebb..4a7600cb 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -1,12 +1,10 @@ -use crate::api_bail; -use crate::llm::{ - LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat, - ToJsonSchemaOptions, -}; -use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, bail, Context}; use serde_json::Value; +use crate::api_bail; use urlencoding::encode; +use crate::llm::prompt_utils::STRICT_JSON_PROMPT; pub struct Client { model: String, @@ -60,11 +58,14 @@ impl LlmGenerationClient for Client { // Prepare payload let mut payload = serde_json::json!({ "contents": contents }); - if let Some(system) = request.system_prompt { - payload["systemInstruction"] = serde_json::json!({ - "parts": [ { "text": system } ] - }); - } + if let Some(mut system) = request.system_prompt { + if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) { + system = format!("{STRICT_JSON_PROMPT}\n\n{system}").into(); + } + payload["systemInstruction"] = serde_json::json!({ + "parts": [ { "text": system } ] + }); +} // If structured output is requested, add schema and responseMimeType if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { @@ -79,13 +80,10 @@ impl LlmGenerationClient for Client { let api_key = &self.api_key; let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", - encode(&self.model), - encode(api_key) + encode(&self.model), encode(api_key) ); - let resp = self - .client - .post(&url) + let resp = self.client.post(&url) .json(&payload) .send() .await @@ -102,7 +100,15 @@ impl LlmGenerationClient for Client { _ => bail!("No text in response"), }; - Ok(LlmGenerateResponse { text }) + // If output_format is JsonSchema, try to parse as JSON + if let Some(OutputFormat::JsonSchema { .. }) = request.output_format { + match serde_json::from_str::(&text) { + Ok(val) => Ok(LlmGenerateResponse::Json(val)), + Err(_) => Ok(LlmGenerateResponse::Text(text)), + } + } else { + Ok(LlmGenerateResponse::Text(text)) + } } fn json_schema_options(&self) -> ToJsonSchemaOptions { @@ -113,4 +119,4 @@ impl LlmGenerationClient for Client { top_level_must_be_object: true, } } -} +} \ No newline at end of file diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 5a2706aa..b8865e8d 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -22,7 +22,7 @@ pub struct LlmSpec { model: String, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum OutputFormat<'a> { JsonSchema { name: Cow<'a, str>, @@ -38,8 +38,9 @@ pub struct LlmGenerateRequest<'a> { } #[derive(Debug)] -pub struct LlmGenerateResponse { - pub text: String, +pub enum LlmGenerateResponse { + Json(serde_json::Value), + Text(String), } #[async_trait] @@ -56,6 +57,7 @@ mod anthropic; mod gemini; mod ollama; mod openai; +mod prompt_utils; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { let client = match spec.api_type { diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index f2926077..afddaecf 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -1,6 +1,7 @@ use super::LlmGenerationClient; use anyhow::Result; use async_trait::async_trait; +use crate::llm::prompt_utils::STRICT_JSON_PROMPT; use schemars::schema::SchemaObject; use serde::{Deserialize, Serialize}; @@ -52,6 +53,10 @@ impl LlmGenerationClient for Client { &self, request: super::LlmGenerateRequest<'req>, ) -> Result { + let mut system_prompt = request.system_prompt.unwrap_or_default(); + if matches!(request.output_format, Some(super::OutputFormat::JsonSchema { .. })) { + system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into(); + } let req = OllamaRequest { model: &self.model, prompt: request.user_prompt.as_ref(), @@ -60,7 +65,7 @@ impl LlmGenerationClient for Client { OllamaFormat::JsonSchema(schema.as_ref()) }, ), - system: request.system_prompt.as_ref().map(|s| s.as_ref()), + system: Some(&system_prompt), stream: Some(false), }; let res = self @@ -71,9 +76,15 @@ impl LlmGenerationClient for Client { .await?; let body = res.text().await?; let json: OllamaResponse = serde_json::from_str(&body)?; - Ok(super::LlmGenerateResponse { - text: json.response, - }) + // Check if output_format is JsonSchema, try to parse as JSON + if let Some(super::OutputFormat::JsonSchema { .. }) = request.output_format { + match serde_json::from_str::(&json.response) { + Ok(val) => Ok(super::LlmGenerateResponse::Json(val)), + Err(_) => Ok(super::LlmGenerateResponse::Text(json.response)), + } + } else { + Ok(super::LlmGenerateResponse::Text(json.response)) + } } fn json_schema_options(&self) -> super::ToJsonSchemaOptions { diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 5675fc86..f75e1b66 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -64,8 +64,10 @@ impl LlmGenerationClient for Client { }, )); + // Save output_format before it is moved. + let output_format = request.output_format.clone(); // Create the chat completion request - let request = CreateChatCompletionRequest { + let openai_request = CreateChatCompletionRequest { model: self.model.clone(), messages, response_format: match request.output_format { @@ -85,7 +87,7 @@ impl LlmGenerationClient for Client { }; // Send request and get response - let response = self.client.chat().create(request).await?; + let response = self.client.chat().create(openai_request).await?; // Extract the response text from the first choice let text = response @@ -95,7 +97,15 @@ impl LlmGenerationClient for Client { .and_then(|choice| choice.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; - Ok(super::LlmGenerateResponse { text }) + // If output_format is JsonSchema, try to parse as JSON + if let Some(super::OutputFormat::JsonSchema { .. }) = output_format { + match serde_json::from_str::(&text) { + Ok(val) => Ok(super::LlmGenerateResponse::Json(val)), + Err(_) => Ok(super::LlmGenerateResponse::Text(text)), + } + } else { + Ok(super::LlmGenerateResponse::Text(text)) + } } fn json_schema_options(&self) -> super::ToJsonSchemaOptions { diff --git a/src/llm/prompt_utils.rs b/src/llm/prompt_utils.rs new file mode 100644 index 00000000..fe28f893 --- /dev/null +++ b/src/llm/prompt_utils.rs @@ -0,0 +1,4 @@ +// Shared prompt utilities for LLM clients +// Only import this in clients that require strict JSON output instructions (e.g., Anthropic, Gemini, Ollama) + +pub const STRICT_JSON_PROMPT: &str = "IMPORTANT: Output ONLY valid JSON that matches the schema. Do NOT say anything else. Do NOT explain. Do NOT preface. Do NOT add comments. If you cannot answer, output an empty JSON object: {}."; diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 060956c7..7a3abf23 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -1,7 +1,7 @@ use crate::prelude::*; use crate::llm::{ - new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, + new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmGenerateResponse, LlmSpec, OutputFormat, }; use crate::ops::sdk::*; use base::json_schema::build_json_schema; @@ -83,7 +83,10 @@ impl SimpleFunctionExecutor for Executor { }), }; let res = self.client.generate(req).await?; - let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?; + let json_value = match res { + LlmGenerateResponse::Json(val) => val, + LlmGenerateResponse::Text(text) => serde_json::from_str(&text)?, + }; let value = self.value_extractor.extract_value(json_value)?; Ok(value) } @@ -124,4 +127,4 @@ impl SimpleFunctionFactoryBase for Factory { ) -> Result> { Ok(Box::new(Executor::new(spec, resolved_input_schema).await?)) } -} +} \ No newline at end of file