diff --git a/Cargo.lock b/Cargo.lock index ae547da..e576595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -978,6 +978,7 @@ name = "puppet" version = "0.1.0" dependencies = [ "base64", + "bytes", "clap", "clap_complete", "futures", diff --git a/Cargo.toml b/Cargo.toml index eb41969..9c41844 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ futures = "0" rand = "0" base64 = "0" uuid = { version = "1", features = ["v4"] } +bytes = "1" diff --git a/src/gcloud/gemini.rs b/src/gcloud/gemini.rs index 82a00aa..fa9aa7d 100644 --- a/src/gcloud/gemini.rs +++ b/src/gcloud/gemini.rs @@ -3,9 +3,11 @@ use std::mem; use std::ops::Not; use std::path::Path; use std::rc::Rc; +use std::str; use base64::prelude::BASE64_STANDARD; use base64::Engine; +use bytes::Bytes; use futures::StreamExt; use reqwest::Response; use tokio::sync::mpsc::channel; @@ -117,6 +119,7 @@ impl Gemini { async fn process_response(&mut self, mut rx: Receiver, handler: &impl ChatHandler) -> Option { let mut model_message = String::new(); + let mut function_call = None; while let Some(response) = rx.recv().await { if let Some(usage) = response.usage_metadata { self.usage.request_tokens += usage.prompt_token_count; @@ -124,13 +127,17 @@ impl Gemini { } let candidate = response.candidates.into_iter().next().unwrap(); + if let Some(reason) = candidate.finish_reason.as_ref() { + if reason == "STOP" { + continue; + } + } match candidate.content { Some(content) => { let part = content.parts.into_iter().next().unwrap(); - if let Some(function_call) = part.function_call { - self.add_message(Content::new_function_call(function_call.clone())); - return Some(function_call); + if let Some(call) = part.function_call { + function_call = Some(call); } else if let Some(text) = part.text { model_message.push_str(&text); handler.on_event(ChatEvent::Delta(text)); @@ -145,6 +152,11 @@ impl Gemini { } } + if let Some(call) = function_call { + self.add_message(Content::new_function_call(call.clone())); + return Some(call); + } + if !model_message.is_empty() { self.add_message(Content::new_text(Role::Model, model_message)); } @@ -172,18 +184,20 @@ impl Gemini { }; let body = json::to_json(&request)?; - // info!("body={body}"); + let body = Bytes::from(body); let response = http_client::http_client() .post(&self.url) .bearer_auth(token()) .header("Content-Type", "application/json") .header("Accept", "application/json") - .body(body) + .body(body.clone()) .send() .await?; let status = response.status(); if status != 200 { + let body = str::from_utf8(&body).unwrap(); + info!("body={}", body); let response_text = response.text().await?; return Err(Exception::ExternalError(format!( "failed to call gcloud api, status={status}, response={response_text}" @@ -201,7 +215,7 @@ async fn read_response_stream(response: Response, tx: Sender { - buffer.push_str(std::str::from_utf8(&chunk).unwrap()); + buffer.push_str(str::from_utf8(&chunk).unwrap()); // first char is '[' or ',' if !is_valid_json(&buffer[1..]) {