diff --git a/src/configuration.rs b/src/configuration.rs index d7495ef0..9526f8cd 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -494,6 +494,13 @@ pub enum FailureMode { Allow, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ExtensionType { + Auth, + RateLimit, +} + #[derive(Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct PluginConfiguration { diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 281eab93..9feedb7a 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,9 +1,9 @@ -use crate::configuration::{FailureMode, FilterConfig}; +use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; use crate::service::rate_limit::RateLimitService; -use crate::service::Service; +use crate::service::GrpcServiceHandler; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; @@ -19,7 +19,7 @@ pub enum TracingHeader { } impl TracingHeader { - fn all() -> [Self; 3] { + pub fn all() -> [Self; 3] { [Traceparent, Tracestate, Baggage] } @@ -63,7 +63,11 @@ impl Filter { return Action::Continue; } - let rls = RateLimitService::new(rlp.service.as_str(), self.tracing_headers.clone()); + let rls = GrpcServiceHandler::new( + ExtensionType::RateLimit, + rlp.service.clone(), + self.tracing_headers.clone(), + ); let message = RateLimitService::message(rlp.domain.clone(), descriptors); match rls.send(message) { diff --git a/src/service.rs b/src/service.rs index 3cad9530..3f358550 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,9 +1,96 @@ pub(crate) mod auth; pub(crate) mod rate_limit; +use crate::configuration::ExtensionType; +use crate::filter::http_context::TracingHeader; +use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use protobuf::Message; -use proxy_wasm::types::Status; +use proxy_wasm::hostcalls; +use proxy_wasm::hostcalls::dispatch_grpc_call; +use proxy_wasm::types::{Bytes, MapType, Status}; +use std::cell::OnceCell; +use std::time::Duration; -pub trait Service { - fn send(&self, message: M) -> Result; +pub struct GrpcServiceHandler { + endpoint: String, + service_name: String, + method_name: String, + tracing_headers: Vec<(TracingHeader, Bytes)>, +} + +impl GrpcServiceHandler { + fn new_base( + endpoint: String, + service_name: &str, + method_name: &str, + tracing_headers: Vec<(TracingHeader, Bytes)>, + ) -> Self { + Self { + endpoint: endpoint.to_owned(), + service_name: service_name.to_owned(), + method_name: method_name.to_owned(), + tracing_headers, + } + } + + pub fn new( + extension_type: ExtensionType, + endpoint: String, + tracing_headers: Vec<(TracingHeader, Bytes)>, + ) -> Self { + match extension_type { + ExtensionType::Auth => Self::new_base( + endpoint, + AUTH_SERVICE_NAME, + AUTH_METHOD_NAME, + tracing_headers, + ), + ExtensionType::RateLimit => Self::new_base( + endpoint, + RATELIMIT_SERVICE_NAME, + RATELIMIT_METHOD_NAME, + tracing_headers, + ), + } + } + + pub fn send(&self, message: M) -> Result { + let msg = Message::write_to_bytes(&message).unwrap(); + let metadata = self + .tracing_headers + .iter() + .map(|(header, value)| (header.as_str(), value.as_slice())) + .collect(); + + dispatch_grpc_call( + self.endpoint.as_str(), + self.service_name.as_str(), + self.method_name.as_str(), + metadata, + Some(&msg), + Duration::from_secs(5), + ) + } +} + +pub struct TracingHeaderResolver { + tracing_headers: OnceCell>, +} + +impl TracingHeaderResolver { + pub fn get(&self) -> &Vec<(TracingHeader, Bytes)> { + self.tracing_headers.get_or_init(|| { + let mut headers = Vec::new(); + for header in TracingHeader::all() { + if let Some(value) = + hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, header.as_str()) + .unwrap() + { + headers.push((header, value)); + } + } + headers + }) + } } diff --git a/src/service/auth.rs b/src/service/auth.rs index 35a82cd9..627d0cb0 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -3,33 +3,20 @@ use crate::envoy::{ Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, AttributeContext_Request, CheckRequest, Metadata, SocketAddress, }; -use crate::filter::http_context::TracingHeader; -use crate::service::Service; use chrono::{DateTime, FixedOffset, Timelike}; use protobuf::well_known_types::Timestamp; use protobuf::Message; use proxy_wasm::hostcalls; -use proxy_wasm::hostcalls::dispatch_grpc_call; -use proxy_wasm::types::{Bytes, MapType, Status}; +use proxy_wasm::types::MapType; use std::collections::HashMap; -use std::time::Duration; -const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; -const AUTH_METHOD_NAME: &str = "Check"; +pub const AUTH_DATA_ITEM: &str = "host"; +pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; +pub const AUTH_METHOD_NAME: &str = "Check"; -pub struct AuthService { - endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, -} +pub struct AuthService; impl AuthService { - pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> Self { - Self { - endpoint: String::from(endpoint), - tracing_headers: metadata, - } - } - pub fn message(ce_host: String) -> CheckRequest { AuthService::build_check_req(ce_host) } @@ -93,23 +80,3 @@ impl AuthService { peer } } - -impl Service for AuthService { - fn send(&self, message: CheckRequest) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); // TODO(adam-cattermole): Error Handling - let metadata = self - .tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(); - - dispatch_grpc_call( - self.endpoint.as_str(), - AUTH_SERVICE_NAME, - AUTH_METHOD_NAME, - metadata, - Some(&msg), - Duration::from_secs(5), - ) - } -} diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 5fe03b41..ab986e76 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,26 +1,13 @@ use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use crate::filter::http_context::TracingHeader; -use crate::service::Service; use protobuf::{Message, RepeatedField}; -use proxy_wasm::hostcalls::dispatch_grpc_call; -use proxy_wasm::types::{Bytes, Status}; -use std::time::Duration; -const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; -const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; -pub struct RateLimitService { - endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, -} +pub const RATELIMIT_DATA_ITEM: &str = "domain"; +pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; +pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; -impl RateLimitService { - pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> Self { - Self { - endpoint: String::from(endpoint), - tracing_headers: metadata, - } - } +pub struct RateLimitService; +impl RateLimitService { pub fn message( domain: String, descriptors: RepeatedField, @@ -35,26 +22,6 @@ impl RateLimitService { } } -impl Service for RateLimitService { - fn send(&self, message: RateLimitRequest) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling - let metadata = self - .tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(); - - dispatch_grpc_call( - self.endpoint.as_str(), - RATELIMIT_SERVICE_NAME, - RATELIMIT_METHOD_NAME, - metadata, - Some(&msg), - Duration::from_secs(5), - ) - } -} - #[cfg(test)] mod tests { use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry, RateLimitRequest};