Skip to content

Commit

Permalink
improve: don't use tokio TcpStream::connect
Browse files Browse the repository at this point in the history
instead use rama's tcp connector,
as to use our context and HickoryDNS

more in line with production-like code
  • Loading branch information
GlenDC committed Oct 31, 2024
1 parent 77ccf5e commit d7f5a3e
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 31 deletions.
8 changes: 4 additions & 4 deletions examples/http_connect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ use rama::{
net::{address::Domain, user::Basic},
rt::Executor,
service::service_fn,
tcp::{server::TcpListener, utils::is_connection_error},
tcp::{client::default_tcp_connect, server::TcpListener, utils::is_connection_error},
username::{
UsernameLabelParser, UsernameLabelState, UsernameLabels, UsernameOpaqueLabelParser,
},
Expand Down Expand Up @@ -191,9 +191,9 @@ where
.get::<RequestContext>()
.unwrap()
.authority
.to_string();
tracing::info!("CONNECT to {}", authority);
let mut stream = match tokio::net::TcpStream::connect(authority).await {
.clone();
tracing::info!("CONNECT to {authority}");
let (mut stream, _) = match default_tcp_connect(&ctx, authority).await {
Ok(stream) => stream,
Err(err) => {
tracing::error!(error = %err, "error connecting to host");
Expand Down
8 changes: 4 additions & 4 deletions examples/https_connect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use rama::{
net::user::Basic,
rt::Executor,
service::service_fn,
tcp::{server::TcpListener, utils::is_connection_error},
tcp::{client::default_tcp_connect, server::TcpListener, utils::is_connection_error},
tls::std::server::TlsAcceptorLayer,
Context, Layer, Service,
};
Expand Down Expand Up @@ -155,9 +155,9 @@ where
.get::<RequestContext>()
.unwrap()
.authority
.to_string();
tracing::info!("CONNECT to {}", authority);
let mut stream = match tokio::net::TcpStream::connect(authority).await {
.clone();
tracing::info!("CONNECT to {authority}");
let (mut stream, _) = match default_tcp_connect(&ctx, authority).await {
Ok(stream) => stream,
Err(err) => {
tracing::error!(error = %err, "error connecting to host");
Expand Down
8 changes: 4 additions & 4 deletions rama-cli/src/cmd/proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use rama::{
net::stream::layer::http::BodyLimitLayer,
rt::Executor,
service::service_fn,
tcp::{server::TcpListener, utils::is_connection_error},
tcp::{client::default_tcp_connect, server::TcpListener, utils::is_connection_error},
Context, Layer, Service,
};
use std::{convert::Infallible, time::Duration};
Expand Down Expand Up @@ -128,9 +128,9 @@ where
.get::<RequestContext>()
.unwrap()
.authority
.to_string();
tracing::info!("CONNECT to {}", authority);
let mut stream = match tokio::net::TcpStream::connect(authority).await {
.clone();
tracing::info!("CONNECT to {authority}");
let (mut stream, _) = match default_tcp_connect(&ctx, authority).await {
Ok(stream) => stream,
Err(err) => {
tracing::error!(error = %err, "error connecting to host");
Expand Down
1 change: 0 additions & 1 deletion rama-core/src/matcher/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ mod private {
}
}

#[cfg(test)]
#[cfg(test)]
mod test {
use super::*;
Expand Down
18 changes: 17 additions & 1 deletion rama-tcp/src/client/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rama_core::{
error::{BoxError, ErrorContext, OpaqueError},
Context,
};
use rama_dns::{DnsOverwrite, DnsResolver};
use rama_dns::{DnsOverwrite, DnsResolver, HickoryDns};
use rama_net::address::{Authority, Domain, Host};
use std::{
future::Future,
Expand Down Expand Up @@ -100,6 +100,22 @@ macro_rules! impl_stream_connector_either {

::rama_core::combinators::impl_either!(impl_stream_connector_either);

#[inline]
/// Establish a [`TcpStream`] connection for the given [`Authority`],
/// using the default settings and no custom state.
///
/// Use [`tcp_connect`] in case you want to customise any of these settings,
/// or use a [`rama_net::client::ConnectorService`] for even more advanced possibilities.
pub async fn default_tcp_connect<State>(
ctx: &Context<State>,
authority: Authority,
) -> Result<(TcpStream, SocketAddr), OpaqueError>
where
State: Clone + Send + Sync + 'static,
{
tcp_connect(ctx, authority, true, HickoryDns::default(), ()).await
}

/// Establish a [`TcpStream`] connection for the given [`Authority`].
pub async fn tcp_connect<State, Dns, Connector>(
ctx: &Context<State>,
Expand Down
2 changes: 1 addition & 1 deletion rama-tcp/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod service;

mod connect;
#[doc(inline)]
pub use connect::{tcp_connect, TcpStreamConnector};
pub use connect::{default_tcp_connect, tcp_connect, TcpStreamConnector};

#[cfg(feature = "http")]
mod request;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use super::utils;
use rama::{tcp::client::default_tcp_connect, Context};
use tokio::io::{AsyncReadExt, AsyncWriteExt};

#[tokio::test]
#[ignore]
async fn test_tcp_listener_layers() {
utils::init_tracing();

let runner = utils::ExampleRunner::interactive("tcp_listener_layers", None);
let _runner = utils::ExampleRunner::<()>::interactive("tcp_listener_layers", None);

let mut stream = None;
let ctx = Context::default();
for i in 0..5 {
match runner.connect_tcp("127.0.0.1:62501").await {
Ok(s) => stream = Some(s),
match default_tcp_connect(&ctx, ([127, 0, 0, 1], 62501).into()).await {
Ok((s, _)) => stream = Some(s),
Err(e) => {
eprintln!("connect_tcp error: {}", e);
tokio::time::sleep(std::time::Duration::from_millis(500 + 250 * i)).await;
Expand Down
14 changes: 1 addition & 13 deletions tests/integration/examples/example_tests/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(dead_code)]

use rama::{
error::{BoxError, OpaqueError},
error::BoxError,
http::client::proxy::layer::SetProxyAuthHttpHeaderLayer,
http::service::client::{HttpClientExt, IntoUrl, RequestBuilder},
http::{
Expand All @@ -15,7 +15,6 @@ use rama::{
Request, Response,
},
layer::MapResultLayer,
net::stream::Stream,
service::BoxService,
utils::{backoff::ExponentialBackoff, rng::HasherRng},
Layer, Service,
Expand All @@ -25,7 +24,6 @@ use std::{
sync::Once,
time::Duration,
};
use tokio::net::ToSocketAddrs;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};

Expand Down Expand Up @@ -201,16 +199,6 @@ impl ExampleRunner<()> {
.await
.unwrap()
}

/// Establish an async R/W to the TCP server behind this [`ExampleRunner`].
pub(super) async fn connect_tcp(
&self,
addr: impl ToSocketAddrs,
) -> Result<impl Stream, OpaqueError> {
tokio::net::TcpStream::connect(addr)
.await
.map_err(OpaqueError::from_std)
}
}

impl<State> std::ops::Drop for ExampleRunner<State> {
Expand Down

0 comments on commit d7f5a3e

Please sign in to comment.