Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kyle-tennison committed Aug 25, 2024
1 parent 730c316 commit d276d03
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
21 changes: 19 additions & 2 deletions crates/llm-chain-openai/src/chatgpt/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use async_openai::types::ChatCompletionRequestMessage;

use async_openai::types::ChatCompletionRequestUserMessageContent;
use llm_chain::options::Opt;
use llm_chain::options::OptDiscriminants;
use llm_chain::options::Options;
use llm_chain::options::OptionsCascade;
use llm_chain::output::Output;
Expand Down Expand Up @@ -103,11 +104,27 @@ impl traits::Executor for Executor {
Ok(Self { client, options })
}

async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
async fn execute(&self, _: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {

let options = &self.options;

println!("OPENAI Options\n{:?}\n", &options);
let opts = self.cascade(Some(options));
let client = self.client.clone();
let model = self.get_model_from_invocation_options(&opts);
let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();

let stop_sequence = match opts.get(OptDiscriminants::StopSequence){
Some(Opt::StopSequence(s)) => s.to_owned(),
Some(_) => Vec::new(),
None => Vec::new()
};

let input = create_chat_completion_request(
model,
prompt,
stop_sequence,
opts.is_streaming(),
).unwrap();
if opts.is_streaming() {
let res = async move { client.chat().create_stream(input).await }
.await
Expand Down
2 changes: 2 additions & 0 deletions crates/llm-chain-openai/src/chatgpt/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ pub fn format_chat_messages(
pub fn create_chat_completion_request(
model: String,
prompt: &Prompt,
stop_tokens: Vec<String>,
is_streaming: bool,
) -> Result<CreateChatCompletionRequest, OpenAIInnerError> {
let messages = format_chat_messages(prompt.to_chat())?;
Ok(CreateChatCompletionRequestArgs::default()
.model(model)
.stream(is_streaming)
.messages(messages)
.stop(stop_tokens)
.build()?)
}

Expand Down

0 comments on commit d276d03

Please sign in to comment.