From d276d03c713241e92047707d219fc21f5e5b4b38 Mon Sep 17 00:00:00 2001 From: Kyle Tennison Date: Sun, 25 Aug 2024 14:12:59 -0400 Subject: [PATCH] wip --- .../llm-chain-openai/src/chatgpt/executor.rs | 21 +++++++++++++++++-- crates/llm-chain-openai/src/chatgpt/prompt.rs | 2 ++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/crates/llm-chain-openai/src/chatgpt/executor.rs b/crates/llm-chain-openai/src/chatgpt/executor.rs index ff300213..fc85dd45 100644 --- a/crates/llm-chain-openai/src/chatgpt/executor.rs +++ b/crates/llm-chain-openai/src/chatgpt/executor.rs @@ -6,6 +6,7 @@ use async_openai::types::ChatCompletionRequestMessage; use async_openai::types::ChatCompletionRequestUserMessageContent; use llm_chain::options::Opt; +use llm_chain::options::OptDiscriminants; use llm_chain::options::Options; use llm_chain::options::OptionsCascade; use llm_chain::output::Output; @@ -103,11 +104,27 @@ impl traits::Executor for Executor { Ok(Self { client, options }) } - async fn execute(&self, options: &Options, prompt: &Prompt) -> Result { + async fn execute(&self, _: &Options, prompt: &Prompt) -> Result { + + let options = &self.options; + + println!("OPENAI Options\n{:?}\n", &options); let opts = self.cascade(Some(options)); let client = self.client.clone(); let model = self.get_model_from_invocation_options(&opts); - let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap(); + + let stop_sequence = match opts.get(OptDiscriminants::StopSequence){ + Some(Opt::StopSequence(s)) => s.to_owned(), + Some(_) => Vec::new(), + None => Vec::new() + }; + + let input = create_chat_completion_request( + model, + prompt, + stop_sequence, + opts.is_streaming(), + ).unwrap(); if opts.is_streaming() { let res = async move { client.chat().create_stream(input).await } .await diff --git a/crates/llm-chain-openai/src/chatgpt/prompt.rs b/crates/llm-chain-openai/src/chatgpt/prompt.rs index a8a456cb..a0c30d92 100644 --- a/crates/llm-chain-openai/src/chatgpt/prompt.rs +++ b/crates/llm-chain-openai/src/chatgpt/prompt.rs @@ -77,6 +77,7 @@ pub fn format_chat_messages( pub fn create_chat_completion_request( model: String, prompt: &Prompt, + stop_tokens: Vec, is_streaming: bool, ) -> Result { let messages = format_chat_messages(prompt.to_chat())?; @@ -84,6 +85,7 @@ pub fn create_chat_completion_request( .model(model) .stream(is_streaming) .messages(messages) + .stop(stop_tokens) .build()?) }