Skip to content

Commit

Permalink
introduce loop into chat streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Oct 23, 2023
1 parent 8ba7079 commit 24848d5
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 118 deletions.
18 changes: 11 additions & 7 deletions src/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ pub async fn handle_v1_code_completion(
if !code_completion_post.stream {
crate::restream::scratchpad_interaction_not_stream(global_context.clone(), scratchpad, "completion".to_string(), &prompt, model_name, client1, api_key, &code_completion_post.parameters).await
} else {
crate::restream::scratchpad_interaction_stream(global_context.clone(), scratchpad, "completion-stream".to_string(), &prompt, model_name, client1, api_key, &code_completion_post.parameters).await
crate::restream::scratchpad_interaction_stream(global_context.clone(), scratchpad, "completion-stream".to_string(), prompt, model_name, client1, api_key, code_completion_post.parameters.clone()).await
}
}

Expand Down Expand Up @@ -220,12 +220,16 @@ async fn handle_v1_chat(
)?;
// info!("chat prompt {:?}\n{}", t1.elapsed(), prompt);
info!("chat prompt {:?}", t1.elapsed());
let streaming = chat_post.stream.unwrap_or(false);
if streaming {
crate::restream::scratchpad_interaction_stream(global_context.clone(), scratchpad, "chat-stream".to_string(), &prompt, model_name, client1, api_key, &chat_post.parameters).await
} else {
crate::restream::scratchpad_interaction_not_stream(global_context.clone(), scratchpad, "chat".to_string(), &prompt, model_name, client1, api_key, &chat_post.parameters).await
}
crate::restream::scratchpad_interaction_stream(
global_context.clone(),
scratchpad,
"chat-stream".to_string(),
prompt,
model_name,
client1,
api_key,
chat_post.parameters.clone()
).await
}


Expand Down
257 changes: 146 additions & 111 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,126 +128,130 @@ pub async fn scratchpad_interaction_stream(
global_context: Arc<ARwLock<GlobalContext>>,
mut scratchpad: Box<dyn ScratchpadAbstract>,
scope: String,
prompt: &str,
prompt: String,
mut model_name: String,
client: reqwest::Client,
bearer: String,
parameters: &SamplingParameters,
parameters: SamplingParameters,
) -> Result<Response<Body>, ScratchError> {
let t1 = std::time::SystemTime::now();
let (endpoint_style, endpoint_template, tele_storage) = {
let cx = global_context.write().await;
let caps = cx.caps.clone().unwrap();
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let mut event_source = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming(
&mut save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
}.map_err(|e| {
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
scope.clone(),
false,
e.to_string(),
));
ScratchError::new_but_skip_telemetry(StatusCode::INTERNAL_SERVER_ERROR, format!("forward_to_endpoint: {}", e))
})?;

let evstream = stream! {
let scratch = &mut scratchpad;
let mut finished: bool = false;
let mut problem_reported = false;
let mut was_correct_output_even_if_error = false;
while let Some(event) = event_source.next().await {
match event {
Ok(Event::Open) => {},
Ok(Event::Message(message)) => {
// info!("Message: {:#?}", message);
if message.data.starts_with("[DONE]") {
break;
}
let json = serde_json::from_str::<serde_json::Value>(&message.data).unwrap();
let value_str;
if let Some(token) = json.get("token") { // hf style produces this
let text = token.get("text").unwrap().as_str().unwrap().to_string();
let mut value: serde_json::Value;
(value, finished) = scratch.response_streaming(text, false, false).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
value["model"] = json!(model_name.clone());
value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
was_correct_output_even_if_error |= json.get("generated_text").is_some();
} else if let Some(choices) = json.get("choices") { // openai style
let choice0 = &choices[0];
let text = choice0.get("text").unwrap().as_str().unwrap().to_string();
let finish_reason = choice0.get("finish_reason").unwrap_or(&json!("")).as_str().unwrap().to_string();
let stop_toks = !finish_reason.is_empty() && finish_reason.starts_with("stop");
let stop_length = !finish_reason.is_empty() && !finish_reason.starts_with("stop");
let mut value: serde_json::Value;
(value, finished) = scratch.response_streaming(text, stop_toks, stop_length).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
model_name = json["model"].as_str().unwrap().to_string();
value["model"] = json!(model_name.clone());
value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
} else {
value_str = serde_json::to_string(&json!({"detail": format!("unrecognized response: {:?}", json)})).unwrap();
}
info!("yield: {:?}", value_str);
let scratch: &mut Box<dyn ScratchpadAbstract> = &mut scratchpad;
let (endpoint_style, endpoint_template, tele_storage) = {
let cx = global_context.write().await;
let caps = cx.caps.clone().unwrap();
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
loop {
let event_source_maybe = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming(
&mut save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
};
let mut event_source = match event_source_maybe {
Ok(event_source) => event_source,
Err(e) => {
let e_str = format!("forward_to_endpoint: {:?}", e);
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
scope.clone(),
false,
e_str.to_string(),
));
error!("forward_to_endpoint: {}", e_str);
let value_str = serde_json::to_string(&json!({"detail": e_str})).unwrap();
yield Result::<_, String>::Ok(value_str);
if finished {
break;
}
},
Err(err) => {
if was_correct_output_even_if_error {
// "restream error: Stream ended"
break;
}
error!("restream error: {}\n{:?}", err, err);
let problem_str = format!("restream error: {}", err);
{
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
scope.clone(),
false,
problem_str.clone(),
));
}
yield Result::<_, String>::Ok(serde_json::to_string(&json!({"detail": problem_str})).unwrap());
problem_reported = true;
event_source.close();
break;
},
}
};
let mut finished: bool = false;
let mut problem_reported = false;
let mut was_correct_output_even_if_error = false;
while let Some(event) = event_source.next().await {
match event {
Ok(Event::Open) => {},
Ok(Event::Message(message)) => {
// info!("Message: {:#?}", message);
if message.data.starts_with("[DONE]") {
break;
}
let json = serde_json::from_str::<serde_json::Value>(&message.data).unwrap();
let value_maybe = _push_streaming_json_into_scratchpad(
scratch,
&json,
&mut model_name,
&mut finished,
&mut was_correct_output_even_if_error,
);
if let Ok(mut value) = value_maybe {
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
info!("yield: {:?}", value_str);
yield Result::<_, String>::Ok(value_str);
} else {
let err_str = value_maybe.unwrap_err();
error!("unexpected error: {}", err_str);
let value_str = format!("data: {}\n\n", serde_json::to_string(&json!({"detail": err_str})).unwrap());
yield Result::<_, String>::Ok(value_str);
// TODO: send telemetry
break;
}
if finished {
break;
}
},
Err(err) => {
if was_correct_output_even_if_error {
// "restream error: Stream ended"
break;
}
error!("restream error: {}\n{:?}", err, err);
let problem_str = format!("restream error: {}", err);
{
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
scope.clone(),
false,
problem_str.clone(),
));
}
yield Result::<_, String>::Ok(serde_json::to_string(&json!({"detail": problem_str})).unwrap());
problem_reported = true;
event_source.close();
break;
},
}
}
}
if problem_reported {
return;
} else if !finished {
let mut value: serde_json::Value;
(value, _) = scratch.response_streaming("".to_string(), false, true).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
value["model"] = json!(model_name.clone());
let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
info!("yield final: {:?}", value_str);
yield Result::<_, String>::Ok(value_str);
if problem_reported {
return;
} else if !finished {
let mut value: serde_json::Value;
(value, _) = scratch.response_streaming("".to_string(), false, true).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
value["model"] = json!(model_name.clone());
let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
info!("yield final: {:?}", value_str);
yield Result::<_, String>::Ok(value_str);
}
break;
}
info!("yield: [DONE]");
yield Result::<_, String>::Ok("data: [DONE]\n\n".to_string());
Expand All @@ -266,6 +270,37 @@ pub async fn scratchpad_interaction_stream(
return Ok(response);
}

fn _push_streaming_json_into_scratchpad(
scratch: &mut Box<dyn ScratchpadAbstract>,
json: &serde_json::Value,
model_name: &mut String,
finished: &mut bool,
was_correct_output_even_if_error: &mut bool,
) -> Result<serde_json::Value, String> {
if let Some(token) = json.get("token") { // hf style produces this
let text = token.get("text").unwrap().as_str().unwrap().to_string();
let mut value: serde_json::Value;
(value, *finished) = scratch.response_streaming(text, false, false).unwrap();
value["model"] = json!(model_name.clone());
*was_correct_output_even_if_error |= json.get("generated_text").is_some();
// Ok(format!("data: {}\n\n", serde_json::to_string(&value).unwrap()));
Ok(value)
} else if let Some(choices) = json.get("choices") { // openai style
let choice0 = &choices[0];
let text = choice0.get("text").unwrap().as_str().unwrap().to_string();
let finish_reason = choice0.get("finish_reason").unwrap_or(&json!("")).as_str().unwrap().to_string();
let stop_toks = !finish_reason.is_empty() && finish_reason.starts_with("stop");
let stop_length = !finish_reason.is_empty() && !finish_reason.starts_with("stop");
let mut value: serde_json::Value;
(value, *finished) = scratch.response_streaming(text, stop_toks, stop_length).unwrap();
*model_name = json["model"].as_str().unwrap().to_string();
value["model"] = json!(model_name.clone());
Ok(value)
} else {
Err(format!("unrecognized response: {:?}", json))
}
}

pub async fn cached_not_stream(
cached_json_value: &serde_json::Value,
) -> Result<Response<Body>, ScratchError> {
Expand Down

0 comments on commit 24848d5

Please sign in to comment.