From 7448ad267bcab06fed0bb3a278e5f91b9d885b81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Wed, 20 Nov 2024 08:35:38 +0100 Subject: [PATCH 1/7] Add comment about tokio runtime and getaddrinfo --- mullvad-jni/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 755cfce62341..1dd7c8694263 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -122,6 +122,9 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_shutdow if let Some(context) = DAEMON_CONTEXT.lock().unwrap().take() { _ = context.daemon_command_tx.shutdown(); _ = context.runtime.block_on(context.running_daemon); + + // Dropping the tokio runtime will block if there are any tasks in flight. + // That is, until all async tasks yield *and* all blocking threads have stopped. } } From 8ababf0f77b23f7245a1aed3d8c8c4a5e3c06192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20G=C3=B6ransson?= Date: Tue, 19 Nov 2024 11:04:30 +0100 Subject: [PATCH 2/7] Expose current dns servers --- .../mullvad/talpid/ConnectivityListener.kt | 22 +++++++++++++++++++ .../net/mullvad/talpid/TalpidVpnService.kt | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index edeec9a6fe9d..c2f0aef20f2d 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -3,9 +3,12 @@ package net.mullvad.talpid import android.content.Context import android.net.ConnectivityManager import android.net.ConnectivityManager.NetworkCallback +import android.net.LinkProperties import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkRequest +import co.touchlab.kermit.Logger +import java.net.InetAddress import kotlin.properties.Delegates.observable class ConnectivityListener { @@ -24,6 +27,14 @@ class ConnectivityListener { } } + private val defaultNetworkCallback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) { + super.onLinkPropertiesChanged(network, linkProperties) + currentDnsServers = ArrayList(linkProperties.dnsServers) + } + } + private lateinit var connectivityManager: ConnectivityManager // Used by JNI @@ -36,6 +47,12 @@ class ConnectivityListener { } } + var currentDnsServers: ArrayList = ArrayList() + private set(value) { + field = ArrayList(value.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }) + Logger.d("New currentDnsServers: $field") + } + var senderAddress = 0L fun register(context: Context) { @@ -49,10 +66,15 @@ class ConnectivityListener { context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager connectivityManager.registerNetworkCallback(request, callback) + currentDnsServers = + connectivityManager.getLinkProperties(connectivityManager.activeNetwork)?.dnsServers?.let { ArrayList(it) } + ?: ArrayList() + connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback) } fun unregister() { connectivityManager.unregisterNetworkCallback(callback) + connectivityManager.unregisterNetworkCallback(defaultNetworkCallback) } // DROID-1401 diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 9470c88318b6..61c0be2ccf65 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -161,7 +161,7 @@ open class TalpidVpnService : LifecycleVpnService() { private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean) companion object { - private const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1" + const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1" private const val IPV4_PREFIX_LENGTH = 32 private const val IPV6_PREFIX_LENGTH = 128 From f4db85b3a552f60d2454bfa69912c7ced51b41b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Tue, 19 Nov 2024 10:25:44 +0100 Subject: [PATCH 3/7] Add non-blocking DNS resolver for Android API requests --- Cargo.lock | 3 + Cargo.toml | 3 + mullvad-api/Cargo.toml | 1 + mullvad-api/src/bin/relay_list.rs | 6 +- mullvad-api/src/https_client_with_sni.rs | 36 +-- mullvad-api/src/lib.rs | 52 +++- mullvad-api/src/rest.rs | 3 + mullvad-daemon/Cargo.toml | 2 + mullvad-daemon/src/android_dns.rs | 49 ++++ mullvad-daemon/src/lib.rs | 23 ++ mullvad-encrypted-dns-proxy/Cargo.toml | 2 +- mullvad-problem-report/src/lib.rs | 3 +- mullvad-setup/src/main.rs | 4 +- talpid-core/Cargo.toml | 4 +- talpid-core/src/connectivity_listener.rs | 249 ++++++++++++++++++++ talpid-core/src/lib.rs | 4 + talpid-core/src/offline/android.rs | 211 +---------------- talpid-core/src/offline/mod.rs | 8 +- talpid-core/src/tunnel_state_machine/mod.rs | 13 +- test/Cargo.lock | 62 ++--- test/test-manager/src/tests/account.rs | 7 +- 21 files changed, 480 insertions(+), 265 deletions(-) create mode 100644 mullvad-daemon/src/android_dns.rs create mode 100644 talpid-core/src/connectivity_listener.rs diff --git a/Cargo.lock b/Cargo.lock index d17347167318..6d660b472340 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2318,6 +2318,7 @@ dependencies = [ name = "mullvad-api" version = "0.0.0" dependencies = [ + "async-trait", "cbindgen", "chrono", "futures", @@ -2373,6 +2374,7 @@ name = "mullvad-daemon" version = "0.0.0" dependencies = [ "android_logger", + "async-trait", "chrono", "clap", "ctrlc", @@ -2380,6 +2382,7 @@ dependencies = [ "either", "fern", "futures", + "hickory-resolver", "libc", "log", "log-panics", diff --git a/Cargo.toml b/Cargo.toml index da59a6a6c84c..8c2d22a043c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,9 @@ single_use_lifetimes = "warn" unused_async = "deny" [workspace.dependencies] +hickory-proto = "0.24.1" +hickory-resolver = "0.24.1" +hickory-server = { version = "0.24.1", features = ["resolver"] } tokio = { version = "1.8" } parity-tokio-ipc = "0.9" futures = "0.3.15" diff --git a/mullvad-api/Cargo.toml b/mullvad-api/Cargo.toml index e617d942b522..a822593600e0 100644 --- a/mullvad-api/Cargo.toml +++ b/mullvad-api/Cargo.toml @@ -15,6 +15,7 @@ workspace = true api-override = [] [dependencies] +async-trait = "0.1" libc = "0.2" chrono = { workspace = true } thiserror = { workspace = true } diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index def32303eaef..22190abd63db 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,13 +2,15 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; +use mullvad_api::{ + proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy, +}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) + let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver) .expect("Failed to load runtime"); let relay_list_request = diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 898927513f96..09e198ca3bcf 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -2,17 +2,14 @@ use crate::{ abortable_stream::{AbortableStream, AbortableStreamHandle}, proxy::{ApiConnection, ApiConnectionMode, ProxyConfig}, tls_stream::TlsStream, - AddressCache, + AddressCache, DnsResolver, }; use futures::{channel::mpsc, future, pin_mut, StreamExt}; #[cfg(target_os = "android")] use futures::{channel::oneshot, sink::SinkExt}; use http::uri::Scheme; use hyper::Uri; -use hyper_util::{ - client::legacy::connect::dns::{GaiResolver, Name}, - rt::TokioIo, -}; +use hyper_util::rt::TokioIo; use mullvad_encrypted_dns_proxy::{ config::ProxyConfig as EncryptedDNSConfig, Forwarder as EncryptedDNSForwarder, }; @@ -291,6 +288,7 @@ pub struct HttpsConnectorWithSni { sni_hostname: Option, address_cache: AddressCache, abort_notify: Arc, + dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, } @@ -307,6 +305,7 @@ impl HttpsConnectorWithSni { pub fn new( sni_hostname: Option, address_cache: AddressCache, + dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> (Self, HttpsConnectorWithSniHandle) { let (tx, mut rx) = mpsc::unbounded(); @@ -355,6 +354,7 @@ impl HttpsConnectorWithSni { sni_hostname, address_cache, abort_notify, + dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, }, @@ -388,7 +388,14 @@ impl HttpsConnectorWithSni { .map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))? } - async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result { + /// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used. + /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback. + /// If the URI contains a port, then that port will be used. + async fn resolve_address( + address_cache: AddressCache, + dns_resolver: &dyn DnsResolver, + uri: Uri, + ) -> io::Result { const DEFAULT_PORT: u16 = 443; let hostname = uri.host().ok_or_else(|| { @@ -408,19 +415,13 @@ impl HttpsConnectorWithSni { )); } - // Use getaddrinfo as a fallback + // Use DNS resolution as fallback // - let mut addrs = GaiResolver::new() - .call( - Name::from_str(hostname) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?, - ) - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let addrs = dns_resolver.resolve(hostname.to_owned()).await?; let addr = addrs - .next() + .first() .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; - Ok(SocketAddr::new(addr.ip(), port.unwrap_or(DEFAULT_PORT))) + Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT))) } } @@ -455,6 +456,7 @@ impl Service for HttpsConnectorWithSni { #[cfg(target_os = "android")] let socket_bypass_tx = self.socket_bypass_tx.clone(); let address_cache = self.address_cache.clone(); + let dns_resolver = self.dns_resolver.clone(); let fut = async move { if uri.scheme() != Some(&Scheme::HTTPS) { @@ -465,7 +467,7 @@ impl Service for HttpsConnectorWithSni { } let hostname = sni_hostname?; - let addr = Self::resolve_address(address_cache, uri).await?; + let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?; // Loop until we have established a connection. This starts over if a new endpoint // is selected while connecting. diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 6b3ac3c951b7..1f47d600b30f 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -1,4 +1,5 @@ #![allow(rustdoc::private_intra_doc_links)] +use async_trait::async_trait; #[cfg(target_os = "android")] use futures::channel::mpsc; #[cfg(target_os = "android")] @@ -12,10 +13,11 @@ use std::{ cell::Cell, collections::BTreeMap, future::Future, + io, net::{IpAddr, Ipv4Addr, SocketAddr}, ops::Deref, path::Path, - sync::OnceLock, + sync::{Arc, OnceLock}, }; use talpid_types::ErrorExt; @@ -304,11 +306,43 @@ impl ApiEndpoint { } } +#[async_trait] +pub trait DnsResolver: 'static + Send + Sync { + async fn resolve(&self, host: String) -> io::Result>; +} + +/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`). +pub struct DefaultDnsResolver; + +#[async_trait] +impl DnsResolver for DefaultDnsResolver { + async fn resolve(&self, host: String) -> io::Result> { + use std::net::ToSocketAddrs; + // Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which + // blocks and either has no timeout or a very long one. + let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs()) + .await + .expect("DNS task panicked")?; + Ok(addrs.map(|addr| addr.ip()).collect()) + } +} + +/// DNS resolver that always returns no results +pub struct NullDnsResolver; + +#[async_trait] +impl DnsResolver for NullDnsResolver { + async fn resolve(&self, _host: String) -> io::Result> { + Ok(vec![]) + } +} + /// A type that helps with the creation of API connections. pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, + dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, } @@ -323,13 +357,20 @@ pub enum Error { #[error("API availability check failed")] ApiCheckError(#[from] availability::Error), + + #[error("DNS resolution error")] + ResolutionFailed(#[from] std::io::Error), } impl Runtime { /// Create a new `Runtime`. - pub fn new(handle: tokio::runtime::Handle) -> Result { + pub fn new( + handle: tokio::runtime::Handle, + dns_resolver: impl DnsResolver, + ) -> Result { Self::new_inner( handle, + dns_resolver, #[cfg(target_os = "android")] None, ) @@ -346,12 +387,14 @@ impl Runtime { fn new_inner( handle: tokio::runtime::Handle, + dns_resolver: impl DnsResolver, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> Result { Ok(Runtime { handle, address_cache: AddressCache::new(None)?, api_availability: ApiAvailability::default(), + dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -360,15 +403,18 @@ impl Runtime { /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. pub async fn with_cache( + dns_resolver: impl DnsResolver, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> Result { let handle = tokio::runtime::Handle::current(); + #[cfg(feature = "api-override")] if API.disable_address_cache { return Self::new_inner( handle, + dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, ); @@ -402,6 +448,7 @@ impl Runtime { handle, address_cache, api_availability, + dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -419,6 +466,7 @@ impl Runtime { self.api_availability.clone(), self.address_cache.clone(), connection_mode_provider, + self.dns_resolver.clone(), #[cfg(target_os = "android")] socket_bypass_tx, ) diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index f6098c3b49c0..54a32f63f937 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -6,6 +6,7 @@ use crate::{ availability::ApiAvailability, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, proxy::ConnectionModeProvider, + DnsResolver, }; use futures::{ channel::{mpsc, oneshot}, @@ -154,11 +155,13 @@ impl RequestService { api_availability: ApiAvailability, address_cache: AddressCache, connection_mode_provider: T, + dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( sni_hostname, address_cache.clone(), + dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx.clone(), ); diff --git a/mullvad-daemon/Cargo.toml b/mullvad-daemon/Cargo.toml index 778ef02d7f9f..d0986fa70f71 100644 --- a/mullvad-daemon/Cargo.toml +++ b/mullvad-daemon/Cargo.toml @@ -51,6 +51,8 @@ tokio = { workspace = true, features = ["test-util"] } [target.'cfg(target_os="android")'.dependencies] android_logger = "0.8" +async-trait = "0.1" +hickory-resolver = { workspace = true } [target.'cfg(unix)'.dependencies] nix = "0.23" diff --git a/mullvad-daemon/src/android_dns.rs b/mullvad-daemon/src/android_dns.rs new file mode 100644 index 000000000000..ed44f5dc8c6d --- /dev/null +++ b/mullvad-daemon/src/android_dns.rs @@ -0,0 +1,49 @@ +#![cfg(target_os = "android")] +//! See [AndroidDnsResolver]. + +use async_trait::async_trait; +use hickory_resolver::{ + config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, + TokioAsyncResolver, +}; +use mullvad_api::DnsResolver; +use std::{io, net::IpAddr}; +use talpid_core::connectivity_listener::ConnectivityListener; + +/// A non-blocking DNS resolver. The default resolver uses `getaddrinfo`, which often prevents the +/// tokio runtime from being dropped, since it waits indefinitely on blocking threads. This is +/// particularly bad on Android, so we use a non-blocking resolver instead. +pub struct AndroidDnsResolver { + connectivity_listener: ConnectivityListener, +} + +impl AndroidDnsResolver { + pub fn new(connectivity_listener: ConnectivityListener) -> Self { + Self { + connectivity_listener, + } + } +} + +#[async_trait] +impl DnsResolver for AndroidDnsResolver { + async fn resolve(&self, host: String) -> io::Result> { + let ips = self + .connectivity_listener + .current_dns_servers() + .map_err(|err| { + io::Error::other(format!("Failed to retrieve current servers: {err}")) + })?; + let group = NameServerConfigGroup::from_ips_clear(&ips, 53, false); + + let config = ResolverConfig::from_parts(None, vec![], group); + let resolver = TokioAsyncResolver::tokio(config, ResolverOpts::default()); + + let lookup = resolver + .lookup_ip(host) + .await + .map_err(|err| io::Error::other(format!("lookup_ip failed: {err}")))?; + + Ok(lookup.into_iter().collect()) + } +} diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 8299e1d20c6d..0f2914fdff8a 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -3,6 +3,7 @@ mod access_method; pub mod account_history; +mod android_dns; mod api; mod api_address_updater; #[cfg(not(target_os = "android"))] @@ -38,6 +39,8 @@ use futures::{ }; use geoip::GeoIpHandler; use management_interface::ManagementInterfaceServer; +#[cfg(not(target_os = "android"))] +use mullvad_api::DefaultDnsResolver; use mullvad_relay_selector::{RelaySelector, SelectorConfig}; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; @@ -91,6 +94,9 @@ use talpid_types::{ }; use tokio::io; +#[cfg(target_os = "android")] +use talpid_core::connectivity_listener::ConnectivityListener; + /// Delay between generating a new WireGuard key and reconnecting const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60); @@ -604,8 +610,23 @@ impl Daemon { let (internal_event_tx, internal_event_rx) = daemon_command_channel.destructure(); + #[cfg(target_os = "android")] + let connectivity_listener = ConnectivityListener::new(android_context.clone()) + .inspect_err(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to start connectivity listener") + ); + }) + .map_err(|_| Error::DaemonUnavailable)?; + mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; let api_runtime = mullvad_api::Runtime::with_cache( + // FIXME: clone is bad (single sender) + #[cfg(target_os = "android")] + android_dns::AndroidDnsResolver::new(connectivity_listener.clone()), + #[cfg(not(target_os = "android"))] + DefaultDnsResolver, &cache_dir, true, #[cfg(target_os = "android")] @@ -777,6 +798,8 @@ impl Daemon { volume_update_rx, #[cfg(target_os = "android")] android_context, + #[cfg(target_os = "android")] + connectivity_listener, #[cfg(target_os = "linux")] tunnel_state_machine::LinuxNetworkingIdentifiers { fwmark: mullvad_types::TUNNEL_FWMARK, diff --git a/mullvad-encrypted-dns-proxy/Cargo.toml b/mullvad-encrypted-dns-proxy/Cargo.toml index e5ed53056cac..1326337a66a1 100644 --- a/mullvad-encrypted-dns-proxy/Cargo.toml +++ b/mullvad-encrypted-dns-proxy/Cargo.toml @@ -13,7 +13,7 @@ workspace = true [dependencies] tokio = { workspace = true, features = [ "macros" ] } log = { workspace = true } -hickory-resolver = { version = "0.24.1", features = [ "dns-over-https-rustls" ]} +hickory-resolver = { workspace = true, features = [ "dns-over-https-rustls" ]} serde = { workspace = true, features = ["derive"] } webpki-roots = "0.25.0" rustls = "0.21" diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 270de55f9589..91b790e5f6fe 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -1,4 +1,4 @@ -use mullvad_api::proxy::ApiConnectionMode; +use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver}; use regex::Regex; use std::{ borrow::Cow, @@ -292,6 +292,7 @@ async fn send_problem_report_inner( ) -> Result<(), Error> { let metadata = ProblemReport::parse_metadata(report_content).unwrap_or_else(metadata::collect); let api_runtime = mullvad_api::Runtime::with_cache( + NullDnsResolver, cache_dir, false, #[cfg(target_os = "android")] diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index d3dfd6de8ac4..4a444aa63cdc 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -1,7 +1,7 @@ use clap::Parser; use std::{path::PathBuf, process, str::FromStr, sync::LazyLock, time::Duration}; -use mullvad_api::{proxy::ApiConnectionMode, DEVICE_NOT_FOUND}; +use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver, DEVICE_NOT_FOUND}; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::version::ParsedAppVersion; use talpid_core::firewall::{self, Firewall}; @@ -152,7 +152,7 @@ async fn remove_device() -> Result<(), Error> { .await .map_err(Error::ReadDeviceCacheError)?; if let Some(device) = state.into_device() { - let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false) + let api_runtime = mullvad_api::Runtime::with_cache(NullDnsResolver, &cache_path, false) .await .map_err(Error::RpcInitializationError)?; diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 690d5797cc0f..7ec3fd20f288 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -50,8 +50,8 @@ duct = "0.13" pfctl = "0.6.1" subslice = "0.2" system-configuration = "0.5.1" -hickory-proto = "0.24.1" -hickory-server = { version = "0.24.1", features = ["resolver"] } +hickory-proto = { workspace = true } +hickory-server = { workspace = true, features = ["resolver"] } talpid-platform-metadata = { path = "../talpid-platform-metadata" } pcap = { version = "2.1", features = ["capture-stream"] } pnet_packet = "0.34" diff --git a/talpid-core/src/connectivity_listener.rs b/talpid-core/src/connectivity_listener.rs new file mode 100644 index 000000000000..b28f58cefbec --- /dev/null +++ b/talpid-core/src/connectivity_listener.rs @@ -0,0 +1,249 @@ +//! Rust wrapper around Android connectivity listener + +use futures::channel::mpsc::UnboundedSender; +use jnix::{ + jni::{ + self, + objects::{GlobalRef, JObject, JValue}, + signature::{JavaType, Primitive}, + sys::{jboolean, jlong, JNI_TRUE}, + JNIEnv, JavaVM, + }, + JnixEnv, + FromJava, +}; +use std::{net::IpAddr, sync::{Arc, Weak}}; +use talpid_types::{android::AndroidContext, net::Connectivity, ErrorExt}; + +/// Error related to Android connectivity monitor +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Failed to attach Java VM to tunnel thread + #[error("Failed to attach Java VM to tunnel thread")] + AttachJvmToThread(#[source] jni::errors::Error), + + /// Failed to call Java method + #[error("Failed to call Java method {0}.{1}")] + CallMethod(&'static str, &'static str, #[source] jni::errors::Error), + + /// Failed to create global reference to Java object + #[error("Failed to create global reference to Java object")] + CreateGlobalRef(#[source] jni::errors::Error), + + /// Failed to find method + #[error("Failed to find {0}.{1} method")] + FindMethod(&'static str, &'static str, #[source] jni::errors::Error), + + /// Method returned invalid result + #[error("Received an invalid result from {0}.{1}: {2}")] + InvalidMethodResult(&'static str, &'static str, String), +} + +/// Android connectivity listener +#[derive(Clone)] +pub struct ConnectivityListener { + jvm: Arc, + class: GlobalRef, + object: GlobalRef, + _sender: Option>>, +} + +impl ConnectivityListener { + /// Create a new connectivity listener + pub fn new(android_context: AndroidContext) -> Result { + let env = JnixEnv::from( + android_context + .jvm + .attach_current_thread_as_daemon() + .map_err(Error::AttachJvmToThread)?, + ); + + let get_connectivity_listener_method = env + .get_method_id( + &env.get_class("net/mullvad/talpid/TalpidVpnService"), + "getConnectivityListener", + "()Lnet/mullvad/talpid/ConnectivityListener;", + ) + .map_err(|cause| { + Error::FindMethod("MullvadVpnService", "getConnectivityListener", cause) + })?; + + let result = env + .call_method_unchecked( + android_context.vpn_service.as_obj(), + get_connectivity_listener_method, + JavaType::Object("Lnet/mullvad/talpid/ConnectivityListener;".to_owned()), + &[], + ) + .map_err(|cause| { + Error::CallMethod("MullvadVpnService", "getConnectivityListener", cause) + })?; + + let object = match result { + JValue::Object(object) => env.new_global_ref(object).map_err(Error::CreateGlobalRef)?, + value => { + return Err(Error::InvalidMethodResult( + "MullvadVpnService", + "getConnectivityListener", + format!("{:?}", value), + )) + } + }; + + let class = env.get_class("net/mullvad/talpid/ConnectivityListener"); + + Ok(ConnectivityListener { + jvm: android_context.jvm, + class, + object, + _sender: None, + }) + } + + /// Register a channel that receives changes about the offline state + pub fn set_connectivity_listener( + &mut self, + sender: UnboundedSender, + ) -> Result<(), Error> { + let sender = Arc::new(sender); + + let weak_sender = Arc::downgrade(&sender); + + let weak_sender_ptr = Box::new(weak_sender); + let weak_sender_address = Box::into_raw(weak_sender_ptr) as jlong; + + let result = self.call_method( + "setSenderAddress", + "(J)V", + &[JValue::Long(weak_sender_address)], + JavaType::Primitive(Primitive::Void), + )?; + + match result { + JValue::Void => Ok(()), + value => Err(Error::InvalidMethodResult( + "ConnectivityListener", + "setSenderAddress", + format!("{:?}", value), + )), + }?; + + self._sender = Some(sender); + + Ok(()) + } + + /// Return the current offline/connectivity state + pub fn connectivity(&self) -> Connectivity { + self.get_is_connected() + .map(|connected| Connectivity::Status { connected }) + .unwrap_or_else(|error| { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check connectivity status") + ); + Connectivity::PresumeOnline + }) + } + + fn get_is_connected(&self) -> Result { + let is_connected = self.call_method( + "isConnected", + "()Z", + &[], + JavaType::Primitive(Primitive::Boolean), + )?; + + match is_connected { + JValue::Bool(JNI_TRUE) => Ok(true), + JValue::Bool(_) => Ok(false), + value => Err(Error::InvalidMethodResult( + "ConnectivityListener", + "isConnected", + format!("{:?}", value), + )), + } + } + + /// Return the current DNS servers according to Android + pub fn current_dns_servers(&self) -> Result, Error> { + let env = JnixEnv::from( + self.jvm + .attach_current_thread_as_daemon() + .map_err(Error::AttachJvmToThread)?, + ); + + let current_dns_servers = self.call_method( + "getCurrentDnsServers", + "()Ljava/util/ArrayList;", + &[], + JavaType::Object("java/util/ArrayList".to_owned()), + )?; + + match current_dns_servers { + JValue::Object(jaddrs) => Ok(Vec::from_java(&env, jaddrs)), + value => Err(Error::InvalidMethodResult( + "ConnectivityListener", + "currentDnsServers", + format!("{:?}", value), + )), + } + } + + fn call_method( + &self, + method: &'static str, + signature: &str, + parameters: &[JValue<'_>], + return_type: JavaType, + ) -> Result, Error> { + let env = JnixEnv::from( + self.jvm + .attach_current_thread_as_daemon() + .map_err(Error::AttachJvmToThread)?, + ); + + let method_id = env + .get_method_id(&self.class, method, signature) + .map_err(|cause| Error::FindMethod("ConnectivityListener", method, cause))?; + + env.call_method_unchecked(self.object.as_obj(), method_id, return_type, parameters) + .map_err(|cause| Error::CallMethod("ConnectivityListener", method, cause)) + } +} + +/// Entry point for Android Java code to notify the connectivity status. +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnectivityChange( + _: JNIEnv<'_>, + _: JObject<'_>, + connected: jboolean, + sender_address: jlong, +) { + let connected = JNI_TRUE == connected; + let sender_ref = Box::leak(unsafe { get_sender_from_address(sender_address) }); + if let Some(sender) = sender_ref.upgrade() { + if sender + .unbounded_send(Connectivity::Status { connected }) + .is_err() + { + log::warn!("Failed to send offline change event"); + } + } +} + +/// Entry point for Android Java code to return ownership of the sender reference. +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_destroySender( + _: JNIEnv<'_>, + _: JObject<'_>, + sender_address: jlong, +) { + let _ = unsafe { get_sender_from_address(sender_address) }; +} + +unsafe fn get_sender_from_address(address: jlong) -> Box>> { + Box::from_raw(address as *mut Weak>) +} diff --git a/talpid-core/src/lib.rs b/talpid-core/src/lib.rs index bb81d7a48fa3..585c0274ee0a 100644 --- a/talpid-core/src/lib.rs +++ b/talpid-core/src/lib.rs @@ -42,3 +42,7 @@ mod linux; /// A resolver that's controlled by the tunnel state machine #[cfg(target_os = "macos")] pub(crate) mod resolver; + +/// Connectivity monitor for Android +#[cfg(target_os = "android")] +pub mod connectivity_listener; diff --git a/talpid-core/src/offline/android.rs b/talpid-core/src/offline/android.rs index 7dc8389ed39b..7280ee792f09 100644 --- a/talpid-core/src/offline/android.rs +++ b/talpid-core/src/offline/android.rs @@ -1,217 +1,32 @@ +use crate::connectivity_listener::{ConnectivityListener, Error}; use futures::channel::mpsc::UnboundedSender; -use jnix::{ - jni::{ - self, - objects::{GlobalRef, JObject, JValue}, - signature::{JavaType, Primitive}, - sys::{jboolean, jlong, JNI_TRUE}, - JNIEnv, JavaVM, - }, - JnixEnv, -}; -use std::sync::{Arc, Weak}; -use talpid_types::{android::AndroidContext, net::Connectivity, ErrorExt}; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Failed to attach Java VM to tunnel thread")] - AttachJvmToThread(#[source] jni::errors::Error), - - #[error("Failed to call Java method {0}.{1}")] - CallMethod(&'static str, &'static str, #[source] jni::errors::Error), - - #[error("Failed to create global reference to Java object")] - CreateGlobalRef(#[source] jni::errors::Error), - - #[error("Failed to find {0}.{1} method")] - FindMethod(&'static str, &'static str, #[source] jni::errors::Error), - - #[error("Received an invalid result from {0}.{1}: {2}")] - InvalidMethodResult(&'static str, &'static str, String), -} +use talpid_types::net::Connectivity; pub struct MonitorHandle { - jvm: Arc, - class: GlobalRef, - object: GlobalRef, - _sender: Arc>, + connectivity_listener: ConnectivityListener, } impl MonitorHandle { - pub fn new( - android_context: AndroidContext, - sender: Arc>, - ) -> Result { - let env = JnixEnv::from( - android_context - .jvm - .attach_current_thread_as_daemon() - .map_err(Error::AttachJvmToThread)?, - ); - - let get_connectivity_listener_method = env - .get_method_id( - &env.get_class("net/mullvad/talpid/TalpidVpnService"), - "getConnectivityListener", - "()Lnet/mullvad/talpid/ConnectivityListener;", - ) - .map_err(|cause| { - Error::FindMethod("MullvadVpnService", "getConnectivityListener", cause) - })?; - - let result = env - .call_method_unchecked( - android_context.vpn_service.as_obj(), - get_connectivity_listener_method, - JavaType::Object("Lnet/mullvad/talpid/ConnectivityListener;".to_owned()), - &[], - ) - .map_err(|cause| { - Error::CallMethod("MullvadVpnService", "getConnectivityListener", cause) - })?; - - let object = match result { - JValue::Object(object) => env.new_global_ref(object).map_err(Error::CreateGlobalRef)?, - value => { - return Err(Error::InvalidMethodResult( - "MullvadVpnService", - "getConnectivityListener", - format!("{:?}", value), - )) - } - }; - - let class = env.get_class("net/mullvad/talpid/ConnectivityListener"); - - Ok(MonitorHandle { - jvm: android_context.jvm, - class, - object, - _sender: sender, - }) + fn new(connectivity_listener: ConnectivityListener) -> Self { + MonitorHandle { + connectivity_listener, + } } #[allow(clippy::unused_async)] pub async fn connectivity(&self) -> Connectivity { - self.get_is_connected() - .map(|connected| Connectivity::Status { connected }) - .unwrap_or_else(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check connectivity status") - ); - Connectivity::PresumeOnline - }) - } - - fn get_is_connected(&self) -> Result { - let is_connected = self.call_method( - "isConnected", - "()Z", - &[], - JavaType::Primitive(Primitive::Boolean), - )?; - - match is_connected { - JValue::Bool(JNI_TRUE) => Ok(true), - JValue::Bool(_) => Ok(false), - value => Err(Error::InvalidMethodResult( - "ConnectivityListener", - "isConnected", - format!("{:?}", value), - )), - } - } - - fn set_sender(&self, sender: Weak>) -> Result<(), Error> { - let sender_ptr = Box::new(sender); - let sender_address = Box::into_raw(sender_ptr) as jlong; - - let result = self.call_method( - "setSenderAddress", - "(J)V", - &[JValue::Long(sender_address)], - JavaType::Primitive(Primitive::Void), - )?; - - match result { - JValue::Void => Ok(()), - value => Err(Error::InvalidMethodResult( - "ConnectivityListener", - "setSenderAddress", - format!("{:?}", value), - )), - } - } - - fn call_method( - &self, - method: &'static str, - signature: &str, - parameters: &[JValue<'_>], - return_type: JavaType, - ) -> Result, Error> { - let env = JnixEnv::from( - self.jvm - .attach_current_thread_as_daemon() - .map_err(Error::AttachJvmToThread)?, - ); - - let method_id = env - .get_method_id(&self.class, method, signature) - .map_err(|cause| Error::FindMethod("ConnectivityListener", method, cause))?; - - env.call_method_unchecked(self.object.as_obj(), method_id, return_type, parameters) - .map_err(|cause| Error::CallMethod("ConnectivityListener", method, cause)) + self.connectivity_listener.connectivity() } } -/// Entry point for Android Java code to notify the connectivity status. -#[no_mangle] -#[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnectivityChange( - _: JNIEnv<'_>, - _: JObject<'_>, - connected: jboolean, - sender_address: jlong, -) { - let connected = JNI_TRUE == connected; - let sender_ref = Box::leak(unsafe { get_sender_from_address(sender_address) }); - if let Some(sender) = sender_ref.upgrade() { - if sender - .unbounded_send(Connectivity::Status { connected }) - .is_err() - { - log::warn!("Failed to send offline change event"); - } - } -} - -/// Entry point for Android Java code to return ownership of the sender reference. -#[no_mangle] -#[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_destroySender( - _: JNIEnv<'_>, - _: JObject<'_>, - sender_address: jlong, -) { - let _ = unsafe { get_sender_from_address(sender_address) }; -} - -unsafe fn get_sender_from_address(address: jlong) -> Box>> { - Box::from_raw(address as *mut Weak>) -} - #[allow(clippy::unused_async)] pub async fn spawn_monitor( sender: UnboundedSender, - android_context: AndroidContext, + connectivity_listener: ConnectivityListener, ) -> Result { - let sender = Arc::new(sender); - let weak_sender = Arc::downgrade(&sender); - let monitor_handle = MonitorHandle::new(android_context, sender)?; - - monitor_handle.set_sender(weak_sender)?; - + let mut monitor_handle = MonitorHandle::new(connectivity_listener); + monitor_handle + .connectivity_listener + .set_connectivity_listener(sender)?; Ok(monitor_handle) } diff --git a/talpid-core/src/offline/mod.rs b/talpid-core/src/offline/mod.rs index 0e1d55c27375..6605bd3358fd 100644 --- a/talpid-core/src/offline/mod.rs +++ b/talpid-core/src/offline/mod.rs @@ -1,9 +1,9 @@ +#[cfg(target_os = "android")] +use crate::connectivity_listener::ConnectivityListener; use futures::channel::mpsc::UnboundedSender; use std::sync::LazyLock; #[cfg(not(target_os = "android"))] use talpid_routing::RouteManagerHandle; -#[cfg(target_os = "android")] -use talpid_types::android::AndroidContext; use talpid_types::{net::Connectivity, ErrorExt}; #[cfg(target_os = "macos")] @@ -44,7 +44,7 @@ pub async fn spawn_monitor( sender: UnboundedSender, #[cfg(not(target_os = "android"))] route_manager: RouteManagerHandle, #[cfg(target_os = "linux")] fwmark: Option, - #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(target_os = "android")] connectivity_listener: ConnectivityListener, ) -> MonitorHandle { let monitor = if *FORCE_DISABLE_OFFLINE_MONITOR { None @@ -56,7 +56,7 @@ pub async fn spawn_monitor( #[cfg(target_os = "linux")] fwmark, #[cfg(target_os = "android")] - android_context, + connectivity_listener, ) .await .inspect_err(|error| { diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 6a1d779be812..2541bc88e614 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -49,6 +49,9 @@ use talpid_types::{ tunnel::{ErrorStateCause, ParameterGenerationError, TunnelStateTransition}, }; +#[cfg(target_os = "android")] +use crate::connectivity_listener::ConnectivityListener; + const TUNNEL_STATE_MACHINE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); /// Errors that can happen when setting up or using the state machine. @@ -119,6 +122,7 @@ pub struct LinuxNetworkingIdentifiers { } /// Spawn the tunnel state machine thread, returning a channel for sending tunnel commands. +#[allow(clippy::too_many_arguments)] pub async fn spawn( initial_settings: InitialTunnelState, tunnel_parameters_generator: impl TunnelParametersGenerator, @@ -128,6 +132,7 @@ pub async fn spawn( offline_state_listener: mpsc::UnboundedSender, #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, #[cfg(target_os = "android")] android_context: AndroidContext, + #[cfg(target_os = "android")] connectivity_listener: ConnectivityListener, #[cfg(target_os = "linux")] linux_ids: LinuxNetworkingIdentifiers, ) -> Result { let (command_tx, command_rx) = mpsc::unbounded(); @@ -155,7 +160,7 @@ pub async fn spawn( #[cfg(target_os = "windows")] volume_update_rx, #[cfg(target_os = "android")] - android_context, + connectivity_listener, #[cfg(target_os = "linux")] linux_ids, }; @@ -251,7 +256,7 @@ struct TunnelStateMachineInitArgs { #[cfg(target_os = "windows")] volume_update_rx: mpsc::UnboundedReceiver<()>, #[cfg(target_os = "android")] - android_context: AndroidContext, + connectivity_listener: ConnectivityListener, #[cfg(target_os = "linux")] linux_ids: LinuxNetworkingIdentifiers, } @@ -263,7 +268,7 @@ impl TunnelStateMachine { #[cfg(target_os = "windows")] let volume_update_rx = args.volume_update_rx; #[cfg(target_os = "android")] - let android_context = args.android_context; + let connectivity_listener = args.connectivity_listener; let runtime = tokio::runtime::Handle::current(); @@ -339,7 +344,7 @@ impl TunnelStateMachine { #[cfg(target_os = "linux")] Some(args.linux_ids.fwmark), #[cfg(target_os = "android")] - android_context, + connectivity_listener, ) .await; let connectivity = offline_monitor.connectivity().await; diff --git a/test/Cargo.lock b/test/Cargo.lock index 24852c3356e4..941a0a08da5d 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -300,9 +300,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "blake3" @@ -1324,9 +1324,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.8.0" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" [[package]] name = "httpdate" @@ -1702,7 +1702,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] @@ -1866,6 +1866,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +dependencies = [ + "hermit-abi", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.52.0", +] + [[package]] name = "mio-serial" version = "5.0.5" @@ -1873,7 +1885,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20a4c60ca5c9c0e114b3bd66ff4aa5f9b2b175442be51ca6c4365d687a97a2ac" dependencies = [ "log", - "mio", + "mio 0.8.11", "nix 0.26.4", "serialport", "winapi", @@ -1883,6 +1895,7 @@ dependencies = [ name = "mullvad-api" version = "0.0.0" dependencies = [ + "async-trait", "cbindgen", "chrono", "futures", @@ -2042,7 +2055,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", "cfg_aliases", "libc", @@ -2061,7 +2074,7 @@ version = "6.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "crossbeam-channel", "filetime", "fsevent-sys", @@ -2069,7 +2082,7 @@ dependencies = [ "kqueue", "libc", "log", - "mio", + "mio 0.8.11", "walkdir", "windows-sys 0.48.0", ] @@ -2109,16 +2122,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - [[package]] name = "object" version = "0.32.2" @@ -2883,7 +2886,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno 0.3.8", "libc", "linux-raw-sys", @@ -3079,7 +3082,7 @@ version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f5a15d0be940df84846264b09b51b10b931fb2f275becb80934e3568a016828" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cfg-if", "core-foundation-sys", "io-kit-sys", @@ -3624,28 +3627,27 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", "libc", - "mio", - "num_cpus", + "mio 1.0.2", "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2 0.5.6", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", @@ -4525,9 +4527,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" dependencies = [ "zeroize_derive", ] diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs index 7fe14ae58ee1..45151070a9d9 100644 --- a/test/test-manager/src/tests/account.rs +++ b/test/test-manager/src/tests/account.rs @@ -295,8 +295,11 @@ pub async fn new_device_client() -> anyhow::Result { ..api_endpoint }); - let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) - .expect("failed to create api runtime"); + let api = mullvad_api::Runtime::new( + tokio::runtime::Handle::current(), + mullvad_api::DefaultDnsResolver, + ) + .expect("failed to create api runtime"); let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()); Ok(DevicesProxy::new(rest_handle)) } From 168c9afb19e9bec61b40ecfb5ab12ed7983f35e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Tue, 19 Nov 2024 18:17:06 +0100 Subject: [PATCH 4/7] Simplify ConnectivityListener --- .../mullvad/talpid/ConnectivityListener.kt | 13 ++--- mullvad-daemon/src/lib.rs | 1 - talpid-core/src/connectivity_listener.rs | 47 ++++++++----------- 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index c2f0aef20f2d..f1fe3ca807b4 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -75,15 +75,12 @@ class ConnectivityListener { fun unregister() { connectivityManager.unregisterNetworkCallback(callback) connectivityManager.unregisterNetworkCallback(defaultNetworkCallback) - } - // DROID-1401 - // This function has never been used and should most likely be merged into unregister(), - // along with ensuring that the lifecycle of it is correct. - @Suppress("UnusedPrivateMember") - private fun finalize() { - destroySender(senderAddress) - senderAddress = 0L + if (senderAddress != 0L) { + var oldSender = senderAddress + senderAddress = 0L + destroySender(oldSender) + } } private external fun notifyConnectivityChange(isConnected: Boolean, senderAddress: Long) diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 0f2914fdff8a..0665c2736ebb 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -622,7 +622,6 @@ impl Daemon { mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; let api_runtime = mullvad_api::Runtime::with_cache( - // FIXME: clone is bad (single sender) #[cfg(target_os = "android")] android_dns::AndroidDnsResolver::new(connectivity_listener.clone()), #[cfg(not(target_os = "android"))] diff --git a/talpid-core/src/connectivity_listener.rs b/talpid-core/src/connectivity_listener.rs index b28f58cefbec..2d121d264bcf 100644 --- a/talpid-core/src/connectivity_listener.rs +++ b/talpid-core/src/connectivity_listener.rs @@ -9,10 +9,9 @@ use jnix::{ sys::{jboolean, jlong, JNI_TRUE}, JNIEnv, JavaVM, }, - JnixEnv, - FromJava, + FromJava, JnixEnv, }; -use std::{net::IpAddr, sync::{Arc, Weak}}; +use std::{net::IpAddr, sync::Arc}; use talpid_types::{android::AndroidContext, net::Connectivity, ErrorExt}; /// Error related to Android connectivity monitor @@ -45,7 +44,6 @@ pub struct ConnectivityListener { jvm: Arc, class: GlobalRef, object: GlobalRef, - _sender: Option>>, } impl ConnectivityListener { @@ -96,26 +94,24 @@ impl ConnectivityListener { jvm: android_context.jvm, class, object, - _sender: None, }) } /// Register a channel that receives changes about the offline state + /// + /// # Note + /// + /// The listener is shared by all instances of the struct. pub fn set_connectivity_listener( &mut self, sender: UnboundedSender, ) -> Result<(), Error> { - let sender = Arc::new(sender); - - let weak_sender = Arc::downgrade(&sender); - - let weak_sender_ptr = Box::new(weak_sender); - let weak_sender_address = Box::into_raw(weak_sender_ptr) as jlong; + let sender_ptr = Box::into_raw(Box::new(sender)) as jlong; let result = self.call_method( "setSenderAddress", "(J)V", - &[JValue::Long(weak_sender_address)], + &[JValue::Long(sender_ptr)], JavaType::Primitive(Primitive::Void), )?; @@ -128,8 +124,6 @@ impl ConnectivityListener { )), }?; - self._sender = Some(sender); - Ok(()) } @@ -222,15 +216,18 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnec sender_address: jlong, ) { let connected = JNI_TRUE == connected; - let sender_ref = Box::leak(unsafe { get_sender_from_address(sender_address) }); - if let Some(sender) = sender_ref.upgrade() { - if sender - .unbounded_send(Connectivity::Status { connected }) - .is_err() - { - log::warn!("Failed to send offline change event"); - } + + let sender = unsafe { Box::from_raw(sender_address as *mut UnboundedSender) }; + + if sender + .unbounded_send(Connectivity::Status { connected }) + .is_err() + { + log::warn!("Failed to send offline change event"); } + + // Do not destroy + std::mem::forget(sender); } /// Entry point for Android Java code to return ownership of the sender reference. @@ -241,9 +238,5 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_destroySende _: JObject<'_>, sender_address: jlong, ) { - let _ = unsafe { get_sender_from_address(sender_address) }; -} - -unsafe fn get_sender_from_address(address: jlong) -> Box>> { - Box::from_raw(address as *mut Weak>) + let _ = unsafe { Box::from_raw(sender_address as *mut UnboundedSender) }; } From 133845955492ecafb6447eaa9ceba34cb972f488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20G=C3=B6ransson?= Date: Wed, 20 Nov 2024 08:45:25 +0100 Subject: [PATCH 5/7] Refactor ConnectivityListener --- android/lib/talpid/build.gradle.kts | 1 + .../mullvad/talpid/ConnectivityListener.kt | 129 +++++++++-------- .../net/mullvad/talpid/TalpidVpnService.kt | 8 +- .../talpid/util/ConnectivityManagerUtil.kt | 132 ++++++++++++++++++ .../mullvadvpn/service/MullvadVpnService.kt | 2 +- 5 files changed, 209 insertions(+), 63 deletions(-) create mode 100644 android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt diff --git a/android/lib/talpid/build.gradle.kts b/android/lib/talpid/build.gradle.kts index a5cd613de189..c53c2add28dd 100644 --- a/android/lib/talpid/build.gradle.kts +++ b/android/lib/talpid/build.gradle.kts @@ -31,6 +31,7 @@ android { dependencies { implementation(projects.lib.model) + implementation(libs.androidx.ktx) implementation(libs.androidx.lifecycle.service) implementation(libs.kermit) implementation(libs.kotlin.stdlib) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index f1fe3ca807b4..a37cf18578df 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -1,86 +1,95 @@ package net.mullvad.talpid -import android.content.Context import android.net.ConnectivityManager -import android.net.ConnectivityManager.NetworkCallback import android.net.LinkProperties import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkRequest -import co.touchlab.kermit.Logger import java.net.InetAddress -import kotlin.properties.Delegates.observable - -class ConnectivityListener { - private val availableNetworks = HashSet() - - private val callback = - object : NetworkCallback() { - override fun onAvailable(network: Network) { - availableNetworks.add(network) - isConnected = true - } - - override fun onLost(network: Network) { - availableNetworks.remove(network) - isConnected = availableNetworks.isNotEmpty() - } - } - - private val defaultNetworkCallback = - object : NetworkCallback() { - override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) { - super.onLinkPropertiesChanged(network, linkProperties) - currentDnsServers = ArrayList(linkProperties.dnsServers) +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.distinctUntilChanged +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.scan +import kotlinx.coroutines.flow.stateIn +import net.mullvad.talpid.util.NetworkEvent +import net.mullvad.talpid.util.defaultNetworkFlow +import net.mullvad.talpid.util.networkFlow + +class ConnectivityListener(val connectivityManager: ConnectivityManager) { + // Used by JNI + var senderAddress = 0L + set(value) { + if (value == 0L) { + destroySender(field) } + field = value } - private lateinit var connectivityManager: ConnectivityManager + private lateinit var _isConnected: StateFlow + // Used by JNI + val isConnected + get() = _isConnected.value + private lateinit var _currentDnsServers: StateFlow> // Used by JNI - var isConnected by - observable(false) { _, oldValue, newValue -> - if (newValue != oldValue) { - if (senderAddress != 0L) { - notifyConnectivityChange(newValue, senderAddress) + val currentDnsServers + get() = ArrayList(_currentDnsServers.value) + + fun register(scope: CoroutineScope) { + _currentDnsServers = + dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers()) + + _isConnected = + hasInternetCapability() + .onEach { + if (senderAddress != 0L) { + notifyConnectivityChange(it, senderAddress) + } } - } - } + .stateIn(scope, SharingStarted.Eagerly, false) + } - var currentDnsServers: ArrayList = ArrayList() - private set(value) { - field = ArrayList(value.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }) - Logger.d("New currentDnsServers: $field") - } + fun unregister() { + senderAddress = 0L + } - var senderAddress = 0L + private fun dnsServerChanges(): Flow> = + connectivityManager + .defaultNetworkFlow() + .filterIsInstance() + .map { it.linkProperties.dnsServersWithoutFallback() } + + private fun currentDnsServers(): List = + connectivityManager + .getLinkProperties(connectivityManager.activeNetwork) + ?.dnsServersWithoutFallback() ?: emptyList() - fun register(context: Context) { + private fun LinkProperties.dnsServersWithoutFallback(): List = + dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } + + private fun hasInternetCapability(): Flow { val request = NetworkRequest.Builder() .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) .build() - connectivityManager = - context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager - - connectivityManager.registerNetworkCallback(request, callback) - currentDnsServers = - connectivityManager.getLinkProperties(connectivityManager.activeNetwork)?.dnsServers?.let { ArrayList(it) } - ?: ArrayList() - connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback) - } - - fun unregister() { - connectivityManager.unregisterNetworkCallback(callback) - connectivityManager.unregisterNetworkCallback(defaultNetworkCallback) - - if (senderAddress != 0L) { - var oldSender = senderAddress - senderAddress = 0L - destroySender(oldSender) - } + return connectivityManager + .networkFlow(request) + .scan(setOf()) { networks, event -> + when (event) { + is NetworkEvent.Available -> networks + event.network + is NetworkEvent.Lost -> networks - event.network + else -> networks + } + } + .map { it.isNotEmpty() } + .distinctUntilChanged() } private external fun notifyConnectivityChange(isConnected: Boolean, senderAddress: Long) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 61c0be2ccf65..dfd6699b1e33 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -1,7 +1,10 @@ package net.mullvad.talpid +import android.net.ConnectivityManager import android.os.ParcelFileDescriptor import androidx.annotation.CallSuper +import androidx.core.content.getSystemService +import androidx.lifecycle.lifecycleScope import co.touchlab.kermit.Logger import java.net.Inet4Address import java.net.Inet6Address @@ -29,12 +32,13 @@ open class TalpidVpnService : LifecycleVpnService() { private var currentTunConfig: TunConfig? = null // Used by JNI - val connectivityListener = ConnectivityListener() + lateinit var connectivityListener: ConnectivityListener @CallSuper override fun onCreate() { super.onCreate() - connectivityListener.register(this) + connectivityListener = ConnectivityListener(getSystemService()!!) + connectivityListener.register(lifecycleScope) } @CallSuper diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt new file mode 100644 index 000000000000..daf155c6e8ef --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -0,0 +1,132 @@ +package net.mullvad.talpid.util + +import android.net.ConnectivityManager +import android.net.ConnectivityManager.NetworkCallback +import android.net.LinkProperties +import android.net.Network +import android.net.NetworkCapabilities +import android.net.NetworkRequest +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.channels.trySendBlocking +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.callbackFlow + +fun ConnectivityManager.defaultNetworkFlow(): Flow = + callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged( + network: Network, + linkProperties: LinkProperties, + ) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } + + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } + + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } + + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } + + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } + + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } + + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) + } + } + registerDefaultNetworkCallback(callback) + + awaitClose { unregisterNetworkCallback(callback) } + } + +fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow = + callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged( + network: Network, + linkProperties: LinkProperties, + ) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } + + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } + + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } + + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } + + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } + + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } + + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) + } + } + registerNetworkCallback(networkRequest, callback) + + awaitClose { unregisterNetworkCallback(callback) } + } + +sealed interface NetworkEvent { + data class Available(val network: Network) : NetworkEvent + + data object Unavailable : NetworkEvent + + data class LinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) : + NetworkEvent + + data class CapabilitiesChanged( + val network: Network, + val networkCapabilities: NetworkCapabilities, + ) : NetworkEvent + + data class BlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent + + data class Losing(val network: Network, val maxMsToLive: Int) : NetworkEvent + + data class Lost(val network: Network) : NetworkEvent +} diff --git a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt index ebdcbec78019..55aa416e537b 100644 --- a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt +++ b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt @@ -203,6 +203,7 @@ class MullvadVpnService : TalpidVpnService() { } override fun onDestroy() { + super.onDestroy() Logger.i("MullvadVpnService: onDestroy") // Shutting down the daemon gracefully managementService.stop() @@ -214,7 +215,6 @@ class MullvadVpnService : TalpidVpnService() { managementService.enterIdle() Logger.i("Shutdown complete") - super.onDestroy() } // If an intent is from the system it is because of the OS starting/stopping the VPN. From e1f1f6f74dac0c36f211131d110da2c6892a14cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Wed, 20 Nov 2024 13:22:14 +0100 Subject: [PATCH 6/7] Update changelog --- android/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/android/CHANGELOG.md b/android/CHANGELOG.md index ac65d38672a7..d97bf45ed7c2 100644 --- a/android/CHANGELOG.md +++ b/android/CHANGELOG.md @@ -32,6 +32,7 @@ Line wrap the file at 100 chars. Th ### Fixed - Fix a bug where the Android account expiry notifications would not be updated if the app was running in the background for a long time. +- Fix ANR due to the tokio runtime being blocked by `getaddrinfo` when dropped. ## [android/2024.8] - 2024-11-01 From ad2fc60c64c78ad78bd9f995a9f5e72f978ac1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Thu, 21 Nov 2024 14:56:30 +0100 Subject: [PATCH 7/7] Make connectivity sender static --- .../mullvad/talpid/ConnectivityListener.kt | 23 +-- .../net/mullvad/talpid/TalpidVpnService.kt | 6 - talpid-core/src/connectivity_listener.rs | 137 +++++------------- talpid-core/src/offline/android.rs | 2 +- 4 files changed, 41 insertions(+), 127 deletions(-) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index a37cf18578df..4cb67f9945e6 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -21,15 +21,6 @@ import net.mullvad.talpid.util.defaultNetworkFlow import net.mullvad.talpid.util.networkFlow class ConnectivityListener(val connectivityManager: ConnectivityManager) { - // Used by JNI - var senderAddress = 0L - set(value) { - if (value == 0L) { - destroySender(field) - } - field = value - } - private lateinit var _isConnected: StateFlow // Used by JNI val isConnected @@ -46,18 +37,10 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { _isConnected = hasInternetCapability() - .onEach { - if (senderAddress != 0L) { - notifyConnectivityChange(it, senderAddress) - } - } + .onEach { notifyConnectivityChange(it) } .stateIn(scope, SharingStarted.Eagerly, false) } - fun unregister() { - senderAddress = 0L - } - private fun dnsServerChanges(): Flow> = connectivityManager .defaultNetworkFlow() @@ -92,7 +75,5 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .distinctUntilChanged() } - private external fun notifyConnectivityChange(isConnected: Boolean, senderAddress: Long) - - private external fun destroySender(senderAddress: Long) + private external fun notifyConnectivityChange(isConnected: Boolean) } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index dfd6699b1e33..dc1f8d23ca90 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -41,12 +41,6 @@ open class TalpidVpnService : LifecycleVpnService() { connectivityListener.register(lifecycleScope) } - @CallSuper - override fun onDestroy() { - super.onDestroy() - connectivityListener.unregister() - } - fun openTun(config: TunConfig): CreateTunResult { synchronized(this) { val tunStatus = activeTunStatus diff --git a/talpid-core/src/connectivity_listener.rs b/talpid-core/src/connectivity_listener.rs index 2d121d264bcf..9bdf4bf87ab7 100644 --- a/talpid-core/src/connectivity_listener.rs +++ b/talpid-core/src/connectivity_listener.rs @@ -5,13 +5,15 @@ use jnix::{ jni::{ self, objects::{GlobalRef, JObject, JValue}, - signature::{JavaType, Primitive}, - sys::{jboolean, jlong, JNI_TRUE}, + sys::{jboolean, JNI_TRUE}, JNIEnv, JavaVM, }, FromJava, JnixEnv, }; -use std::{net::IpAddr, sync::Arc}; +use std::{ + net::IpAddr, + sync::{Arc, Mutex}, +}; use talpid_types::{android::AndroidContext, net::Connectivity, ErrorExt}; /// Error related to Android connectivity monitor @@ -42,10 +44,11 @@ pub enum Error { #[derive(Clone)] pub struct ConnectivityListener { jvm: Arc, - class: GlobalRef, - object: GlobalRef, + android_listener: GlobalRef, } +static CONNECTIVITY_TX: Mutex>> = Mutex::new(None); + impl ConnectivityListener { /// Create a new connectivity listener pub fn new(android_context: AndroidContext) -> Result { @@ -56,28 +59,18 @@ impl ConnectivityListener { .map_err(Error::AttachJvmToThread)?, ); - let get_connectivity_listener_method = env - .get_method_id( - &env.get_class("net/mullvad/talpid/TalpidVpnService"), - "getConnectivityListener", - "()Lnet/mullvad/talpid/ConnectivityListener;", - ) - .map_err(|cause| { - Error::FindMethod("MullvadVpnService", "getConnectivityListener", cause) - })?; - let result = env - .call_method_unchecked( + .call_method( android_context.vpn_service.as_obj(), - get_connectivity_listener_method, - JavaType::Object("Lnet/mullvad/talpid/ConnectivityListener;".to_owned()), + "getConnectivityListener", + "()Lnet/mullvad/talpid/ConnectivityListener;", &[], ) .map_err(|cause| { Error::CallMethod("MullvadVpnService", "getConnectivityListener", cause) })?; - let object = match result { + let android_listener = match result { JValue::Object(object) => env.new_global_ref(object).map_err(Error::CreateGlobalRef)?, value => { return Err(Error::InvalidMethodResult( @@ -88,43 +81,19 @@ impl ConnectivityListener { } }; - let class = env.get_class("net/mullvad/talpid/ConnectivityListener"); - Ok(ConnectivityListener { jvm: android_context.jvm, - class, - object, + android_listener, }) } - /// Register a channel that receives changes about the offline state + /// Register a channel that receives changes about the offline state. /// /// # Note /// /// The listener is shared by all instances of the struct. - pub fn set_connectivity_listener( - &mut self, - sender: UnboundedSender, - ) -> Result<(), Error> { - let sender_ptr = Box::into_raw(Box::new(sender)) as jlong; - - let result = self.call_method( - "setSenderAddress", - "(J)V", - &[JValue::Long(sender_ptr)], - JavaType::Primitive(Primitive::Void), - )?; - - match result { - JValue::Void => Ok(()), - value => Err(Error::InvalidMethodResult( - "ConnectivityListener", - "setSenderAddress", - format!("{:?}", value), - )), - }?; - - Ok(()) + pub fn set_connectivity_listener(&mut self, sender: UnboundedSender) { + *CONNECTIVITY_TX.lock().unwrap() = Some(sender); } /// Return the current offline/connectivity state @@ -141,16 +110,18 @@ impl ConnectivityListener { } fn get_is_connected(&self) -> Result { - let is_connected = self.call_method( - "isConnected", - "()Z", - &[], - JavaType::Primitive(Primitive::Boolean), - )?; + let env = JnixEnv::from( + self.jvm + .attach_current_thread_as_daemon() + .map_err(Error::AttachJvmToThread)?, + ); + + let is_connected = + env.call_method(self.android_listener.as_obj(), "isConnected", "()Z", &[]); match is_connected { - JValue::Bool(JNI_TRUE) => Ok(true), - JValue::Bool(_) => Ok(false), + Ok(JValue::Bool(JNI_TRUE)) => Ok(true), + Ok(JValue::Bool(_)) => Ok(false), value => Err(Error::InvalidMethodResult( "ConnectivityListener", "isConnected", @@ -167,43 +138,22 @@ impl ConnectivityListener { .map_err(Error::AttachJvmToThread)?, ); - let current_dns_servers = self.call_method( + let current_dns_servers = env.call_method( + self.android_listener.as_obj(), "getCurrentDnsServers", "()Ljava/util/ArrayList;", &[], - JavaType::Object("java/util/ArrayList".to_owned()), - )?; + ); match current_dns_servers { - JValue::Object(jaddrs) => Ok(Vec::from_java(&env, jaddrs)), + Ok(JValue::Object(jaddrs)) => Ok(Vec::from_java(&env, jaddrs)), value => Err(Error::InvalidMethodResult( "ConnectivityListener", - "currentDnsServers", + "getCurrentDnsServers", format!("{:?}", value), )), } } - - fn call_method( - &self, - method: &'static str, - signature: &str, - parameters: &[JValue<'_>], - return_type: JavaType, - ) -> Result, Error> { - let env = JnixEnv::from( - self.jvm - .attach_current_thread_as_daemon() - .map_err(Error::AttachJvmToThread)?, - ); - - let method_id = env - .get_method_id(&self.class, method, signature) - .map_err(|cause| Error::FindMethod("ConnectivityListener", method, cause))?; - - env.call_method_unchecked(self.object.as_obj(), method_id, return_type, parameters) - .map_err(|cause| Error::CallMethod("ConnectivityListener", method, cause)) - } } /// Entry point for Android Java code to notify the connectivity status. @@ -213,30 +163,19 @@ pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyConnec _: JNIEnv<'_>, _: JObject<'_>, connected: jboolean, - sender_address: jlong, ) { - let connected = JNI_TRUE == connected; + let Some(tx) = &*CONNECTIVITY_TX.lock().unwrap() else { + // No sender has been registered + log::trace!("Received connectivity notification wíth no channel"); + return; + }; - let sender = unsafe { Box::from_raw(sender_address as *mut UnboundedSender) }; + let connected = JNI_TRUE == connected; - if sender + if tx .unbounded_send(Connectivity::Status { connected }) .is_err() { log::warn!("Failed to send offline change event"); } - - // Do not destroy - std::mem::forget(sender); -} - -/// Entry point for Android Java code to return ownership of the sender reference. -#[no_mangle] -#[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_destroySender( - _: JNIEnv<'_>, - _: JObject<'_>, - sender_address: jlong, -) { - let _ = unsafe { Box::from_raw(sender_address as *mut UnboundedSender) }; } diff --git a/talpid-core/src/offline/android.rs b/talpid-core/src/offline/android.rs index 7280ee792f09..4947c7f61ebd 100644 --- a/talpid-core/src/offline/android.rs +++ b/talpid-core/src/offline/android.rs @@ -27,6 +27,6 @@ pub async fn spawn_monitor( let mut monitor_handle = MonitorHandle::new(connectivity_listener); monitor_handle .connectivity_listener - .set_connectivity_listener(sender)?; + .set_connectivity_listener(sender); Ok(monitor_handle) }