Skip to content

Commit

Permalink
fix double reader locking issue in protocol.rs (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
Iamdavidonuh authored Sep 23, 2024
1 parent 958ccc2 commit 45acd5e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
4 changes: 2 additions & 2 deletions ahnlich/ai/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ impl AIStoreHandler {
}

/// Stores storeinput into ahnlich db
#[tracing::instrument(skip(self), fields(input_length=inputs.len()))]
#[tracing::instrument(skip(self, inputs), fields(input_length=inputs.len()))]
pub(crate) async fn set(
&self,
store_name: &StoreName,
Expand All @@ -233,7 +233,7 @@ impl AIStoreHandler {

/// Converts (storekey, storevalue) into (storeinput, storevalue)
/// by removing the reserved_key from storevalue
#[tracing::instrument(skip(self))]
#[tracing::instrument(skip(self, output), fields(output_len=output.len()))]
pub(crate) fn store_key_val_to_store_input_val(
&self,
output: Vec<(StoreKey, StoreValue)>,
Expand Down
33 changes: 17 additions & 16 deletions ahnlich/utils/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use std::fmt::Debug;
use std::io::Error;
use std::io::ErrorKind;
use std::sync::Arc;
use tokio::sync::MutexGuard;

use task_manager::TaskState;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
Expand Down Expand Up @@ -46,51 +48,52 @@ where
match reader.read_exact(&mut magic_bytes_buf).await {
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
let error = "Hung up on buffered stream";
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
}
Err(e) => {
let error = format!("Error reading from task buffered stream {e}");
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
}
Ok(_) => {
if magic_bytes_buf != MAGIC_BYTES {
let error = "Invalid request stream".to_string();
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
}
if let Err(error) = reader.read_exact(&mut version_buf).await {
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
}
let version = match Version::deserialize_magic_bytes(&version_buf) {
Ok(version) => version,
Err(error) => {
let error = format!("Unable to parse version chunk {error}");
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
}
};
if !VERSION.is_compatible(&version) {
let error = format!(
"Incompatible versions, Server: {:?}, Client {version:?}",
*VERSION
);
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
}
// cap the message size to be of length 1MiB
if let Err(error) = reader.read_exact(&mut length_buf).await {
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
};
let data_length = u64::from_le_bytes(length_buf);
if data_length > self.maximum_message_size() {
let error = format!(
"Message cannot exceed {} bytes, configure `message_size` for higher",
self.maximum_message_size()
);
return self.handle_error(error, false).await;
return self.handle_error(reader, error, true).await;
};

let mut data: Vec<_> = match FallibleVec::try_with_capacity(data_length as usize) {
Err(error) => {
return self
.handle_error(
reader,
format!("Could not allocate buffer for message body {:?}", error),
true,
)
Expand All @@ -101,14 +104,15 @@ where
if let Err(error) = data.try_resize(data_length as usize, 0u8) {
return self
.handle_error(
reader,
format!("Could not resize buffer for message body {:?}", error),
true,
)
.await;
};
if let Err(e) = reader.read_exact(&mut data).await {
let error = format!("Could not read data buffer {e}");
return self.handle_error(error.to_string(), false).await;
return self.handle_error(reader, error.to_string(), false).await;
};
match Self::ServerQuery::deserialize(&data) {
Ok(queries) => {
Expand All @@ -119,14 +123,14 @@ where
.map_err(|err| Error::new(ErrorKind::Other, err))
{
Ok(parent_context) => parent_context,
Err(error) => return self.handle_error(error, false).await,
Err(error) => return self.handle_error(reader, error, false).await,
};
span.set_parent(parent_context);
}
let results = self.handle(queries.into_inner()).instrument(span).await;
if let Ok(binary_results) = results.serialize() {
if let Err(error) = reader.get_mut().write_all(&binary_results).await {
return self.handle_error(error, false).await;
return self.handle_error(reader, error, false).await;
};
log::debug!(
"Sent Response of length {}, {:?}",
Expand All @@ -136,7 +140,7 @@ where
}
}
Err(error) => {
return self.handle_error(error, true).await;
return self.handle_error(reader, error, true).await;
}
}
}
Expand All @@ -146,14 +150,13 @@ where

async fn handle_error(
&self,
mut reader: MutexGuard<'_, BufReader<TcpStream>>,
error: impl ToString + Send,
respond_with_error: bool,
) -> TaskState {
let error = self.prefix_log(error.to_string());
log::error!("{error}");
if respond_with_error {
let reader = self.reader();
let mut reader = reader.lock().await;
match Self::ServerResponse::from_error(format!(
"Could not deserialize query, error is {error}"
))
Expand All @@ -166,8 +169,6 @@ where
Ok(deserialize_error) => {
if let Err(error) = reader.get_mut().write_all(&deserialize_error).await {
log::error!("{}", self.prefix_log(format!("{error}")));
} else {
return TaskState::Continue;
}
}
};
Expand Down

0 comments on commit 45acd5e

Please sign in to comment.