Skip to content

Commit

Permalink
fix gemini function call
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jun 28, 2024
1 parent c8a51df commit 1bcc2f2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ futures = "0"
rand = "0"
base64 = "0"
uuid = { version = "1", features = ["v4"] }
bytes = "1"
26 changes: 20 additions & 6 deletions src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use std::mem;
use std::ops::Not;
use std::path::Path;
use std::rc::Rc;
use std::str;

use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use futures::StreamExt;
use reqwest::Response;
use tokio::sync::mpsc::channel;
Expand Down Expand Up @@ -117,20 +119,25 @@ impl Gemini {

async fn process_response(&mut self, mut rx: Receiver<GenerateContentResponse>, handler: &impl ChatHandler) -> Option<FunctionCall> {
let mut model_message = String::new();
let mut function_call = None;
while let Some(response) = rx.recv().await {
if let Some(usage) = response.usage_metadata {
self.usage.request_tokens += usage.prompt_token_count;
self.usage.response_tokens += usage.candidates_token_count;
}

let candidate = response.candidates.into_iter().next().unwrap();
if let Some(reason) = candidate.finish_reason.as_ref() {
if reason == "STOP" {
continue;
}
}
match candidate.content {
Some(content) => {
let part = content.parts.into_iter().next().unwrap();

if let Some(function_call) = part.function_call {
self.add_message(Content::new_function_call(function_call.clone()));
return Some(function_call);
if let Some(call) = part.function_call {
function_call = Some(call);
} else if let Some(text) = part.text {
model_message.push_str(&text);
handler.on_event(ChatEvent::Delta(text));
Expand All @@ -145,6 +152,11 @@ impl Gemini {
}
}

if let Some(call) = function_call {
self.add_message(Content::new_function_call(call.clone()));
return Some(call);
}

if !model_message.is_empty() {
self.add_message(Content::new_text(Role::Model, model_message));
}
Expand Down Expand Up @@ -172,18 +184,20 @@ impl Gemini {
};

let body = json::to_json(&request)?;
// info!("body={body}");
let body = Bytes::from(body);
let response = http_client::http_client()
.post(&self.url)
.bearer_auth(token())
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(body)
.body(body.clone())
.send()
.await?;

let status = response.status();
if status != 200 {
let body = str::from_utf8(&body).unwrap();
info!("body={}", body);
let response_text = response.text().await?;
return Err(Exception::ExternalError(format!(
"failed to call gcloud api, status={status}, response={response_text}"
Expand All @@ -201,7 +215,7 @@ async fn read_response_stream(response: Response, tx: Sender<GenerateContentResp
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
buffer.push_str(std::str::from_utf8(&chunk).unwrap());
buffer.push_str(str::from_utf8(&chunk).unwrap());

// first char is '[' or ','
if !is_valid_json(&buffer[1..]) {
Expand Down

0 comments on commit 1bcc2f2

Please sign in to comment.