Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove global API endpoint #7373

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import java.io.File
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride

const val PROBLEM_REPORT_LOGS_FILE = "problem_report.txt"

Expand Down Expand Up @@ -39,7 +40,10 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher
collectReport(logDirectory.absolutePath, logsPath.absolutePath)
}

suspend fun sendReport(userReport: UserReport): SendProblemReportResult {
suspend fun sendReport(
userReport: UserReport,
apiEndpointOverride: ApiEndpointOverride?,
): SendProblemReportResult {
// If report is not collected then, collect it, if it fails then return error
if (!logsExists() && !collectLogs()) {
return SendProblemReportResult.Error.CollectLog
Expand All @@ -52,6 +56,7 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher
userReport.description,
logsPath.absolutePath,
cacheDirectory.absolutePath,
apiEndpointOverride
)
}

Expand Down Expand Up @@ -89,5 +94,6 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher
userMessage: String,
reportPath: String,
cacheDirectory: String,
apiEndpointOverride: ApiEndpointOverride?,
): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import net.mullvad.mullvadvpn.applist.ApplicationsProvider
import net.mullvad.mullvadvpn.compose.state.RelayListType
import net.mullvad.mullvadvpn.constant.IS_PLAY_BUILD
import net.mullvad.mullvadvpn.dataproxy.MullvadProblemReport
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.lib.payment.PaymentProvider
import net.mullvad.mullvadvpn.lib.shared.VoucherRepository
import net.mullvad.mullvadvpn.receiver.BootCompletedReceiver
Expand All @@ -29,6 +30,7 @@ import net.mullvad.mullvadvpn.repository.SettingsRepository
import net.mullvad.mullvadvpn.repository.SplashCompleteRepository
import net.mullvad.mullvadvpn.repository.SplitTunnelingRepository
import net.mullvad.mullvadvpn.repository.WireguardConstraintsRepository
import net.mullvad.mullvadvpn.service.DaemonConfig
import net.mullvad.mullvadvpn.ui.MainActivity
import net.mullvad.mullvadvpn.ui.serviceconnection.AppVersionInfoRepository
import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager
Expand Down Expand Up @@ -223,7 +225,7 @@ val uiModule = module {
viewModel { VoucherDialogViewModel(get()) }
viewModel { VpnSettingsViewModel(get(), get(), get(), get(), get()) }
viewModel { WelcomeViewModel(get(), get(), get(), get(), isPlayBuild = IS_PLAY_BUILD) }
viewModel { ReportProblemViewModel(get(), get()) }
viewModel { ReportProblemViewModel(get(), get(), get<DaemonConfig>().apiEndpointOverride) }
viewModel { ViewLogsViewModel(get()) }
viewModel { OutOfTimeViewModel(get(), get(), get(), get(), get(), isPlayBuild = IS_PLAY_BUILD) }
viewModel { PaymentViewModel(get()) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import net.mullvad.mullvadvpn.constant.MINIMUM_LOADING_TIME_MILLIS
import net.mullvad.mullvadvpn.dataproxy.MullvadProblemReport
import net.mullvad.mullvadvpn.dataproxy.SendProblemReportResult
import net.mullvad.mullvadvpn.dataproxy.UserReport
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.repository.ProblemReportRepository

data class ReportProblemUiState(
Expand All @@ -38,6 +39,7 @@ sealed interface ReportProblemSideEffect {
class ReportProblemViewModel(
private val mullvadProblemReporter: MullvadProblemReport,
private val problemReportRepository: ProblemReportRepository,
private val apiEndpointOverride: ApiEndpointOverride?,
) : ViewModel() {

private val sendingState: MutableStateFlow<SendingReportUiState?> = MutableStateFlow(null)
Expand Down Expand Up @@ -66,7 +68,10 @@ class ReportProblemViewModel(

// Ensure we show loading for at least MINIMUM_LOADING_TIME_MILLIS
val deferredResult = async {
mullvadProblemReporter.sendReport(UserReport(nullableEmail, description))
mullvadProblemReporter.sendReport(
UserReport(nullableEmail, description),
apiEndpointOverride,
)
}
delay(MINIMUM_LOADING_TIME_MILLIS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import kotlinx.parcelize.Parcelize
data class ApiEndpointOverride(
val hostname: String,
val port: Int = CUSTOM_ENDPOINT_HTTPS_PORT,
val disableAddressCache: Boolean = true,
val disableTls: Boolean = false,
val forceDirectConnection: Boolean = true,
) : Parcelable {
Expand Down
3 changes: 2 additions & 1 deletion ios/MullvadVPNUITests/MullvadApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class MullvadApi {
let result = mullvad_api_client_initialize(
&clientContext,
apiAddress,
hostname
hostname,
false
)
try ApiError(result).throwIfErr()
}
Expand Down
5 changes: 2 additions & 3 deletions mullvad-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tokio = { workspace = true, features = ["macros", "time", "rt-multi-thread", "ne
tokio-rustls = { version = "0.26.0", features = ["logging", "tls12", "ring"], default-features = false}
tokio-socks = "0.5.1"
rustls-pemfile = "2.1.3"
uuid = { version = "1.4.1", features = ["v4"] }

mullvad-encrypted-dns-proxy = { path = "../mullvad-encrypted-dns-proxy" }
mullvad-fs = { path = "../mullvad-fs" }
Expand All @@ -45,13 +46,11 @@ shadowsocks = { workspace = true, features = [ "stream-cipher" ] }
[dev-dependencies]
talpid-time = { path = "../talpid-time", features = ["test"] }
tokio = { workspace = true, features = ["test-util", "time"] }
mockito = "1.6.1"

[build-dependencies]
cbindgen = { version = "0.24.3", default-features = false }

[target.'cfg(target_os = "ios")'.dependencies]
uuid = { version = "1.4.1", features = ["v4"] }

[lib]
crate-type = [ "rlib", "staticlib" ]
bench = false
10 changes: 5 additions & 5 deletions mullvad-api/include/mullvad-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ typedef struct MullvadApiDevice {
* struct.
*
* * `api_address`: pointer to nul-terminated UTF-8 string containing a socket address
* representation
* ("143.32.4.32:9090"), the port is mandatory.
* representation ("143.32.4.32:9090"), the port is mandatory.
*
* * `hostname`: pointer to a null-terminated UTF-8 string representing the hostname that will be
* used for TLS validation.
*/
struct MullvadApiError mullvad_api_client_initialize(struct MullvadApiClient *client_ptr,
const char *api_address_ptr,
const char *hostname);
const char *hostname,
bool disable_tls);

/**
* Removes all devices from a given account
Expand Down Expand Up @@ -98,8 +98,8 @@ struct MullvadApiError mullvad_api_get_expiry(struct MullvadApiClient client_ptr
* * `account_str_ptr`: pointer to nul-terminated UTF-8 string containing the account number of the
* account that will have all of it's devices removed.
*
* * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function
* doesn't return an error, the pointer will be initialized with a valid instance of
* * `device_iter_ptr`: a pointer to a `device::MullvadApiDeviceIterator`. If this function doesn't
* return an error, the pointer will be initialized with a valid instance of
* `device::MullvadApiDeviceIterator`, which can be used to iterate through the devices.
*/
struct MullvadApiError mullvad_api_list_devices(struct MullvadApiClient client_ptr,
Expand Down
29 changes: 14 additions & 15 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.

use super::API;
use crate::DnsResolver;
use crate::{ApiEndpoint, DnsResolver};
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
Expand Down Expand Up @@ -38,42 +37,42 @@ impl DnsResolver for AddressCache {

#[derive(Clone)]
pub struct AddressCache {
hostname: String,
inner: Arc<Mutex<AddressCacheInner>>,
write_path: Option<Arc<Path>>,
}

impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
pub fn new(write_path: Option<Box<Path>>) -> Self {
Self::new_inner(API.address(), write_path)
}

pub fn with_static_addr(address: SocketAddr) -> Self {
Self::new_inner(address, None)
pub fn new(endpoint: &ApiEndpoint, write_path: Option<Box<Path>>) -> Self {
Self::new_inner(endpoint.address(), endpoint.host().to_owned(), write_path)
}

/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
pub async fn from_file(
read_path: &Path,
write_path: Option<Box<Path>>,
hostname: String,
) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
Ok(Self::new_inner(
read_address_file(read_path).await?,
write_path,
))
let address = read_address_file(read_path).await?;
Ok(Self::new_inner(address, hostname, write_path))
}

fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
fn new_inner(address: SocketAddr, hostname: String, write_path: Option<Box<Path>>) -> Self {
let cache = AddressCacheInner::from_address(address);
log::debug!("Using API address: {}", cache.address);

Self {
inner: Arc::new(Mutex::new(cache)),
write_path: write_path.map(Arc::from),
hostname,
}
}

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
if hostname.eq_ignore_ascii_case(&self.hostname) {
Some(self.get_address().await)
} else {
None
Expand Down
10 changes: 7 additions & 3 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.

use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, ApiEndpoint, RelayListProxy,
};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");
let runtime = mullvad_api::Runtime::new(
tokio::runtime::Handle::current(),
&ApiEndpoint::from_env_vars(),
);

let relay_list_request =
RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()))
Expand Down
8 changes: 8 additions & 0 deletions mullvad-api/src/ffi/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum MullvadApiErrorKind {

/// MullvadApiErrorKind contains a description and an error kind. If the error kind is
/// `MullvadApiErrorKind` is NoError, the pointer will be nil.
#[derive(Debug)]
#[repr(C)]
pub struct MullvadApiError {
description: *mut libc::c_char,
Expand Down Expand Up @@ -47,6 +48,13 @@ impl MullvadApiError {
}
}

pub fn unwrap(&self) {
if !matches!(self.kind, MullvadApiErrorKind::NoError) {
let desc = unsafe { std::ffi::CStr::from_ptr(self.description) };
panic!("API ERROR - {:?} - {}", self.kind, desc.to_str().unwrap());
}
}

pub fn drop(self) {
if self.description.is_null() {
return;
Expand Down
Loading
Loading