Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for server certificate verification handler #106

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ Once the handler sends a request, these settings become immutable and cannot be
|MaxIdlePerHost|Gets or sets the maximum idle connection per host allowed in the pool. Default is usize::MAX (no limit).|
|Http2Only|Gets or sets a value that indicates whether to force the use of HTTP/2.|
|SkipCertificateVerification|Gets or sets a value that indicates whether to skip certificate verification.|
|OnVerifyServerCertificate|Gets or sets a custom handler that validates server certificates.|
|RootCertificates|Gets or sets a custom root CA. By default, the built-in root CA (Mozilla's root certificates) is used. See also https://github.com/rustls/webpki-roots. |
|ClientAuthCertificates|Gets or sets a custom client auth key.|
|ClientAuthKey|Gets or sets a custom client auth certificates.|
Expand Down Expand Up @@ -280,6 +281,22 @@ using var handler = new YetAnotherHttpHandler() { RootCertificates = rootCerts }
### Ignore certificate validation errors
We strongly not recommend this, but in some cases, you may want to skip certificate validation when connecting via HTTPS. In this scenario, you can ignore certificate errors by setting the `SkipCertificateVerification` property to `true`.

### Handling server certificate verification
You can customize the server certificate verification process by setting the `OnVerifyServerCertificate` property.

The callback should return `true` or `false` based on the verification result. If the property is set, the root CA verification is not performed.

```csharp
using var httpHandler = new YetAnotherHttpHandler()
{
OnVerifyServerCertificate = (serverName, certificate, now) =>
{
var cert = new X509Certificate2(certificate);
return serverName == "api.example.com" &&
cert.Subject == "CN=api.example.com";
}
};
```

## Development
### Build & Tests
Expand Down
13 changes: 12 additions & 1 deletion native/yaha_native/src/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ pub extern "C" fn yaha_client_config_skip_certificate_verification(
let ctx = YahaNativeContextInternal::from_raw_context(ctx);
ctx.skip_certificate_verification = Some(val);
}

#[no_mangle]
pub extern "C" fn yaha_client_config_set_server_certificate_verification_handler(
ctx: *mut YahaNativeContext,
handler: Option<extern "C" fn(state: NonZeroIsize, server_name: *const u8, server_name_len: usize, certificate_der: *const u8, certificate_der_len: usize, now: u64) -> bool>,
callback_state: NonZeroIsize
) {
let ctx = YahaNativeContextInternal::from_raw_context(ctx);
ctx.server_certificate_verification_handler = handler.map(|x| (x, callback_state));
}

#[no_mangle]
pub extern "C" fn yaha_client_config_pool_idle_timeout(
ctx: *mut YahaNativeContext,
Expand Down Expand Up @@ -327,7 +338,7 @@ pub extern "C" fn yaha_client_config_http2_initial_max_send_streams(
#[no_mangle]
pub extern "C" fn yaha_build_client(ctx: *mut YahaNativeContext) {
let ctx = YahaNativeContextInternal::from_raw_context(ctx);
ctx.build_client(ctx.skip_certificate_verification.unwrap_or_default());
ctx.build_client();
}

#[no_mangle]
Expand Down
77 changes: 70 additions & 7 deletions native/yaha_native/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ use hyper_tls::HttpsConnector;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_util::sync::CancellationToken;

use crate::primitives::{YahaHttpVersion, CompletionReason};
use crate::{primitives::{CompletionReason, YahaHttpVersion}};

type OnStatusCodeAndHeadersReceive =
extern "C" fn(req_seq: i32, state: NonZeroIsize, status_code: i32, version: YahaHttpVersion);
type OnReceive = extern "C" fn(req_seq: i32, state: NonZeroIsize, length: usize, buf: *const u8);
type OnComplete = extern "C" fn(req_seq: i32, state: NonZeroIsize, reason: CompletionReason, h2_error_code: u32);
type OnServerCertificateVerificationHandler = extern "C" fn(callback_state: NonZeroIsize, server_name: *const u8, server_name_len: usize, certificate_der: *const u8, certificate_der_len: usize, now: u64) -> bool;

pub struct YahaNativeRuntimeContext;
pub struct YahaNativeRuntimeContextInternal {
Expand All @@ -54,6 +55,7 @@ pub struct YahaNativeContextInternal<'a> {
pub runtime: tokio::runtime::Handle,
pub client_builder: Option<client::legacy::Builder>,
pub skip_certificate_verification: Option<bool>,
pub server_certificate_verification_handler: Option<(OnServerCertificateVerificationHandler, NonZeroIsize)>,
pub root_certificates: Option<rustls::RootCertStore>,
pub override_server_name: Option<String>,
pub connect_timeout: Option<Duration>,
Expand Down Expand Up @@ -81,6 +83,7 @@ impl YahaNativeContextInternal<'_> {
client: None,
client_builder: Some(Client::builder(TokioExecutor::new())),
skip_certificate_verification: None,
server_certificate_verification_handler: None,
root_certificates: None,
override_server_name: None,
connect_timeout: None,
Expand All @@ -92,24 +95,32 @@ impl YahaNativeContextInternal<'_> {
}
}

pub fn build_client(&mut self, skip_verify_certificates: bool) {
pub fn build_client(&mut self) {
let mut builder = self.client_builder.take().unwrap();
let https = self.new_connector(skip_verify_certificates);
let https = self.new_connector();
self.client = Some(builder.timer(TokioTimer::new()).build(https));
}

#[cfg(feature = "rustls")]
fn new_connector(&mut self, skip_verify_certificates: bool) -> HttpsConnector<HttpConnector> {
fn new_connector(&mut self) -> HttpsConnector<HttpConnector> {
let tls_config_builder = rustls::ClientConfig::builder();

// Configure certificate root store.
let tls_config: rustls::ClientConfig;
if skip_verify_certificates {
if let Some(server_certificate_verification_handler) = self.server_certificate_verification_handler {
// Use custom certificate verification handler
tls_config = tls_config_builder
.dangerous()
.with_custom_certificate_verifier(Arc::new(danger::NoCertificateVerification {}))
.with_custom_certificate_verifier(Arc::new(danger::CustomCerficateVerification { handler: server_certificate_verification_handler }))
.with_no_client_auth();
} else if self.skip_certificate_verification.unwrap_or_default() {
// Skip certificate verification
tls_config = tls_config_builder
.dangerous()
.with_custom_certificate_verifier(Arc::new(danger::NoCertificateVerification{}))
.with_no_client_auth();
} else {
// Configure to use built-in certification store and client authentication.
let tls_config_builder_root: rustls::ConfigBuilder<
rustls::ClientConfig,
rustls::client::WantsClientCert,
Expand Down Expand Up @@ -165,21 +176,30 @@ impl YahaNativeContextInternal<'_> {
}

#[cfg(feature = "native")]
fn new_connector(&mut self, skip_verify_certificates: bool) -> HttpsConnector<HttpConnector> {
fn new_connector(&mut self, server_certificate_verification_handler: Option<OnServerCertificateVerificationHandler>) -> HttpsConnector<HttpConnector> {
let https = HttpsConnector::new();
https
}
}

#[cfg(feature = "rustls")]
mod danger {
use std::num::NonZeroIsize;

use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified};
use rustls::{DigitallySignedStruct, Error, SignatureScheme};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};

use super::OnServerCertificateVerificationHandler;

#[derive(Debug)]
pub struct NoCertificateVerification {}

#[derive(Debug)]
pub struct CustomCerficateVerification {
pub handler: (OnServerCertificateVerificationHandler, NonZeroIsize)
}

const ALL_SCHEMES: [SignatureScheme; 12] = [
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
Expand All @@ -194,6 +214,49 @@ mod danger {
SignatureScheme::ED25519,
SignatureScheme::ED448];

impl rustls::client::danger::ServerCertVerifier for CustomCerficateVerification {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
_ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, Error> {
let server_name = server_name.to_str();
let server_name = server_name.as_bytes();
let cetificate_der = end_entity.as_ref();

if (self.handler.0)(self.handler.1, server_name.as_ptr(), server_name.len(), cetificate_der.as_ptr(), cetificate_der.len(), now.as_secs()) {
Ok(ServerCertVerified::assertion())
} else {
Err(Error::InvalidCertificate(rustls::CertificateError::ApplicationVerificationFailure))
}
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
Vec::from(ALL_SCHEMES)
}
}

impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
Expand Down
42 changes: 42 additions & 0 deletions src/YetAnotherHttpHandler/NativeHttpHandlerCore.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Net;
using System.Text;
using System.IO.Pipelines;
Expand All @@ -25,13 +26,15 @@ internal class NativeHttpHandlerCore : IDisposable

//private unsafe YahaNativeContext* _ctx;
private readonly YahaContextSafeHandle _handle;
private GCHandle? _onVerifyServerCertificateHandle; // The handle must be released in Dispose if it is allocated.
private bool _disposed = false;

// NOTE: We need to keep the callback delegates in advance.
// The delegates are kept on the Rust side, so it will crash if they are garbage collected.
private static readonly unsafe NativeMethods.yaha_init_context_on_status_code_and_headers_receive_delegate OnStatusCodeAndHeaderReceiveCallback = OnStatusCodeAndHeaderReceive;
private static readonly unsafe NativeMethods.yaha_init_context_on_receive_delegate OnReceiveCallback = OnReceive;
private static readonly unsafe NativeMethods.yaha_init_context_on_complete_delegate OnCompleteCallback = OnComplete;
private static readonly unsafe NativeMethods.yaha_client_config_set_server_certificate_verification_handler_handler_delegate OnServerCertificateVerificationCallback = OnServerCertificateVerification;

public unsafe NativeHttpHandlerCore(NativeClientSettings settings)
{
Expand Down Expand Up @@ -75,6 +78,16 @@ private unsafe void Initialize(YahaNativeContext* ctx, NativeClientSettings sett
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Info($"Option '{nameof(settings.SkipCertificateVerification)}' = {skipCertificateVerification}");
NativeMethods.yaha_client_config_skip_certificate_verification(ctx, skipCertificateVerification);
}
if (settings.OnVerifyServerCertificate is { } onVerifyServerCertificate)
{
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Info($"Option '{nameof(settings.OnVerifyServerCertificate)}' = {onVerifyServerCertificate}");

// NOTE: We need to keep the handle to call in the static callback method.
// The handle must be released in Dispose if it is allocated.
_onVerifyServerCertificateHandle = GCHandle.Alloc(onVerifyServerCertificate);

NativeMethods.yaha_client_config_set_server_certificate_verification_handler(ctx, OnServerCertificateVerificationCallback, GCHandle.ToIntPtr(_onVerifyServerCertificateHandle.Value));
}
if (settings.RootCertificates is { } rootCertificates)
{
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Info($"Option '{nameof(settings.RootCertificates)}' = Length:{rootCertificates.Length}");
Expand Down Expand Up @@ -395,6 +408,33 @@ private static unsafe void OnStatusCodeAndHeaderReceive(int reqSeq, IntPtr state
requestContext.Response.SetStatusCode(statusCode);
}

[MonoPInvokeCallback(typeof(NativeMethods.yaha_client_config_set_server_certificate_verification_handler_handler_delegate))]
private static unsafe bool OnServerCertificateVerification(IntPtr callbackState, byte* serverNamePtr, UIntPtr /*nuint*/ serverNameLength, byte* certificateDerPtr, UIntPtr /*nuint*/ certificateDerLength, ulong now)
{
var serverName = UnsafeUtilities.GetStringFromUtf8Bytes(new ReadOnlySpan<byte>(serverNamePtr, (int)serverNameLength));
var certificateDer = new ReadOnlySpan<byte>(certificateDerPtr, (int)certificateDerLength);
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Trace($"OnServerCertificateVerification: State=0x{callbackState:X}; ServerName={serverName}; CertificateDer.Length={certificateDer.Length}; Now={now}");

var onServerCertificateVerification = (ServerCertificateVerificationHandler)GCHandle.FromIntPtr(callbackState).Target;
Debug.Assert(onServerCertificateVerification != null);
if (onServerCertificateVerification == null)
{
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Warning($"OnServerVerification: The verification callback was called, but onServerCertificateVerification is null.");
return false;
}
try
{
var success = onServerCertificateVerification(serverName, certificateDer, DateTimeOffset.FromUnixTimeSeconds((long)now));
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Trace($"OnServerVerification: Success = {success}");
return success;
}
catch (Exception e)
{
if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Error($"OnServerVerification: The verification callback thrown an exception: {e.ToString()}");
return false;
}
}

[MonoPInvokeCallback(typeof(NativeMethods.yaha_init_context_on_receive_delegate))]
private static unsafe void OnReceive(int reqSeq, IntPtr state, UIntPtr length, byte* buf)
{
Expand Down Expand Up @@ -496,6 +536,8 @@ private void Dispose(bool disposing)

if (YahaEventSource.Log.IsEnabled()) YahaEventSource.Log.Info($"Dispose {nameof(NativeHttpHandlerCore)}; disposing={disposing}");

_onVerifyServerCertificateHandle?.Free();

NativeRuntime.Instance.Release(); // We always need to release runtime.

if (disposing)
Expand Down
6 changes: 6 additions & 0 deletions src/YetAnotherHttpHandler/NativeMethods.Uwp.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ internal static unsafe partial class NativeMethods
[DllImport(__DllName, EntryPoint = "yaha_client_config_skip_certificate_verification", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void yaha_client_config_skip_certificate_verification(YahaNativeContext* ctx, [MarshalAs(UnmanagedType.U1)] bool val);

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
public delegate bool yaha_client_config_set_server_certificate_verification_handler_handler_delegate(nint state, byte* server_name, nuint server_name_len, byte* certificate_der, nuint certificate_der_len, ulong now);

[DllImport(__DllName, EntryPoint = "yaha_client_config_set_server_certificate_verification_handler", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void yaha_client_config_set_server_certificate_verification_handler(YahaNativeContext* ctx, yaha_client_config_set_server_certificate_verification_handler_handler_delegate handler, nint callback_state);

[DllImport(__DllName, EntryPoint = "yaha_client_config_pool_idle_timeout", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void yaha_client_config_pool_idle_timeout(YahaNativeContext* ctx, ulong val_milliseconds);

Expand Down
6 changes: 6 additions & 0 deletions src/YetAnotherHttpHandler/NativeMethods.g.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ internal static unsafe partial class NativeMethods
[DllImport(__DllName, EntryPoint = "yaha_client_config_skip_certificate_verification", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void yaha_client_config_skip_certificate_verification(YahaNativeContext* ctx, [MarshalAs(UnmanagedType.U1)] bool val);

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
public delegate bool yaha_client_config_set_server_certificate_verification_handler_handler_delegate(nint state, byte* server_name, nuint server_name_len, byte* certificate_der, nuint certificate_der_len, ulong now);

[DllImport(__DllName, EntryPoint = "yaha_client_config_set_server_certificate_verification_handler", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void yaha_client_config_set_server_certificate_verification_handler(YahaNativeContext* ctx, yaha_client_config_set_server_certificate_verification_handler_handler_delegate handler, nint callback_state);

[DllImport(__DllName, EntryPoint = "yaha_client_config_pool_idle_timeout", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void yaha_client_config_pool_idle_timeout(YahaNativeContext* ctx, ulong val_milliseconds);

Expand Down
Loading
Loading