Skip to content

Commit

Permalink
Implement Client trait for StreamContext (#134)
Browse files Browse the repository at this point in the history
Signed-off-by: José Ulises Niño Rivera <[email protected]>
  • Loading branch information
junr03 authored Oct 7, 2024
1 parent 5bfccd3 commit c1cfbcd
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 218 deletions.
12 changes: 12 additions & 0 deletions arch/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions arch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"
derivative = "2.2.0"
sha2 = "0.10.8"

[dev-dependencies]
Expand Down
1 change: 1 addition & 0 deletions arch/src/filter_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl FilterContext {

let call_args = CallArgs::new(
MODEL_SERVER_NAME,
"/embeddings",
vec![
(":method", "POST"),
(":path", "/embeddings"),
Expand Down
16 changes: 12 additions & 4 deletions arch/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::stats::{Gauge, IncrementingMetric};
use derivative::Derivative;
use log::debug;
use proxy_wasm::{traits::Context, types::Status};
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};

#[derive(Debug)]
#[derive(Derivative)]
#[derivative(Debug)]
pub struct CallArgs<'a> {
upstream: &'a str,
path: &'a str,
headers: Vec<(&'a str, &'a str)>,
#[derivative(Debug = "ignore")]
body: Option<&'a [u8]>,
trailers: Vec<(&'a str, &'a str)>,
timeout: Duration,
Expand All @@ -15,13 +19,15 @@ pub struct CallArgs<'a> {
impl<'a> CallArgs<'a> {
pub fn new(
upstream: &'a str,
path: &'a str,
headers: Vec<(&'a str, &'a str)>,
body: Option<&'a [u8]>,
trailers: Vec<(&'a str, &'a str)>,
timeout: Duration,
) -> Self {
CallArgs {
upstream,
path,
headers,
body,
trailers,
Expand All @@ -32,9 +38,10 @@ impl<'a> CallArgs<'a> {

#[derive(thiserror::Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}`, error: {internal_status:?}")]
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
upstream_name: String,
path: String,
internal_status: Status,
},
}
Expand All @@ -46,7 +53,7 @@ pub trait Client: Context {
&self,
call_args: CallArgs,
call_context: Self::CallContext,
) -> Result<(), ClientError> {
) -> Result<u32, ClientError> {
debug!(
"dispatching http call with args={:?} context={:?}",
call_args, call_context
Expand All @@ -61,10 +68,11 @@ pub trait Client: Context {
) {
Ok(id) => {
self.add_call_context(id, call_context);
Ok(())
Ok(id)
}
Err(status) => Err(ClientError::DispatchError {
upstream_name: String::from(call_args.upstream),
path: String::from(call_args.path),
internal_status: status,
}),
}
Expand Down
31 changes: 25 additions & 6 deletions arch/src/ratelimit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use public_types::configuration;
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
use std::{collections::HashMap, sync::OnceLock};
Expand All @@ -28,13 +29,18 @@ pub struct RatelimitMap {
}

// This version of Header demands that the user passes a header value to match on.
#[allow(unused)]
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Header {
pub key: String,
pub value: String,
}

impl Display for Header {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}

impl From<Header> for configuration::Header {
fn from(header: Header) -> Self {
Self {
Expand All @@ -44,6 +50,16 @@ impl From<Header> for configuration::Header {
}
}

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")]
ExceededLimit {
provider: String,
selector: Header,
tokens_used: NonZeroU32,
},
}

impl RatelimitMap {
// n.b new is private so that the only access to the Ratelimits can be done via the static
// reference inside a RwLock via ratelimit::ratelimits().
Expand Down Expand Up @@ -82,7 +98,7 @@ impl RatelimitMap {
provider: String,
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), String> {
) -> Result<(), Error> {
debug!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
Expand All @@ -96,7 +112,7 @@ impl RatelimitMap {
Some(limit) => limit,
};

let mut config_selector = configuration::Header::from(selector);
let mut config_selector = configuration::Header::from(selector.clone());

let (limit, limit_key) = match provider_limits.get(&config_selector) {
// This is a specific limit, i.e one that was configured with both key, and value.
Expand All @@ -119,8 +135,11 @@ impl RatelimitMap {

match limit.check_key_n(&limit_key, tokens_used) {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) => Err(String::from("Not allowed")),
Err(InsufficientCapacity(_)) => Err(String::from("Not allowed")),
Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit {
provider,
selector,
tokens_used,
}),
}
}
}
Expand Down
Loading

0 comments on commit c1cfbcd

Please sign in to comment.