Skip to content

Commit

Permalink
Merge branch 'generalize-encrypted-dns-proxy-forwarder'
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Oct 21, 2024
2 parents c32e72a + b786cf4 commit 36a6113
Showing 1 changed file with 62 additions and 50 deletions.
112 changes: 62 additions & 50 deletions mullvad-encrypted-dns-proxy/src/forwarder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
//! Forward TCP traffic over various proxy configurations.
use std::{
io,
task::{ready, Poll},
};
use std::io;

use tokio::{
io::{AsyncRead, AsyncWrite},
Expand All @@ -16,17 +13,18 @@ use crate::config::Obfuscator;
///
/// Obtain [`ProxyConfig`](crate::config::ProxyConfig)s with
/// [resolve_configs](crate::config_resolver::resolve_configs).
pub struct Forwarder {
pub struct Forwarder<S> {
read_obfuscator: Option<Box<dyn Obfuscator>>,
write_obfuscator: Option<Box<dyn Obfuscator>>,
server_connection: TcpStream,
stream: S,
}

impl Forwarder {
/// Create a forwarder that will connect to a given proxy endpoint.
pub async fn connect(proxy_config: &crate::config::ProxyConfig) -> io::Result<Self> {
let server_connection = TcpStream::connect(proxy_config.addr).await?;

impl<S> Forwarder<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
/// Create a [`Forwarder`] with a connected `stream` to an encrypted DNS proxy server
pub fn from_stream(proxy_config: &crate::config::ProxyConfig, stream: S) -> io::Result<Self> {
let (read_obfuscator, write_obfuscator) =
if let Some(obfuscation_config) = &proxy_config.obfuscation {
(
Expand All @@ -40,14 +38,23 @@ impl Forwarder {
Ok(Self {
read_obfuscator,
write_obfuscator,
server_connection,
stream,
})
}
}

/// Forward TCP traffic over various proxy configurations.
impl Forwarder<TcpStream> {
/// Create a forwarder that will connect to a given proxy endpoint.
pub async fn connect(proxy_config: &crate::config::ProxyConfig) -> io::Result<Self> {
let server_connection = TcpStream::connect(proxy_config.addr).await?;
Self::from_stream(proxy_config, server_connection)
}

/// Forwards traffic from the client stream to the remote proxy, obfuscating and deobfuscating
/// it in the process.
pub async fn forward(self, client_stream: TcpStream) {
let (server_read, server_write) = self.server_connection.into_split();
let (server_read, server_write) = self.stream.into_split();
let (client_read, client_write) = client_stream.into_split();
let _ = tokio::join!(
forward(self.read_obfuscator, client_read, server_write),
Expand All @@ -56,13 +63,38 @@ impl Forwarder {
}
}

impl tokio::io::AsyncRead for Forwarder {
async fn forward(
mut obfuscator: Option<Box<dyn Obfuscator>>,
mut source: impl AsyncRead + Unpin,
mut sink: impl AsyncWrite + Unpin,
) -> io::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = vec![0u8; 1024 * 64];
while let Ok(n_bytes_read) = AsyncReadExt::read(&mut source, &mut buf).await {
if n_bytes_read == 0 {
break;
}
let bytes_received = &mut buf[..n_bytes_read];

if let Some(obfuscator) = &mut obfuscator {
obfuscator.obfuscate(bytes_received);
}
sink.write_all(bytes_received).await?;
}
Ok(())
}

impl<S> tokio::io::AsyncRead for Forwarder<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let socket = std::pin::pin!(&mut self.server_connection);
use std::task::{ready, Poll};
let socket = std::pin::pin!(&mut self.stream);
match ready!(socket.poll_read(cx, buf)) {
// in this case, we can read and deobfuscate.
Ok(()) => {
Expand All @@ -76,61 +108,41 @@ impl tokio::io::AsyncRead for Forwarder {
}
}

impl tokio::io::AsyncWrite for Forwarder {
impl<S> tokio::io::AsyncWrite for Forwarder<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let socket = std::pin::pin!(&mut self.server_connection);
if let Err(err) = ready!(socket.poll_write_ready(cx)) {
return Poll::Ready(Err(err));
};

) -> std::task::Poll<Result<usize, io::Error>> {
let mut owned_buf = buf.to_vec();
if let Some(write_obfuscator) = &mut self.write_obfuscator {
write_obfuscator.obfuscate(&mut owned_buf);
}
let socket = std::pin::pin!(&mut self.server_connection);
socket.poll_write(cx, &owned_buf)
let stream = std::pin::pin!(&mut self.stream);
// If the object is not ready for writing, the method returns Poll::Pending
// and arranges for the current task (via cx.waker()) to receive a notification
// when the object becomes writable or is closed.
stream.poll_write(cx, &owned_buf)
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
std::pin::pin!(&mut self.server_connection).poll_flush(cx)
) -> std::task::Poll<Result<(), io::Error>> {
std::pin::pin!(&mut self.stream).poll_flush(cx)
}

fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
std::pin::pin!(&mut self.server_connection).poll_shutdown(cx)
) -> std::task::Poll<Result<(), io::Error>> {
std::pin::pin!(&mut self.stream).poll_shutdown(cx)
}
}

async fn forward(
mut obfuscator: Option<Box<dyn Obfuscator>>,
mut source: impl AsyncRead + Unpin,
mut sink: impl AsyncWrite + Unpin,
) -> io::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = vec![0u8; 1024 * 64];
while let Ok(n_bytes_read) = AsyncReadExt::read(&mut source, &mut buf).await {
if n_bytes_read == 0 {
break;
}
let bytes_received = &mut buf[..n_bytes_read];

if let Some(obfuscator) = &mut obfuscator {
obfuscator.obfuscate(bytes_received);
}
sink.write_all(bytes_received).await?;
}
Ok(())
}

#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddrV4};
Expand Down Expand Up @@ -169,7 +181,7 @@ mod tests {
let mut forwarder = Forwarder {
read_obfuscator: Some(obfuscation_config.create_obfuscator()),
write_obfuscator: Some(obfuscation_config.create_obfuscator()),
server_connection: client_conn,
stream: client_conn,
};
let mut buf = vec![0u8; 1024];
while let Ok(bytes_read) = forwarder.read(&mut buf).await {
Expand Down

0 comments on commit 36a6113

Please sign in to comment.