diff --git a/src/azure/chatgpt.rs b/src/azure/chatgpt.rs index 482db06..ee9721e 100644 --- a/src/azure/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -252,6 +252,7 @@ async fn read_sse_response(http_response: Response, tx: &mpsc::Sender) - if let Some(finish_reason) = stream_choice.finish_reason { choice.finish_reason = finish_reason; if choice.finish_reason == "stop" { + // chatgpt doesn't return '\n' at end of message tx.send("\n".to_string()).await?; } } diff --git a/src/command/complete.rs b/src/command/complete.rs index e4761be..f160c2e 100644 --- a/src/command/complete.rs +++ b/src/command/complete.rs @@ -95,7 +95,7 @@ impl Complete { if !message.is_empty() { return Err(anyhow!("system message must be at first")); } - if let Some(option) = parse_option(line) { + if let Some(option) = parse_option(line)? { info!("option: {:?}", option); model.option(option); } @@ -171,13 +171,13 @@ async fn add_message(model: &mut llm::Model, state: &ParserState, message: Strin Ok(()) } -fn parse_option(line: &str) -> Option { - let regex = Regex::new(r".*temperature=(\d+\.\d+).*").unwrap(); +fn parse_option(line: &str) -> Result> { + let regex = Regex::new(r".*temperature=(\d+\.\d+).*")?; if let Some(capture) = regex.captures(line) { - let temperature = f32::from_str(&capture[1]).unwrap(); - Some(ChatOption { temperature }) + let temperature = f32::from_str(&capture[1])?; + Ok(Some(ChatOption { temperature })) } else { - None + Ok(None) } } @@ -194,6 +194,6 @@ mod tests { #[test] fn parse_option() { let option = super::parse_option("# system, temperature=2.0, top_p=0.95"); - assert_eq!(option.unwrap().temperature, 2.0); + assert_eq!(option.unwrap().unwrap().temperature, 2.0); } }