From 973155bccb65b6f6faf5509c6f53175bd63c2ff9 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 16 Aug 2024 13:59:05 +0100 Subject: [PATCH 1/8] Remove use of filter in get_attribute Signed-off-by: Adam Cattermole --- src/attribute.rs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/attribute.rs b/src/attribute.rs index aada6176..357e9f83 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -1,7 +1,6 @@ use crate::configuration::Path; -use crate::filter::http_context::Filter; use chrono::{DateTime, FixedOffset}; -use proxy_wasm::traits::Context; +use proxy_wasm::hostcalls; pub trait Attribute { fn parse(raw_attribute: Vec) -> Result @@ -105,15 +104,12 @@ impl Attribute for DateTime { } #[allow(dead_code)] -pub fn get_attribute(f: &Filter, attr: &str) -> Result +pub fn get_attribute(attr: &str) -> Result where T: Attribute, { - match f.get_property(Path::from(attr).tokens()) { - None => Err(format!( - "#{} get_attribute: not found: {}", - f.context_id, attr - )), - Some(attribute_bytes) => T::parse(attribute_bytes), + match hostcalls::get_property(Path::from(attr).tokens()) { + Ok(Some(attribute_bytes)) => T::parse(attribute_bytes), + _ => Err(format!("get_attribute: not found: {}", attr)), } } From 0e87387f084972ed4ec77d2187f311c332591ccc Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 16 Aug 2024 14:39:37 +0100 Subject: [PATCH 2/8] Minor refactor of rate-limit service Signed-off-by: Adam Cattermole --- src/service/rate_limit.rs | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 6c4726c5..5fe03b41 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -14,12 +14,13 @@ pub struct RateLimitService { } impl RateLimitService { - pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> RateLimitService { + pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> Self { Self { endpoint: String::from(endpoint), tracing_headers: metadata, } } + pub fn message( domain: String, descriptors: RepeatedField, @@ -34,31 +35,22 @@ impl RateLimitService { } } -fn grpc_call( - upstream_name: &str, - initial_metadata: Vec<(&str, &[u8])>, - message: RateLimitRequest, -) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling - dispatch_grpc_call( - upstream_name, - RATELIMIT_SERVICE_NAME, - RATELIMIT_METHOD_NAME, - initial_metadata, - Some(&msg), - Duration::from_secs(5), - ) -} - impl Service for RateLimitService { fn send(&self, message: RateLimitRequest) -> Result { - grpc_call( + let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling + let metadata = self + .tracing_headers + .iter() + .map(|(header, value)| (header.as_str(), value.as_slice())) + .collect(); + + dispatch_grpc_call( self.endpoint.as_str(), - self.tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(), - message, + RATELIMIT_SERVICE_NAME, + RATELIMIT_METHOD_NAME, + metadata, + Some(&msg), + Duration::from_secs(5), ) } } From 6d9729543c783485344a765ce143ef7e0623d8d4 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 16 Aug 2024 14:41:06 +0100 Subject: [PATCH 3/8] Add initial implementation of auth service Signed-off-by: Adam Cattermole --- src/envoy/mod.rs | 7 +++ src/service.rs | 1 + src/service/auth.rs | 119 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 src/service/auth.rs diff --git a/src/envoy/mod.rs b/src/envoy/mod.rs index 68e18d60..db527204 100644 --- a/src/envoy/mod.rs +++ b/src/envoy/mod.rs @@ -31,6 +31,13 @@ mod token_bucket; mod value; pub use { + address::{Address, SocketAddress}, + attribute_context::{ + AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, + AttributeContext_Request, + }, + base::Metadata, + external_auth::CheckRequest, ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry}, rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code}, }; diff --git a/src/service.rs b/src/service.rs index b63bb827..3cad9530 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,3 +1,4 @@ +pub(crate) mod auth; pub(crate) mod rate_limit; use protobuf::Message; diff --git a/src/service/auth.rs b/src/service/auth.rs new file mode 100644 index 00000000..a623bc8e --- /dev/null +++ b/src/service/auth.rs @@ -0,0 +1,119 @@ +use crate::attribute::get_attribute; +use crate::envoy::{ + Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, + AttributeContext_Request, CheckRequest, Metadata, SocketAddress, +}; +use crate::filter::http_context::TracingHeader; +use crate::service::Service; +use chrono::{DateTime, FixedOffset, Timelike}; +use protobuf::well_known_types::Timestamp; +use protobuf::Message; +use proxy_wasm::hostcalls; +use proxy_wasm::hostcalls::dispatch_grpc_call; +use proxy_wasm::types::{Bytes, MapType, Status}; +use std::collections::HashMap; +use std::time::Duration; + +const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; +const AUTH_METHOD_NAME: &str = "Check"; + +pub struct AuthService { + endpoint: String, + tracing_headers: Vec<(TracingHeader, Bytes)>, +} + +impl AuthService { + pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> Self { + Self { + endpoint: String::from(endpoint), + tracing_headers: metadata, + } + } + + pub fn message() -> CheckRequest { + AuthService::build_check_req() + } + + fn build_check_req() -> CheckRequest { + let mut auth_req = CheckRequest::default(); + let mut attr = AttributeContext::default(); + attr.set_request(AuthService::build_request()); + attr.set_destination(AuthService::build_peer( + get_attribute::("destination.address").unwrap_or_default(), + get_attribute::("destination.port").unwrap_or_default() as u32, + )); + attr.set_source(AuthService::build_peer( + get_attribute::("source.address").unwrap_or_default(), + get_attribute::("source.port").unwrap_or_default() as u32, + )); + // todo(adam-cattermole): for now we set the context_extensions to the request host + // but this should take other info into account + let context_extensions = HashMap::from([( + "host".to_string(), + attr.get_request().get_http().host.to_owned(), + )]); + attr.set_context_extensions(context_extensions); + attr.set_metadata_context(Metadata::default()); + auth_req.set_attributes(attr); + auth_req + } + + fn build_request() -> AttributeContext_Request { + let mut request = AttributeContext_Request::default(); + let mut http = AttributeContext_HttpRequest::default(); + let headers: HashMap = hostcalls::get_map(MapType::HttpRequestHeaders) + .unwrap() + .into_iter() + .collect(); + + http.set_host(get_attribute::("request.host").unwrap_or_default()); + http.set_method(get_attribute::("request.method").unwrap_or_default()); + http.set_scheme(get_attribute::("request.scheme").unwrap_or_default()); + http.set_path(get_attribute::("request.path").unwrap_or_default()); + http.set_protocol(get_attribute::("request.protocol").unwrap_or_default()); + + http.set_headers(headers); + request.set_time(get_attribute("request.time").map_or( + Timestamp::new(), + |date_time: DateTime| Timestamp { + nanos: date_time.nanosecond() as i32, + seconds: date_time.second() as i64, + unknown_fields: Default::default(), + cached_size: Default::default(), + }, + )); + request.set_http(http); + request + } + + fn build_peer(host: String, port: u32) -> AttributeContext_Peer { + let mut peer = AttributeContext_Peer::default(); + let mut address = Address::default(); + let mut socket_address = SocketAddress::default(); + socket_address.set_address(host); + socket_address.set_port_value(port); + address.set_socket_address(socket_address); + peer.set_address(address); + peer + } +} + +impl Service for AuthService { + fn send(&self, message: CheckRequest) -> Result { + let msg = Message::write_to_bytes(&message).unwrap(); // TODO(adam-cattermole): Error Handling + let metadata = self + .tracing_headers + .iter() + .map(|(header, value)| (header.as_str(), value.as_slice())) + .collect(); + + dispatch_grpc_call( + self.endpoint.as_str(), + AUTH_SERVICE_NAME, + AUTH_METHOD_NAME, + metadata, + Some(&msg), + Duration::from_secs(5), + ) + } +} From db9ab76711446a701208394dc5be3d664a59cbd2 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 16 Aug 2024 16:00:29 +0100 Subject: [PATCH 4/8] Pass in the host to set in context_extensions Signed-off-by: Adam Cattermole --- src/service/auth.rs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/service/auth.rs b/src/service/auth.rs index a623bc8e..35a82cd9 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -30,11 +30,11 @@ impl AuthService { } } - pub fn message() -> CheckRequest { - AuthService::build_check_req() + pub fn message(ce_host: String) -> CheckRequest { + AuthService::build_check_req(ce_host) } - fn build_check_req() -> CheckRequest { + fn build_check_req(ce_host: String) -> CheckRequest { let mut auth_req = CheckRequest::default(); let mut attr = AttributeContext::default(); attr.set_request(AuthService::build_request()); @@ -46,12 +46,8 @@ impl AuthService { get_attribute::("source.address").unwrap_or_default(), get_attribute::("source.port").unwrap_or_default() as u32, )); - // todo(adam-cattermole): for now we set the context_extensions to the request host - // but this should take other info into account - let context_extensions = HashMap::from([( - "host".to_string(), - attr.get_request().get_http().host.to_owned(), - )]); + // the ce_host is the identifier for authorino to determine which authconfig to use + let context_extensions = HashMap::from([("host".to_string(), ce_host)]); attr.set_context_extensions(context_extensions); attr.set_metadata_context(Metadata::default()); auth_req.set_attributes(attr); From 0ce1b2c437478e0dc0ebd36df075f366066e87c8 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Wed, 21 Aug 2024 11:37:09 +0100 Subject: [PATCH 5/8] Genericise the service to send any type of M Signed-off-by: Adam Cattermole --- src/configuration.rs | 7 +++ src/filter/http_context.rs | 12 +++-- src/service.rs | 93 ++++++++++++++++++++++++++++++++++++-- src/service/auth.rs | 42 ++--------------- src/service/rate_limit.rs | 44 ++---------------- 5 files changed, 114 insertions(+), 84 deletions(-) diff --git a/src/configuration.rs b/src/configuration.rs index d7495ef0..9526f8cd 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -494,6 +494,13 @@ pub enum FailureMode { Allow, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ExtensionType { + Auth, + RateLimit, +} + #[derive(Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct PluginConfiguration { diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 281eab93..9feedb7a 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,9 +1,9 @@ -use crate::configuration::{FailureMode, FilterConfig}; +use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; use crate::service::rate_limit::RateLimitService; -use crate::service::Service; +use crate::service::GrpcServiceHandler; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; @@ -19,7 +19,7 @@ pub enum TracingHeader { } impl TracingHeader { - fn all() -> [Self; 3] { + pub fn all() -> [Self; 3] { [Traceparent, Tracestate, Baggage] } @@ -63,7 +63,11 @@ impl Filter { return Action::Continue; } - let rls = RateLimitService::new(rlp.service.as_str(), self.tracing_headers.clone()); + let rls = GrpcServiceHandler::new( + ExtensionType::RateLimit, + rlp.service.clone(), + self.tracing_headers.clone(), + ); let message = RateLimitService::message(rlp.domain.clone(), descriptors); match rls.send(message) { diff --git a/src/service.rs b/src/service.rs index 3cad9530..99424631 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,9 +1,96 @@ pub(crate) mod auth; pub(crate) mod rate_limit; +use crate::configuration::ExtensionType; +use crate::filter::http_context::TracingHeader; +use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use protobuf::Message; -use proxy_wasm::types::Status; +use proxy_wasm::hostcalls; +use proxy_wasm::hostcalls::dispatch_grpc_call; +use proxy_wasm::types::{Bytes, MapType, Status}; +use std::cell::OnceCell; +use std::time::Duration; -pub trait Service { - fn send(&self, message: M) -> Result; +pub struct GrpcServiceHandler { + endpoint: String, + service_name: String, + method_name: String, + tracing_headers: Vec<(TracingHeader, Bytes)>, +} + +impl GrpcServiceHandler { + fn build( + endpoint: String, + service_name: &str, + method_name: &str, + tracing_headers: Vec<(TracingHeader, Bytes)>, + ) -> Self { + Self { + endpoint: endpoint.to_owned(), + service_name: service_name.to_owned(), + method_name: method_name.to_owned(), + tracing_headers, + } + } + + pub fn new( + extension_type: ExtensionType, + endpoint: String, + tracing_headers: Vec<(TracingHeader, Bytes)>, + ) -> Self { + match extension_type { + ExtensionType::Auth => Self::build( + endpoint, + AUTH_SERVICE_NAME, + AUTH_METHOD_NAME, + tracing_headers, + ), + ExtensionType::RateLimit => Self::build( + endpoint, + RATELIMIT_SERVICE_NAME, + RATELIMIT_METHOD_NAME, + tracing_headers, + ), + } + } + + pub fn send(&self, message: M) -> Result { + let msg = Message::write_to_bytes(&message).unwrap(); + let metadata = self + .tracing_headers + .iter() + .map(|(header, value)| (header.as_str(), value.as_slice())) + .collect(); + + dispatch_grpc_call( + self.endpoint.as_str(), + self.service_name.as_str(), + self.method_name.as_str(), + metadata, + Some(&msg), + Duration::from_secs(5), + ) + } +} + +pub struct TracingHeaderResolver { + tracing_headers: OnceCell>, +} + +impl TracingHeaderResolver { + pub fn get(&self) -> &Vec<(TracingHeader, Bytes)> { + self.tracing_headers.get_or_init(|| { + let mut headers = Vec::new(); + for header in TracingHeader::all() { + if let Some(value) = + hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, header.as_str()) + .unwrap() + { + headers.push((header, value)); + } + } + headers + }) + } } diff --git a/src/service/auth.rs b/src/service/auth.rs index 35a82cd9..36220f59 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -3,33 +3,19 @@ use crate::envoy::{ Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, AttributeContext_Request, CheckRequest, Metadata, SocketAddress, }; -use crate::filter::http_context::TracingHeader; -use crate::service::Service; use chrono::{DateTime, FixedOffset, Timelike}; use protobuf::well_known_types::Timestamp; use protobuf::Message; use proxy_wasm::hostcalls; -use proxy_wasm::hostcalls::dispatch_grpc_call; -use proxy_wasm::types::{Bytes, MapType, Status}; +use proxy_wasm::types::MapType; use std::collections::HashMap; -use std::time::Duration; -const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; -const AUTH_METHOD_NAME: &str = "Check"; +pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; +pub const AUTH_METHOD_NAME: &str = "Check"; -pub struct AuthService { - endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, -} +pub struct AuthService; impl AuthService { - pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> Self { - Self { - endpoint: String::from(endpoint), - tracing_headers: metadata, - } - } - pub fn message(ce_host: String) -> CheckRequest { AuthService::build_check_req(ce_host) } @@ -93,23 +79,3 @@ impl AuthService { peer } } - -impl Service for AuthService { - fn send(&self, message: CheckRequest) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); // TODO(adam-cattermole): Error Handling - let metadata = self - .tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(); - - dispatch_grpc_call( - self.endpoint.as_str(), - AUTH_SERVICE_NAME, - AUTH_METHOD_NAME, - metadata, - Some(&msg), - Duration::from_secs(5), - ) - } -} diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 5fe03b41..6dfc3c89 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,26 +1,12 @@ use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use crate::filter::http_context::TracingHeader; -use crate::service::Service; -use protobuf::{Message, RepeatedField}; -use proxy_wasm::hostcalls::dispatch_grpc_call; -use proxy_wasm::types::{Bytes, Status}; -use std::time::Duration; +use protobuf::RepeatedField; -const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; -const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; -pub struct RateLimitService { - endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, -} +pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; +pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; -impl RateLimitService { - pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> Self { - Self { - endpoint: String::from(endpoint), - tracing_headers: metadata, - } - } +pub struct RateLimitService; +impl RateLimitService { pub fn message( domain: String, descriptors: RepeatedField, @@ -35,26 +21,6 @@ impl RateLimitService { } } -impl Service for RateLimitService { - fn send(&self, message: RateLimitRequest) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling - let metadata = self - .tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(); - - dispatch_grpc_call( - self.endpoint.as_str(), - RATELIMIT_SERVICE_NAME, - RATELIMIT_METHOD_NAME, - metadata, - Some(&msg), - Duration::from_secs(5), - ) - } -} - #[cfg(test)] mod tests { use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry, RateLimitRequest}; From 226e3af31fd2f44065d49f9c1adfa6b4c4e7e10b Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Thu, 22 Aug 2024 11:40:42 +0100 Subject: [PATCH 6/8] Use new header resolver in the service Signed-off-by: Adam Cattermole --- src/filter/http_context.rs | 37 +++------------------- src/filter/root_context.rs | 3 +- src/service.rs | 64 +++++++++++++++++++++++++++----------- src/service/auth.rs | 1 - 4 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 9feedb7a..ba2fb690 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,42 +1,19 @@ use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; -use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; use crate::service::rate_limit::RateLimitService; -use crate::service::GrpcServiceHandler; +use crate::service::{GrpcServiceHandler, HeaderResolver}; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; -use proxy_wasm::types::{Action, Bytes}; +use proxy_wasm::types::Action; use std::rc::Rc; -// tracing headers -#[derive(Clone)] -pub enum TracingHeader { - Traceparent, - Tracestate, - Baggage, -} - -impl TracingHeader { - pub fn all() -> [Self; 3] { - [Traceparent, Tracestate, Baggage] - } - - pub fn as_str(&self) -> &'static str { - match self { - Traceparent => "traceparent", - Tracestate => "tracestate", - Baggage => "baggage", - } - } -} - pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub tracing_headers: Vec<(TracingHeader, Bytes)>, + pub header_resolver: Rc, } impl Filter { @@ -66,7 +43,7 @@ impl Filter { let rls = GrpcServiceHandler::new( ExtensionType::RateLimit, rlp.service.clone(), - self.tracing_headers.clone(), + Rc::clone(&self.header_resolver), ); let message = RateLimitService::message(rlp.domain.clone(), descriptors); @@ -102,12 +79,6 @@ impl HttpContext for Filter { fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { debug!("#{} on_http_request_headers", self.context_id); - for header in TracingHeader::all() { - if let Some(value) = self.get_http_request_header_bytes(header.as_str()) { - self.tracing_headers.push((header, value)) - } - } - match self .config .index diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index ab28c72c..90774e1c 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -1,5 +1,6 @@ use crate::configuration::{FilterConfig, PluginConfiguration}; use crate::filter::http_context::Filter; +use crate::service::HeaderResolver; use const_format::formatcp; use log::{debug, error, info}; use proxy_wasm::traits::{Context, HttpContext, RootContext}; @@ -40,7 +41,7 @@ impl RootContext for FilterRoot { context_id, config: Rc::clone(&self.config), response_headers_to_add: Vec::default(), - tracing_headers: Vec::default(), + header_resolver: Rc::new(HeaderResolver::new()), })) } diff --git a/src/service.rs b/src/service.rs index 99424631..1f16198a 100644 --- a/src/service.rs +++ b/src/service.rs @@ -2,21 +2,22 @@ pub(crate) mod auth; pub(crate) mod rate_limit; use crate::configuration::ExtensionType; -use crate::filter::http_context::TracingHeader; use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; use protobuf::Message; use proxy_wasm::hostcalls; use proxy_wasm::hostcalls::dispatch_grpc_call; use proxy_wasm::types::{Bytes, MapType, Status}; use std::cell::OnceCell; +use std::rc::Rc; use std::time::Duration; pub struct GrpcServiceHandler { endpoint: String, service_name: String, method_name: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, + header_resolver: Rc, } impl GrpcServiceHandler { @@ -24,33 +25,33 @@ impl GrpcServiceHandler { endpoint: String, service_name: &str, method_name: &str, - tracing_headers: Vec<(TracingHeader, Bytes)>, + header_resolver: Rc, ) -> Self { Self { endpoint: endpoint.to_owned(), service_name: service_name.to_owned(), method_name: method_name.to_owned(), - tracing_headers, + header_resolver, } } pub fn new( extension_type: ExtensionType, endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, + header_resolver: Rc, ) -> Self { match extension_type { ExtensionType::Auth => Self::build( endpoint, AUTH_SERVICE_NAME, AUTH_METHOD_NAME, - tracing_headers, + header_resolver, ), ExtensionType::RateLimit => Self::build( endpoint, RATELIMIT_SERVICE_NAME, RATELIMIT_METHOD_NAME, - tracing_headers, + header_resolver, ), } } @@ -58,9 +59,10 @@ impl GrpcServiceHandler { pub fn send(&self, message: M) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self - .tracing_headers + .header_resolver + .get() .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) + .map(|(header, value)| (*header, value.as_slice())) .collect(); dispatch_grpc_call( @@ -74,23 +76,49 @@ impl GrpcServiceHandler { } } -pub struct TracingHeaderResolver { - tracing_headers: OnceCell>, +pub struct HeaderResolver { + headers: OnceCell>, } -impl TracingHeaderResolver { - pub fn get(&self) -> &Vec<(TracingHeader, Bytes)> { - self.tracing_headers.get_or_init(|| { +impl HeaderResolver { + pub fn new() -> Self { + Self { + headers: OnceCell::new(), + } + } + + pub fn get(&self) -> &Vec<(&'static str, Bytes)> { + self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { - if let Some(value) = - hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, header.as_str()) - .unwrap() + if let Ok(Some(value)) = + hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str()) { - headers.push((header, value)); + headers.push(((*header).as_str(), value)); } } headers }) } } + +// tracing headers +pub enum TracingHeader { + Traceparent, + Tracestate, + Baggage, +} + +impl TracingHeader { + fn all() -> &'static [Self; 3] { + &[Traceparent, Tracestate, Baggage] + } + + pub fn as_str(&self) -> &'static str { + match self { + Traceparent => "traceparent", + Tracestate => "tracestate", + Baggage => "baggage", + } + } +} diff --git a/src/service/auth.rs b/src/service/auth.rs index 36220f59..1e7c7344 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -5,7 +5,6 @@ use crate::envoy::{ }; use chrono::{DateTime, FixedOffset, Timelike}; use protobuf::well_known_types::Timestamp; -use protobuf::Message; use proxy_wasm::hostcalls; use proxy_wasm::types::MapType; use std::collections::HashMap; From 89bb1873741f04c38a9ef787ea97ff0f89106132 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Thu, 22 Aug 2024 14:08:40 +0100 Subject: [PATCH 7/8] Fix tests using header resolver Signed-off-by: Adam Cattermole --- tests/rate_limited.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 1578afa7..b25bbdda 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -57,12 +57,6 @@ fn it_loads() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) - .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) .expect_log( @@ -161,14 +155,14 @@ fn it_limits() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) + .returning(Some("cars.toystore.com")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) - .returning(Some("cars.toystore.com")) .expect_get_property(Some(vec!["request", "url_path"])) .returning(Some("/admin/toy".as_bytes())) .expect_get_property(Some(vec!["request", "host"])) @@ -299,14 +293,14 @@ fn it_passes_additional_headers() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) + .returning(Some("cars.toystore.com")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) - .returning(Some("cars.toystore.com")) .expect_get_property(Some(vec!["request", "url_path"])) .returning(Some("/admin/toy".as_bytes())) .expect_get_property(Some(vec!["request", "host"])) @@ -431,14 +425,14 @@ fn it_rate_limits_with_empty_conditions() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) + .returning(Some("a.com")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) - .returning(Some("a.com")) .expect_log( Some(LogLevel::Debug), Some("#2 ratelimitpolicy selected some-name"), @@ -542,12 +536,6 @@ fn it_does_not_rate_limits_when_selector_does_not_exist_and_misses_default_value module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) - .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("a.com")) .expect_get_property(Some(vec!["unknown", "path"])) From 4ef1adb3833ba996a6d30258d3c387322f782f37 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Wed, 28 Aug 2024 15:25:11 +0100 Subject: [PATCH 8/8] Ignore authservice until used Signed-off-by: Adam Cattermole --- src/service/auth.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/service/auth.rs b/src/service/auth.rs index 1e7c7344..0831cd6c 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -14,6 +14,7 @@ pub const AUTH_METHOD_NAME: &str = "Check"; pub struct AuthService; +#[allow(dead_code)] impl AuthService { pub fn message(ce_host: String) -> CheckRequest { AuthService::build_check_req(ce_host)