Skip to content

Commit

Permalink
fix: Add request timeout.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Oct 14, 2024
1 parent 31f1af3 commit 9f91e3a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho
}
};
let tokenizer = Arc::new(tokenizer);
// let backend = OpenAITextGenerationBackend::new("".to_string(), "http://10.90.11.68:8000".to_string());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
run_config.url.clone(),
run_config.tokenizer_name.clone(),
tokenizer,
run_config.duration,
)?;

let config = BenchmarkConfig {
Expand Down
66 changes: 63 additions & 3 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::fmt::Display;
use std::path::PathBuf;
use std::sync::atomic::AtomicI64;
use std::sync::{Arc, Mutex};
use std::time;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokio::sync::mpsc::Sender;

Expand Down Expand Up @@ -58,6 +59,7 @@ pub struct OpenAITextGenerationBackend {
pub model_name: String,
pub client: reqwest::Client,
pub tokenizer: Arc<Tokenizer>,
pub timeout: time::Duration,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand Down Expand Up @@ -90,6 +92,7 @@ pub struct OpenAITextGenerationRequest {
pub max_tokens: Option<u64>,
pub stream: bool,
pub stop: Option<String>,
pub temperature: f64,
}

impl OpenAITextGenerationBackend {
Expand All @@ -98,13 +101,15 @@ impl OpenAITextGenerationBackend {
base_url: String,
model_name: String,
tokenizer: Arc<Tokenizer>,
timeout: time::Duration,
) -> anyhow::Result<Self> {
Ok(Self {
client: reqwest::Client::new(),
api_key,
base_url,
model_name,
tokenizer,
timeout,
})
}
}
Expand Down Expand Up @@ -140,6 +145,7 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
max_tokens: request.num_decode_tokens,
stream: true,
stop: None,
temperature: 0.0
};
let req = self
.client
Expand All @@ -148,7 +154,8 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
"Authorization",
format!("Bearer {token}", token = self.api_key),
)
.json(&serde_json::json!(body));
.json(&serde_json::json!(body))
.timeout(self.timeout);
// start timer
aggregated_response.start(request.num_prompt_tokens);
let mut es = EventSource::new(req).unwrap();
Expand Down Expand Up @@ -354,8 +361,7 @@ impl ConversationTextRequestGenerator {
bar.set_style(
ProgressStyle::with_template(
"Tokenizing prompts [{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}",
)
.unwrap(),
)?,
);
split(data, entry_splitter).for_each(|subrange| {
for entry in subrange {
Expand Down Expand Up @@ -687,6 +693,7 @@ mod tests {
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(10),
)
.unwrap();
let request = TextGenerationRequest {
Expand Down Expand Up @@ -745,6 +752,7 @@ mod tests {
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(10),
)
.unwrap();
let request = TextGenerationRequest {
Expand Down Expand Up @@ -829,6 +837,7 @@ mod tests {
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(10),
)
.unwrap();
let request = TextGenerationRequest {
Expand Down Expand Up @@ -874,6 +883,7 @@ mod tests {
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(10),
)
.unwrap();
let request = TextGenerationRequest {
Expand Down Expand Up @@ -919,6 +929,7 @@ mod tests {
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(10),
)
.unwrap();
let request = TextGenerationRequest {
Expand Down Expand Up @@ -967,6 +978,7 @@ mod tests {
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(10),
)
.unwrap();
let request = TextGenerationRequest {
Expand All @@ -993,4 +1005,52 @@ mod tests {
assert_eq!(responses[0].failed, true);
assert_eq!(responses[0].num_generated_tokens, 8);
}

/// Test that request timeout is handled correctly
#[tokio::test]
async fn test_timeout_should_fail_request() {
let mut s = mockito::Server::new_async().await;
s.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_chunked_body(|w| {
w.write_all(b"data: {\"choices\": [{\"message\": null, \"finish_reason\": null, \"delta\": {\"content\": \"Hello, world!\"}}]}\n\n").unwrap();
// sleep for 5s
sleep(std::time::Duration::from_secs(5));
w.write_all(b"data: [DONE]\n\n")
})
.create_async().await;
let url = s.url();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
url,
"gpt2".to_string(),
tokenizer,
time::Duration::from_secs(1),
)
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
num_decode_tokens: Some(16),
system_prompt: None,
};
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let request = Arc::new(request);
tokio::spawn(async move {
backend.generate(request.clone(), tx.clone()).await;
});
let reponses = Arc::new(RwLock::new(Vec::new()));
let responses_clone = reponses.clone();
let t = tokio::spawn(async move {
while let Some(item) = rx.recv().await {
responses_clone.write().await.push(item);
}
});
t.await.unwrap();
let responses = reponses.read().await;
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].failed, true);
}
}

0 comments on commit 9f91e3a

Please sign in to comment.