diff --git a/src/attribute.rs b/src/attribute.rs index aada6176..357e9f83 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -1,7 +1,6 @@ use crate::configuration::Path; -use crate::filter::http_context::Filter; use chrono::{DateTime, FixedOffset}; -use proxy_wasm::traits::Context; +use proxy_wasm::hostcalls; pub trait Attribute { fn parse(raw_attribute: Vec) -> Result @@ -105,15 +104,12 @@ impl Attribute for DateTime { } #[allow(dead_code)] -pub fn get_attribute(f: &Filter, attr: &str) -> Result +pub fn get_attribute(attr: &str) -> Result where T: Attribute, { - match f.get_property(Path::from(attr).tokens()) { - None => Err(format!( - "#{} get_attribute: not found: {}", - f.context_id, attr - )), - Some(attribute_bytes) => T::parse(attribute_bytes), + match hostcalls::get_property(Path::from(attr).tokens()) { + Ok(Some(attribute_bytes)) => T::parse(attribute_bytes), + _ => Err(format!("get_attribute: not found: {}", attr)), } } diff --git a/src/configuration.rs b/src/configuration.rs index d7495ef0..9526f8cd 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -494,6 +494,13 @@ pub enum FailureMode { Allow, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ExtensionType { + Auth, + RateLimit, +} + #[derive(Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct PluginConfiguration { diff --git a/src/envoy/mod.rs b/src/envoy/mod.rs index 68e18d60..db527204 100644 --- a/src/envoy/mod.rs +++ b/src/envoy/mod.rs @@ -31,6 +31,13 @@ mod token_bucket; mod value; pub use { + address::{Address, SocketAddress}, + attribute_context::{ + AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, + AttributeContext_Request, + }, + base::Metadata, + external_auth::CheckRequest, ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry}, rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code}, }; diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 281eab93..ba2fb690 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,42 +1,19 @@ -use crate::configuration::{FailureMode, FilterConfig}; +use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; -use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; use crate::service::rate_limit::RateLimitService; -use crate::service::Service; +use crate::service::{GrpcServiceHandler, HeaderResolver}; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; -use proxy_wasm::types::{Action, Bytes}; +use proxy_wasm::types::Action; use std::rc::Rc; -// tracing headers -#[derive(Clone)] -pub enum TracingHeader { - Traceparent, - Tracestate, - Baggage, -} - -impl TracingHeader { - fn all() -> [Self; 3] { - [Traceparent, Tracestate, Baggage] - } - - pub fn as_str(&self) -> &'static str { - match self { - Traceparent => "traceparent", - Tracestate => "tracestate", - Baggage => "baggage", - } - } -} - pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub tracing_headers: Vec<(TracingHeader, Bytes)>, + pub header_resolver: Rc, } impl Filter { @@ -63,7 +40,11 @@ impl Filter { return Action::Continue; } - let rls = RateLimitService::new(rlp.service.as_str(), self.tracing_headers.clone()); + let rls = GrpcServiceHandler::new( + ExtensionType::RateLimit, + rlp.service.clone(), + Rc::clone(&self.header_resolver), + ); let message = RateLimitService::message(rlp.domain.clone(), descriptors); match rls.send(message) { @@ -98,12 +79,6 @@ impl HttpContext for Filter { fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { debug!("#{} on_http_request_headers", self.context_id); - for header in TracingHeader::all() { - if let Some(value) = self.get_http_request_header_bytes(header.as_str()) { - self.tracing_headers.push((header, value)) - } - } - match self .config .index diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index ab28c72c..90774e1c 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -1,5 +1,6 @@ use crate::configuration::{FilterConfig, PluginConfiguration}; use crate::filter::http_context::Filter; +use crate::service::HeaderResolver; use const_format::formatcp; use log::{debug, error, info}; use proxy_wasm::traits::{Context, HttpContext, RootContext}; @@ -40,7 +41,7 @@ impl RootContext for FilterRoot { context_id, config: Rc::clone(&self.config), response_headers_to_add: Vec::default(), - tracing_headers: Vec::default(), + header_resolver: Rc::new(HeaderResolver::new()), })) } diff --git a/src/service.rs b/src/service.rs index b63bb827..1f16198a 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,8 +1,124 @@ +pub(crate) mod auth; pub(crate) mod rate_limit; +use crate::configuration::ExtensionType; +use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; use protobuf::Message; -use proxy_wasm::types::Status; +use proxy_wasm::hostcalls; +use proxy_wasm::hostcalls::dispatch_grpc_call; +use proxy_wasm::types::{Bytes, MapType, Status}; +use std::cell::OnceCell; +use std::rc::Rc; +use std::time::Duration; -pub trait Service { - fn send(&self, message: M) -> Result; +pub struct GrpcServiceHandler { + endpoint: String, + service_name: String, + method_name: String, + header_resolver: Rc, +} + +impl GrpcServiceHandler { + fn build( + endpoint: String, + service_name: &str, + method_name: &str, + header_resolver: Rc, + ) -> Self { + Self { + endpoint: endpoint.to_owned(), + service_name: service_name.to_owned(), + method_name: method_name.to_owned(), + header_resolver, + } + } + + pub fn new( + extension_type: ExtensionType, + endpoint: String, + header_resolver: Rc, + ) -> Self { + match extension_type { + ExtensionType::Auth => Self::build( + endpoint, + AUTH_SERVICE_NAME, + AUTH_METHOD_NAME, + header_resolver, + ), + ExtensionType::RateLimit => Self::build( + endpoint, + RATELIMIT_SERVICE_NAME, + RATELIMIT_METHOD_NAME, + header_resolver, + ), + } + } + + pub fn send(&self, message: M) -> Result { + let msg = Message::write_to_bytes(&message).unwrap(); + let metadata = self + .header_resolver + .get() + .iter() + .map(|(header, value)| (*header, value.as_slice())) + .collect(); + + dispatch_grpc_call( + self.endpoint.as_str(), + self.service_name.as_str(), + self.method_name.as_str(), + metadata, + Some(&msg), + Duration::from_secs(5), + ) + } +} + +pub struct HeaderResolver { + headers: OnceCell>, +} + +impl HeaderResolver { + pub fn new() -> Self { + Self { + headers: OnceCell::new(), + } + } + + pub fn get(&self) -> &Vec<(&'static str, Bytes)> { + self.headers.get_or_init(|| { + let mut headers = Vec::new(); + for header in TracingHeader::all() { + if let Ok(Some(value)) = + hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str()) + { + headers.push(((*header).as_str(), value)); + } + } + headers + }) + } +} + +// tracing headers +pub enum TracingHeader { + Traceparent, + Tracestate, + Baggage, +} + +impl TracingHeader { + fn all() -> &'static [Self; 3] { + &[Traceparent, Tracestate, Baggage] + } + + pub fn as_str(&self) -> &'static str { + match self { + Traceparent => "traceparent", + Tracestate => "tracestate", + Baggage => "baggage", + } + } } diff --git a/src/service/auth.rs b/src/service/auth.rs new file mode 100644 index 00000000..0831cd6c --- /dev/null +++ b/src/service/auth.rs @@ -0,0 +1,81 @@ +use crate::attribute::get_attribute; +use crate::envoy::{ + Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, + AttributeContext_Request, CheckRequest, Metadata, SocketAddress, +}; +use chrono::{DateTime, FixedOffset, Timelike}; +use protobuf::well_known_types::Timestamp; +use proxy_wasm::hostcalls; +use proxy_wasm::types::MapType; +use std::collections::HashMap; + +pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; +pub const AUTH_METHOD_NAME: &str = "Check"; + +pub struct AuthService; + +#[allow(dead_code)] +impl AuthService { + pub fn message(ce_host: String) -> CheckRequest { + AuthService::build_check_req(ce_host) + } + + fn build_check_req(ce_host: String) -> CheckRequest { + let mut auth_req = CheckRequest::default(); + let mut attr = AttributeContext::default(); + attr.set_request(AuthService::build_request()); + attr.set_destination(AuthService::build_peer( + get_attribute::("destination.address").unwrap_or_default(), + get_attribute::("destination.port").unwrap_or_default() as u32, + )); + attr.set_source(AuthService::build_peer( + get_attribute::("source.address").unwrap_or_default(), + get_attribute::("source.port").unwrap_or_default() as u32, + )); + // the ce_host is the identifier for authorino to determine which authconfig to use + let context_extensions = HashMap::from([("host".to_string(), ce_host)]); + attr.set_context_extensions(context_extensions); + attr.set_metadata_context(Metadata::default()); + auth_req.set_attributes(attr); + auth_req + } + + fn build_request() -> AttributeContext_Request { + let mut request = AttributeContext_Request::default(); + let mut http = AttributeContext_HttpRequest::default(); + let headers: HashMap = hostcalls::get_map(MapType::HttpRequestHeaders) + .unwrap() + .into_iter() + .collect(); + + http.set_host(get_attribute::("request.host").unwrap_or_default()); + http.set_method(get_attribute::("request.method").unwrap_or_default()); + http.set_scheme(get_attribute::("request.scheme").unwrap_or_default()); + http.set_path(get_attribute::("request.path").unwrap_or_default()); + http.set_protocol(get_attribute::("request.protocol").unwrap_or_default()); + + http.set_headers(headers); + request.set_time(get_attribute("request.time").map_or( + Timestamp::new(), + |date_time: DateTime| Timestamp { + nanos: date_time.nanosecond() as i32, + seconds: date_time.second() as i64, + unknown_fields: Default::default(), + cached_size: Default::default(), + }, + )); + request.set_http(http); + request + } + + fn build_peer(host: String, port: u32) -> AttributeContext_Peer { + let mut peer = AttributeContext_Peer::default(); + let mut address = Address::default(); + let mut socket_address = SocketAddress::default(); + socket_address.set_address(host); + socket_address.set_port_value(port); + address.set_socket_address(socket_address); + peer.set_address(address); + peer + } +} diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 6c4726c5..6dfc3c89 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,25 +1,12 @@ use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use crate::filter::http_context::TracingHeader; -use crate::service::Service; -use protobuf::{Message, RepeatedField}; -use proxy_wasm::hostcalls::dispatch_grpc_call; -use proxy_wasm::types::{Bytes, Status}; -use std::time::Duration; +use protobuf::RepeatedField; -const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; -const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; -pub struct RateLimitService { - endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, -} +pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; +pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; + +pub struct RateLimitService; impl RateLimitService { - pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> RateLimitService { - Self { - endpoint: String::from(endpoint), - tracing_headers: metadata, - } - } pub fn message( domain: String, descriptors: RepeatedField, @@ -34,35 +21,6 @@ impl RateLimitService { } } -fn grpc_call( - upstream_name: &str, - initial_metadata: Vec<(&str, &[u8])>, - message: RateLimitRequest, -) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling - dispatch_grpc_call( - upstream_name, - RATELIMIT_SERVICE_NAME, - RATELIMIT_METHOD_NAME, - initial_metadata, - Some(&msg), - Duration::from_secs(5), - ) -} - -impl Service for RateLimitService { - fn send(&self, message: RateLimitRequest) -> Result { - grpc_call( - self.endpoint.as_str(), - self.tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(), - message, - ) - } -} - #[cfg(test)] mod tests { use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry, RateLimitRequest}; diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 1578afa7..b25bbdda 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -57,12 +57,6 @@ fn it_loads() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) - .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) .expect_log( @@ -161,14 +155,14 @@ fn it_limits() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) + .returning(Some("cars.toystore.com")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) - .returning(Some("cars.toystore.com")) .expect_get_property(Some(vec!["request", "url_path"])) .returning(Some("/admin/toy".as_bytes())) .expect_get_property(Some(vec!["request", "host"])) @@ -299,14 +293,14 @@ fn it_passes_additional_headers() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) + .returning(Some("cars.toystore.com")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) - .returning(Some("cars.toystore.com")) .expect_get_property(Some(vec!["request", "url_path"])) .returning(Some("/admin/toy".as_bytes())) .expect_get_property(Some(vec!["request", "host"])) @@ -431,14 +425,14 @@ fn it_rate_limits_with_empty_conditions() { module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) + .returning(Some("a.com")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) - .returning(Some("a.com")) .expect_log( Some(LogLevel::Debug), Some("#2 ratelimitpolicy selected some-name"), @@ -542,12 +536,6 @@ fn it_does_not_rate_limits_when_selector_does_not_exist_and_misses_default_value module .call_proxy_on_request_headers(http_context, 0, false) .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("baggage")) - .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("a.com")) .expect_get_property(Some(vec!["unknown", "path"]))