Skip to content

Commit

Permalink
Replace old client with hyper_util legacy client
Browse files Browse the repository at this point in the history
  • Loading branch information
Serock3 committed Oct 2, 2024
1 parent fdefcd0 commit 81336c8
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 63 deletions.
31 changes: 21 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mullvad-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ http = "1.1.0"
hyper = { version = "1.4.1", features = ["client", "http1"] }
hyper-util = { workspace = true}
http-body-util = "0.1.2"
tower = "0.5.1"
ipnetwork = { workspace = true }
log = { workspace = true }
serde = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion mullvad-api/src/abortable_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! immediately instead of after the socket times out.
use futures::{channel::oneshot, future::Fuse, FutureExt};
use hyper::client::connect::{Connected, Connection};
use hyper_util::client::legacy::connect::{Connected, Connection};
use std::{
future::Future,
io,
Expand Down
13 changes: 7 additions & 6 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
use hyper::{
client::connect::dns::{GaiResolver, Name},
service::Service,
Uri,
use hyper::Uri;
use hyper_util::{
client::legacy::connect::dns::{GaiResolver, Name},
rt::TokioIo,
};
use shadowsocks::{
config::ServerType,
Expand Down Expand Up @@ -39,6 +39,7 @@ use tokio::{
net::{TcpSocket, TcpStream},
time::timeout,
};
use tower::Service;

#[cfg(feature = "api-override")]
use crate::{proxy::ConnectionDecorator, API};
Expand Down Expand Up @@ -407,7 +408,7 @@ impl fmt::Debug for HttpsConnectorWithSni {
}

impl Service<Uri> for HttpsConnectorWithSni {
type Response = AbortableStream<ApiConnection>;
type Response = TokioIo<AbortableStream<ApiConnection>>;
type Error = io::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
Expand Down Expand Up @@ -472,7 +473,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
inner.stream_handles.push(socket_handle);
}

Ok(stream)
Ok(TokioIo::new(stream))
};

Box::pin(fut)
Expand Down
21 changes: 8 additions & 13 deletions mullvad-api/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use hyper::client::connect::Connected;
use hyper_util::client::legacy::connect::{Connected, Connection};
use serde::{Deserialize, Serialize};
use std::{
fmt, io,
Expand Down Expand Up @@ -192,7 +192,7 @@ impl ApiConnectionMode {
}
}

/// Implements `hyper::client::connect::Connection` by wrapping a type.
/// Implements `Connection` by wrapping a type.
pub struct ConnectionDecorator<T: AsyncRead + AsyncWrite>(pub T);

impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for ConnectionDecorator<T> {
Expand Down Expand Up @@ -223,26 +223,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for ConnectionDecorator<T> {
}
}

impl<T: AsyncRead + AsyncWrite> hyper::client::connect::Connection for ConnectionDecorator<T> {
impl<T: AsyncRead + AsyncWrite> Connection for ConnectionDecorator<T> {
fn connected(&self) -> Connected {
Connected::new()
}
}

trait Connection: AsyncRead + AsyncWrite + Unpin + hyper::client::connect::Connection + Send {}
trait ConnectionMullvad: AsyncRead + AsyncWrite + Unpin + Connection + Send {}

impl<T: AsyncRead + AsyncWrite + Unpin + hyper::client::connect::Connection + Send> Connection
for T
{
}
impl<T: AsyncRead + AsyncWrite + Unpin + Connection + Send> ConnectionMullvad for T {}

/// Stream that represents a Mullvad API connection
pub struct ApiConnection(Box<dyn Connection>);
pub struct ApiConnection(Box<dyn ConnectionMullvad>);

impl ApiConnection {
pub fn new<
T: AsyncRead + AsyncWrite + Unpin + hyper::client::connect::Connection + Send + 'static,
>(
pub fn new<T: AsyncRead + AsyncWrite + Unpin + Connection + Send + 'static>(
conn: Box<T>,
) -> Self {
Self(conn)
Expand Down Expand Up @@ -277,7 +272,7 @@ impl AsyncWrite for ApiConnection {
}
}

impl hyper::client::connect::Connection for ApiConnection {
impl Connection for ApiConnection {
fn connected(&self) -> Connected {
self.0.connected()
}
Expand Down
53 changes: 28 additions & 25 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use hyper::{
Method,
Uri,
};
use hyper_util::client::legacy::connect::Connect;
use mullvad_types::account::AccountNumber;
use std::{
borrow::Cow,
Expand Down Expand Up @@ -46,6 +47,9 @@ pub enum Error {
#[error("Request cancelled")]
Aborted,

#[error("Legacy hyper error")]
LegacyHyperError(#[from] Arc<hyper_util::client::legacy::Error>),

#[error("Hyper error")]
HyperError(#[from] Arc<hyper::Error>),

Expand Down Expand Up @@ -87,14 +91,15 @@ impl Error {
/// Return true if there was no route to the destination
pub fn is_offline(&self) -> bool {
match self {
Error::HyperError(error) if error.is_connect() => {
Error::LegacyHyperError(error) if error.is_connect() => {
if let Some(cause) = error.source() {
if let Some(err) = cause.downcast_ref::<std::io::Error>() {
return err.raw_os_error() == Some(libc::ENETUNREACH);
}
}
false
}
// TODO: Match on `Error::HyperError` too?
_ => false,
}
}
Expand Down Expand Up @@ -129,11 +134,7 @@ pub(crate) struct RequestService<T: ConnectionModeProvider> {
command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
command_rx: mpsc::UnboundedReceiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
// client: hyper_util::client::legacy::Client<
// HttpsConnectorWithSni,
// BoxBody<dyn hyper::body::Buf, Error>,
// >,
client: HttpsConnectorWithSni,
client: hyper_util::client::legacy::Client<HttpsConnectorWithSni, BoxBody<Bytes, Error>>,
connection_mode_provider: T,
connection_mode_generation: usize,
api_availability: ApiAvailability,
Expand All @@ -158,16 +159,17 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
connector_handle.set_connection_mode(connection_mode_provider.initial());

let (command_tx, command_rx) = mpsc::unbounded();
// let client =
// hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);
let client =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(connector);

let command_tx = Arc::new(command_tx);

let service = Self {
command_tx: Arc::downgrade(&command_tx),
command_rx,
connector_handle,
client: connector,
client,
connection_mode_provider,
connection_mode_generation: 0,
api_availability,
Expand Down Expand Up @@ -299,24 +301,24 @@ pub struct Request<B> {
}

// TODO: merge with `RequestFactory::get`
/// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
/// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
pub fn get(uri: &str) -> Result<Request<Empty<Bytes>>> {
let uri = hyper::Uri::from_str(uri)?;

let mut builder = http::request::Builder::new()
.method(Method::GET)
.header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"));
if let Some(host) = uri.host() {
builder = builder.header(
header::HOST,
HeaderValue::from_str(host).map_err(|_e| Error::InvalidHeaderError)?,
);
};
let uri = hyper::Uri::from_str(uri)?;

let mut builder = http::request::Builder::new()
.method(Method::GET)
.header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"));
if let Some(host) = uri.host() {
builder = builder.header(
header::HOST,
HeaderValue::from_str(host).map_err(|_e| Error::InvalidHeaderError)?,
);
};

let request = builder.uri(uri).body(Empty::<Bytes>::new())?;
Ok(Request::new(request, None))
}
let request = builder.uri(uri).body(Empty::<Bytes>::new())?;
Ok(Request::new(request, None))
}

impl<B: Body> Request<B> {
fn new(request: hyper::Request<B>, access_token_store: Option<AccessTokenStore>) -> Self {
Expand Down Expand Up @@ -724,6 +726,7 @@ macro_rules! impl_into_arc_err {
}

impl_into_arc_err!(hyper::Error);
impl_into_arc_err!(hyper_util::client::legacy::Error);
impl_into_arc_err!(serde_json::Error);
impl_into_arc_err!(http::Error);
impl_into_arc_err!(http::uri::InvalidUri);
Loading

0 comments on commit 81336c8

Please sign in to comment.