Skip to content

Commit

Permalink
Merge pull request #69 from Kuadrant/authservice
Browse files Browse the repository at this point in the history
Add initial implementation of auth service
  • Loading branch information
adam-cattermole authored Aug 28, 2024
2 parents 354dd42 + 4ef1adb commit 870cafd
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 112 deletions.
14 changes: 5 additions & 9 deletions src/attribute.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::configuration::Path;
use crate::filter::http_context::Filter;
use chrono::{DateTime, FixedOffset};
use proxy_wasm::traits::Context;
use proxy_wasm::hostcalls;

pub trait Attribute {
fn parse(raw_attribute: Vec<u8>) -> Result<Self, String>
Expand Down Expand Up @@ -105,15 +104,12 @@ impl Attribute for DateTime<FixedOffset> {
}

#[allow(dead_code)]
pub fn get_attribute<T>(f: &Filter, attr: &str) -> Result<T, String>
pub fn get_attribute<T>(attr: &str) -> Result<T, String>
where
T: Attribute,
{
match f.get_property(Path::from(attr).tokens()) {
None => Err(format!(
"#{} get_attribute: not found: {}",
f.context_id, attr
)),
Some(attribute_bytes) => T::parse(attribute_bytes),
match hostcalls::get_property(Path::from(attr).tokens()) {
Ok(Some(attribute_bytes)) => T::parse(attribute_bytes),
_ => Err(format!("get_attribute: not found: {}", attr)),
}
}
7 changes: 7 additions & 0 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ pub enum FailureMode {
Allow,
}

#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ExtensionType {
Auth,
RateLimit,
}

#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PluginConfiguration {
Expand Down
7 changes: 7 additions & 0 deletions src/envoy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ mod token_bucket;
mod value;

pub use {
address::{Address, SocketAddress},
attribute_context::{
AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer,
AttributeContext_Request,
},
base::Metadata,
external_auth::CheckRequest,
ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry},
rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code},
};
Expand Down
43 changes: 9 additions & 34 deletions src/filter/http_context.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,19 @@
use crate::configuration::{FailureMode, FilterConfig};
use crate::configuration::{ExtensionType, FailureMode, FilterConfig};
use crate::envoy::{RateLimitResponse, RateLimitResponse_Code};
use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate};
use crate::policy::Policy;
use crate::service::rate_limit::RateLimitService;
use crate::service::Service;
use crate::service::{GrpcServiceHandler, HeaderResolver};
use log::{debug, warn};
use protobuf::Message;
use proxy_wasm::traits::{Context, HttpContext};
use proxy_wasm::types::{Action, Bytes};
use proxy_wasm::types::Action;
use std::rc::Rc;

// tracing headers
#[derive(Clone)]
pub enum TracingHeader {
Traceparent,
Tracestate,
Baggage,
}

impl TracingHeader {
fn all() -> [Self; 3] {
[Traceparent, Tracestate, Baggage]
}

pub fn as_str(&self) -> &'static str {
match self {
Traceparent => "traceparent",
Tracestate => "tracestate",
Baggage => "baggage",
}
}
}

pub struct Filter {
pub context_id: u32,
pub config: Rc<FilterConfig>,
pub response_headers_to_add: Vec<(String, String)>,
pub tracing_headers: Vec<(TracingHeader, Bytes)>,
pub header_resolver: Rc<HeaderResolver>,
}

impl Filter {
Expand All @@ -63,7 +40,11 @@ impl Filter {
return Action::Continue;
}

let rls = RateLimitService::new(rlp.service.as_str(), self.tracing_headers.clone());
let rls = GrpcServiceHandler::new(
ExtensionType::RateLimit,
rlp.service.clone(),
Rc::clone(&self.header_resolver),
);
let message = RateLimitService::message(rlp.domain.clone(), descriptors);

match rls.send(message) {
Expand Down Expand Up @@ -98,12 +79,6 @@ impl HttpContext for Filter {
fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action {
debug!("#{} on_http_request_headers", self.context_id);

for header in TracingHeader::all() {
if let Some(value) = self.get_http_request_header_bytes(header.as_str()) {
self.tracing_headers.push((header, value))
}
}

match self
.config
.index
Expand Down
3 changes: 2 additions & 1 deletion src/filter/root_context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::configuration::{FilterConfig, PluginConfiguration};
use crate::filter::http_context::Filter;
use crate::service::HeaderResolver;
use const_format::formatcp;
use log::{debug, error, info};
use proxy_wasm::traits::{Context, HttpContext, RootContext};
Expand Down Expand Up @@ -40,7 +41,7 @@ impl RootContext for FilterRoot {
context_id,
config: Rc::clone(&self.config),
response_headers_to_add: Vec::default(),
tracing_headers: Vec::default(),
header_resolver: Rc::new(HeaderResolver::new()),
}))
}

Expand Down
122 changes: 119 additions & 3 deletions src/service.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,124 @@
pub(crate) mod auth;
pub(crate) mod rate_limit;

use crate::configuration::ExtensionType;
use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME};
use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME};
use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate};
use protobuf::Message;
use proxy_wasm::types::Status;
use proxy_wasm::hostcalls;
use proxy_wasm::hostcalls::dispatch_grpc_call;
use proxy_wasm::types::{Bytes, MapType, Status};
use std::cell::OnceCell;
use std::rc::Rc;
use std::time::Duration;

pub trait Service<M: Message> {
fn send(&self, message: M) -> Result<u32, Status>;
pub struct GrpcServiceHandler {
endpoint: String,
service_name: String,
method_name: String,
header_resolver: Rc<HeaderResolver>,
}

impl GrpcServiceHandler {
fn build(
endpoint: String,
service_name: &str,
method_name: &str,
header_resolver: Rc<HeaderResolver>,
) -> Self {
Self {
endpoint: endpoint.to_owned(),
service_name: service_name.to_owned(),
method_name: method_name.to_owned(),
header_resolver,
}
}

pub fn new(
extension_type: ExtensionType,
endpoint: String,
header_resolver: Rc<HeaderResolver>,
) -> Self {
match extension_type {
ExtensionType::Auth => Self::build(
endpoint,
AUTH_SERVICE_NAME,
AUTH_METHOD_NAME,
header_resolver,
),
ExtensionType::RateLimit => Self::build(
endpoint,
RATELIMIT_SERVICE_NAME,
RATELIMIT_METHOD_NAME,
header_resolver,
),
}
}

pub fn send<M: Message>(&self, message: M) -> Result<u32, Status> {
let msg = Message::write_to_bytes(&message).unwrap();
let metadata = self
.header_resolver
.get()
.iter()
.map(|(header, value)| (*header, value.as_slice()))
.collect();

dispatch_grpc_call(
self.endpoint.as_str(),
self.service_name.as_str(),
self.method_name.as_str(),
metadata,
Some(&msg),
Duration::from_secs(5),
)
}
}

pub struct HeaderResolver {
headers: OnceCell<Vec<(&'static str, Bytes)>>,
}

impl HeaderResolver {
pub fn new() -> Self {
Self {
headers: OnceCell::new(),
}
}

pub fn get(&self) -> &Vec<(&'static str, Bytes)> {
self.headers.get_or_init(|| {
let mut headers = Vec::new();
for header in TracingHeader::all() {
if let Ok(Some(value)) =
hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str())
{
headers.push(((*header).as_str(), value));
}
}
headers
})
}
}

// tracing headers
pub enum TracingHeader {
Traceparent,
Tracestate,
Baggage,
}

impl TracingHeader {
fn all() -> &'static [Self; 3] {
&[Traceparent, Tracestate, Baggage]
}

pub fn as_str(&self) -> &'static str {
match self {
Traceparent => "traceparent",
Tracestate => "tracestate",
Baggage => "baggage",
}
}
}
81 changes: 81 additions & 0 deletions src/service/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::attribute::get_attribute;
use crate::envoy::{
Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer,
AttributeContext_Request, CheckRequest, Metadata, SocketAddress,
};
use chrono::{DateTime, FixedOffset, Timelike};
use protobuf::well_known_types::Timestamp;
use proxy_wasm::hostcalls;
use proxy_wasm::types::MapType;
use std::collections::HashMap;

pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization";
pub const AUTH_METHOD_NAME: &str = "Check";

pub struct AuthService;

#[allow(dead_code)]
impl AuthService {
pub fn message(ce_host: String) -> CheckRequest {
AuthService::build_check_req(ce_host)
}

fn build_check_req(ce_host: String) -> CheckRequest {
let mut auth_req = CheckRequest::default();
let mut attr = AttributeContext::default();
attr.set_request(AuthService::build_request());
attr.set_destination(AuthService::build_peer(
get_attribute::<String>("destination.address").unwrap_or_default(),
get_attribute::<i64>("destination.port").unwrap_or_default() as u32,
));
attr.set_source(AuthService::build_peer(
get_attribute::<String>("source.address").unwrap_or_default(),
get_attribute::<i64>("source.port").unwrap_or_default() as u32,
));
// the ce_host is the identifier for authorino to determine which authconfig to use
let context_extensions = HashMap::from([("host".to_string(), ce_host)]);
attr.set_context_extensions(context_extensions);
attr.set_metadata_context(Metadata::default());
auth_req.set_attributes(attr);
auth_req
}

fn build_request() -> AttributeContext_Request {
let mut request = AttributeContext_Request::default();
let mut http = AttributeContext_HttpRequest::default();
let headers: HashMap<String, String> = hostcalls::get_map(MapType::HttpRequestHeaders)
.unwrap()
.into_iter()
.collect();

http.set_host(get_attribute::<String>("request.host").unwrap_or_default());
http.set_method(get_attribute::<String>("request.method").unwrap_or_default());
http.set_scheme(get_attribute::<String>("request.scheme").unwrap_or_default());
http.set_path(get_attribute::<String>("request.path").unwrap_or_default());
http.set_protocol(get_attribute::<String>("request.protocol").unwrap_or_default());

http.set_headers(headers);
request.set_time(get_attribute("request.time").map_or(
Timestamp::new(),
|date_time: DateTime<FixedOffset>| Timestamp {
nanos: date_time.nanosecond() as i32,
seconds: date_time.second() as i64,
unknown_fields: Default::default(),
cached_size: Default::default(),
},
));
request.set_http(http);
request
}

fn build_peer(host: String, port: u32) -> AttributeContext_Peer {
let mut peer = AttributeContext_Peer::default();
let mut address = Address::default();
let mut socket_address = SocketAddress::default();
socket_address.set_address(host);
socket_address.set_port_value(port);
address.set_socket_address(socket_address);
peer.set_address(address);
peer
}
}
Loading

0 comments on commit 870cafd

Please sign in to comment.