Skip to content
This repository has been archived by the owner on Sep 7, 2024. It is now read-only.

Commit

Permalink
feat: make stream_message actually stream messages
Browse files Browse the repository at this point in the history
  • Loading branch information
roushou committed Aug 16, 2024
1 parent 65b25e9 commit 750d036
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 27 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 32 additions & 23 deletions anthropic/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::fmt;
use futures_util::StreamExt;
use futures_util::{stream, Stream, StreamExt};
use reqwest::{
header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE},
Method, RequestBuilder, Url,
Expand Down Expand Up @@ -105,7 +105,10 @@ impl Client {
.map_err(AnthropicError::from)
}

pub async fn stream_message(&self, request: MessageRequest) -> Result<(), AnthropicError> {
pub async fn stream_message(
&self,
request: MessageRequest,
) -> Result<impl Stream<Item = Result<StreamEvent, AnthropicError>>, AnthropicError> {
let response = self
.request(Method::POST, "messages")?
.header(ACCEPT, "text/event-stream")
Expand All @@ -121,29 +124,35 @@ impl Client {
}
}

let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk_str = std::str::from_utf8(&chunk).unwrap();

for event in chunk_str.split("\n\n") {
let event = event.trim();
if event.is_empty() {
continue;
}

let data: Vec<&str> = event.split("\n").collect();

if let Some(content) = data[1].strip_prefix("data: ") {
let content = StreamEvent::from_str(content)?;
if let StreamEvent::ContentBlockDelta(content) = content {
print!("{}", content.delta.text);
}
};
Ok(response.bytes_stream().flat_map(move |chunk| match chunk {
Ok(bytes) => {
let events = Self::parse_stream_chunk(&bytes);
stream::iter(events)
}
}
Err(err) => stream::iter(vec![Err(AnthropicError::from(err))]),
}))
}

Ok(())
fn parse_stream_chunk(bytes: &[u8]) -> Vec<Result<StreamEvent, AnthropicError>> {
let chunk_str = match std::str::from_utf8(bytes).map_err(AnthropicError::Utf8Error) {
Ok(chunk_str) => chunk_str,
Err(err) => return vec![Err(err)],
};
chunk_str
.split("\n\n")
.filter(|event| !event.trim().is_empty())
.map(|event| {
event
.lines()
.find(|line| line.starts_with("data: "))
.and_then(|line| line.strip_prefix("data: "))
.ok_or(AnthropicError::InvalidStreamEvent)
.and_then(|content| {
StreamEvent::from_str(content)
.map_err(|_| AnthropicError::InvalidStreamEvent)
})
})
.collect()
}
}

Expand Down
8 changes: 8 additions & 0 deletions anthropic/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::Utf8Error;

use serde::Deserialize;

use crate::client::ApiVersionError;
Expand Down Expand Up @@ -28,6 +30,12 @@ pub enum AnthropicError {
#[error("Missing API key {0}")]
MissingApiKey(&'static str),

#[error("Invalid Stream Event")]
InvalidStreamEvent,

#[error("UTF8 Error: {0}")]
Utf8Error(#[from] Utf8Error),

#[error("Unexpected error: {0}")]
Unexpected(String),
}
Expand Down
2 changes: 1 addition & 1 deletion examples/basic-message/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn main() {
role: Role::User,
content: vec![Content {
content_type: ContentType::Text,
text: "Hello World".to_string(),
text: "Explain the theory of relativity".to_string(),
}],
}],
..Default::default()
Expand Down
1 change: 1 addition & 0 deletions examples/stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ edition = "2021"

[dependencies]
anthropic-rs = { path = "../../anthropic" }
futures-util = "0.3.30"
tokio = { version = "1.39.2", features = ["full"] }
25 changes: 22 additions & 3 deletions examples/stream/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use anthropic_rs::{
api::message::{Content, ContentType, Message, MessageRequest, Role},
api::{
message::{Content, ContentType, Message, MessageRequest, Role},
stream::StreamEvent,
},
client::Client,
config::Config,
models::model::Model,
};
use futures_util::StreamExt;
use std::io::Write;

#[tokio::main]
async fn main() {
Expand All @@ -21,11 +26,25 @@ async fn main() {
role: Role::User,
content: vec![Content {
content_type: ContentType::Text,
text: "Hello World".to_string(),
text: "Explain the theory of relativity".to_string(),
}],
}],
..Default::default()
};

client.stream_message(message.clone()).await.unwrap();
let mut stream = client.stream_message(message.clone()).await.unwrap();

while let Some(event) = stream.next().await {
match event {
Ok(event) => match event {
StreamEvent::ContentBlockDelta(content) => {
print!("{}", content.delta.text);
std::io::stdout().flush().unwrap();
}
StreamEvent::MessageStop => break,
_ => {}
},
Err(err) => println!("{}", err),
}
}
}

0 comments on commit 750d036

Please sign in to comment.