Skip to content

Commit

Permalink
Merge branch 'refactor-relay-settings-update' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 31, 2023
2 parents 9b33f44 + 3cedaf5 commit f073ef8
Show file tree
Hide file tree
Showing 44 changed files with 567 additions and 1,068 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,13 @@ class RelayListListener(
var selectedRelayItem: RelayItem? = null
private set

var selectedRelayLocation: GeographicLocationConstraint?
get() {
val settings = relaySettings as? RelaySettings.Normal
val location = settings?.relayConstraints?.location as? Constraint.Only

return location?.value?.toGeographicLocationConstraint()
}
set(value) {
connection.send(Request.SetRelayLocation(value).message)
}

var selectedWireguardConstraints: WireguardConstraints?
get() {
val settings = relaySettings as? RelaySettings.Normal
fun updateSelectedRelayLocation(value: GeographicLocationConstraint) {
connection.send(Request.SetRelayLocation(value).message)
}

return settings?.relayConstraints?.wireguardConstraints?.port?.let { port ->
WireguardConstraints(port)
}
}
set(value) {
connection.send(Request.SetWireguardConstraints(value).message)
}
fun updateSelectedWireguardConstraints(value: WireguardConstraints) {
connection.send(Request.SetWireguardConstraints(value).message)
}

var onRelayCountriesChange: ((List<RelayCountry>, RelayItem?) -> Unit)? = null
set(value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ class SelectLocationViewModel(private val serviceConnectionManager: ServiceConne
@Suppress("konsist.ensure public properties use permitted names")
val enterTransitionEndAction = _enterTransitionEndAction.asSharedFlow()

fun selectRelay(relayItem: RelayItem?) {
serviceConnectionManager.relayListListener()?.selectedRelayLocation = relayItem?.location
fun selectRelay(relayItem: RelayItem) {
serviceConnectionManager
.relayListListener()
?.updateSelectedRelayLocation(relayItem.location)
serviceConnectionManager.connectionProxy()?.connect()
viewModelScope.launch { _closeAction.emit(Unit) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,9 @@ class VpnSettingsViewModel(

fun onWireguardPortSelected(port: Constraint<Port>) {
viewModelScope.launch(dispatcher) {
serviceConnectionManager.relayListListener()?.selectedWireguardConstraints =
WireguardConstraints(port = port)
serviceConnectionManager
.relayListListener()
?.updateSelectedWireguardConstraints(WireguardConstraints(port = port))
}
hideDialog()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class SelectLocationViewModelTest {
assertEquals(Unit, awaitItem())
verify {
connectionProxyMock.connect()
mockRelayListListener.selectedRelayLocation = mockLocation
mockRelayListListener.updateSelectedRelayLocation(mockLocation)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ class VpnSettingsViewModelTest {
// Arrange
val wireguardPort: Constraint<Port> = Constraint.Only(Port(99))
val wireguardConstraints = WireguardConstraints(port = wireguardPort)
every { mockRelayListListener.selectedWireguardConstraints = any() } returns Unit
every {
mockRelayListListener.updateSelectedWireguardConstraints(wireguardConstraints)
} returns Unit

// Act
mockConnectionState.value =
Expand All @@ -168,7 +170,7 @@ class VpnSettingsViewModelTest {

// Assert
verify(exactly = 1) {
mockRelayListListener.selectedWireguardConstraints = wireguardConstraints
mockRelayListListener.updateSelectedWireguardConstraints(wireguardConstraints)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ sealed class Request : Message.RequestMessage() {
@Parcelize data class SetEnableSplitTunneling(val enable: Boolean) : Request()

@Parcelize
data class SetRelayLocation(val relayLocation: GeographicLocationConstraint?) : Request()
data class SetRelayLocation(val relayLocation: GeographicLocationConstraint) : Request()

@Parcelize data class SetWireGuardMtu(val mtu: Int?) : Request()

Expand All @@ -89,7 +89,7 @@ sealed class Request : Message.RequestMessage() {
@Parcelize data class SetObfuscationSettings(val settings: ObfuscationSettings?) : Request()

@Parcelize
data class SetWireguardConstraints(val wireguardConstraints: WireguardConstraints?) : Request()
data class SetWireguardConstraints(val wireguardConstraints: WireguardConstraints) : Request()

@Parcelize
data class SetWireGuardQuantumResistant(val quantumResistant: QuantumResistantState) :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import android.os.Parcelable
import kotlinx.parcelize.Parcelize

sealed class RelaySettings : Parcelable {
@Parcelize object CustomTunnelEndpoint : RelaySettings()
@Parcelize data object CustomTunnelEndpoint : RelaySettings()

@Parcelize class Normal(val relayConstraints: RelayConstraints) : RelaySettings()
@Parcelize data class Normal(val relayConstraints: RelayConstraints) : RelaySettings()
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import net.mullvad.mullvadvpn.model.PlayPurchaseInitResult
import net.mullvad.mullvadvpn.model.PlayPurchaseVerifyResult
import net.mullvad.mullvadvpn.model.QuantumResistantState
import net.mullvad.mullvadvpn.model.RelayList
import net.mullvad.mullvadvpn.model.RelaySettingsUpdate
import net.mullvad.mullvadvpn.model.RelaySettings
import net.mullvad.mullvadvpn.model.RemoveDeviceEvent
import net.mullvad.mullvadvpn.model.RemoveDeviceResult
import net.mullvad.mullvadvpn.model.Settings
Expand Down Expand Up @@ -182,8 +182,8 @@ class MullvadDaemon(
return verifyPlayPurchase(daemonInterfaceAddress, playPurchase)
}

fun updateRelaySettings(update: RelaySettingsUpdate) {
updateRelaySettings(daemonInterfaceAddress, update)
fun setRelaySettings(update: RelaySettings) {
setRelaySettings(daemonInterfaceAddress, update)
}

fun setObfuscationSettings(settings: ObfuscationSettings?) {
Expand Down Expand Up @@ -289,10 +289,7 @@ class MullvadDaemon(
playPurchase: PlayPurchase,
): PlayPurchaseVerifyResult

private external fun updateRelaySettings(
daemonInterfaceAddress: Long,
update: RelaySettingsUpdate
)
private external fun setRelaySettings(daemonInterfaceAddress: Long, update: RelaySettings)

private external fun setObfuscationSettings(
daemonInterfaceAddress: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import net.mullvad.mullvadvpn.lib.ipc.Request
import net.mullvad.mullvadvpn.model.Constraint
import net.mullvad.mullvadvpn.model.GeographicLocationConstraint
import net.mullvad.mullvadvpn.model.LocationConstraint
import net.mullvad.mullvadvpn.model.RelayConstraintsUpdate
import net.mullvad.mullvadvpn.model.RelayConstraints
import net.mullvad.mullvadvpn.model.RelayList
import net.mullvad.mullvadvpn.model.RelaySettingsUpdate
import net.mullvad.mullvadvpn.model.RelaySettings
import net.mullvad.mullvadvpn.model.WireguardConstraints
import net.mullvad.mullvadvpn.service.MullvadDaemon

Expand Down Expand Up @@ -87,26 +87,40 @@ class RelayListListener(endpoint: ServiceEndpoint) {
}

private suspend fun updateRelayConstraints() {
val currentRelayConstraints = getCurrentRelayConstraints()
val location: Constraint<LocationConstraint> =
selectedRelayLocation?.let { location ->
Constraint.Only(LocationConstraint.Location(location))
}
?: Constraint.Any()
val wireguardConstraints: WireguardConstraints? = selectedWireguardConstraints
?: currentRelayConstraints.location
val wireguardConstraints: WireguardConstraints =
selectedWireguardConstraints ?: currentRelayConstraints.wireguardConstraints

val update =
RelaySettingsUpdate.Normal(
RelayConstraintsUpdate(
RelaySettings.Normal(
RelayConstraints(
location = location,
wireguardConstraints = wireguardConstraints,
ownership = Constraint.Any(),
providers = Constraint.Any()
)
)

daemon.await().updateRelaySettings(update)
daemon.await().setRelaySettings(update)
}

private suspend fun getCurrentRelayConstraints(): RelayConstraints =
when (val relaySettings = daemon.await().getSettings()?.relaySettings) {
is RelaySettings.Normal -> relaySettings.relayConstraints
else ->
RelayConstraints(
location = Constraint.Any(),
providers = Constraint.Any(),
ownership = Constraint.Any(),
wireguardConstraints = WireguardConstraints(Constraint.Any())
)
}

companion object {
private enum class Command {
SetRelayLocation,
Expand Down
80 changes: 32 additions & 48 deletions gui/src/main/daemon-rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
IRelayListCountry,
IRelayListHostname,
IRelayListWithEndpointData,
IRelaySettingsNormal,
ISettings,
ITunnelOptions,
ITunnelStateRelayInfo,
Expand All @@ -58,12 +59,12 @@ import {
RelayLocationGeographical,
RelayProtocol,
RelaySettings,
RelaySettingsUpdate,
TunnelParameterError,
TunnelProtocol,
TunnelState,
TunnelType,
VoucherResponse,
wrapConstraint,
} from '../shared/daemon-rpc-types';
import log from '../shared/logging';
import { ManagementServiceClient } from './management_interface/management_interface_grpc_pb';
Expand Down Expand Up @@ -297,52 +298,14 @@ export class DaemonRpc {
}

// TODO: Custom tunnel configurations are not supported by the GUI.
public async updateRelaySettings(relaySettings: RelaySettingsUpdate): Promise<void> {
public async setRelaySettings(relaySettings: RelaySettings): Promise<void> {
if ('normal' in relaySettings) {
const settingsUpdate = relaySettings.normal;
const grpcRelaySettings = new grpcTypes.RelaySettingsUpdate();
const normalSettings = relaySettings.normal;
const grpcRelaySettings = new grpcTypes.RelaySettings();
grpcRelaySettings.setNormal(convertToRelayConstraints(normalSettings));

const normalUpdate = new grpcTypes.NormalRelaySettingsUpdate();

if (settingsUpdate.tunnelProtocol) {
const tunnelTypeUpdate = new grpcTypes.TunnelTypeUpdate();
if (settingsUpdate.tunnelProtocol !== 'any') {
tunnelTypeUpdate.setTunnelType(convertToTunnelType(settingsUpdate.tunnelProtocol.only));
}
normalUpdate.setTunnelType(tunnelTypeUpdate);
}

if (settingsUpdate.location) {
normalUpdate.setLocation(convertToLocation(liftConstraint(settingsUpdate.location)));
}

if (settingsUpdate.wireguardConstraints) {
normalUpdate.setWireguardConstraints(
convertToWireguardConstraints(settingsUpdate.wireguardConstraints),
);
}

if (settingsUpdate.openvpnConstraints) {
normalUpdate.setOpenvpnConstraints(
convertToOpenVpnConstraints(settingsUpdate.openvpnConstraints),
);
}

if (settingsUpdate.providers) {
const providerUpdate = new grpcTypes.ProviderUpdate();
providerUpdate.setProvidersList(settingsUpdate.providers);
normalUpdate.setProviders(providerUpdate);
}

if (settingsUpdate.ownership !== undefined) {
const ownershipUpdate = new grpcTypes.OwnershipUpdate();
ownershipUpdate.setOwnership(convertToOwnership(settingsUpdate.ownership));
normalUpdate.setOwnership(ownershipUpdate);
}

grpcRelaySettings.setNormal(normalUpdate);
await this.call<grpcTypes.RelaySettingsUpdate, Empty>(
this.client.updateRelaySettings,
await this.call<grpcTypes.RelaySettings, Empty>(
this.client.setRelaySettings,
grpcRelaySettings,
);
}
Expand Down Expand Up @@ -1148,7 +1111,7 @@ function convertFromRelaySettings(
case grpcTypes.RelaySettings.EndpointCase.NORMAL: {
const normal = relaySettings.getNormal()!;
const locationConstraint = convertFromLocationConstraint(normal.getLocation());
const location = locationConstraint ? { only: locationConstraint } : 'any';
const location = wrapConstraint(locationConstraint);
// `getTunnelType()` is not falsy if type is 'any'
const tunnelProtocol = convertFromTunnelTypeConstraint(
normal.hasTunnelType() ? normal.getTunnelType() : undefined,
Expand Down Expand Up @@ -1184,7 +1147,7 @@ function convertFromBridgeSettings(bridgeSettings: grpcTypes.BridgeSettings): Br
const locationConstraint = convertFromLocationConstraint(
bridgeSettings.getNormal()?.getLocation(),
);
const location = locationConstraint ? { only: locationConstraint } : 'any';
const location = wrapConstraint(locationConstraint);
const providers = normalSettings.providersList;
const ownership = convertFromOwnership(normalSettings.ownership);
return {
Expand Down Expand Up @@ -1475,7 +1438,7 @@ function convertFromWireguardConstraints(
const entryLocation = constraints.getEntryLocation();
if (entryLocation) {
const location = convertFromLocationConstraint(entryLocation);
result.entryLocation = location ? { only: location } : 'any';
result.entryLocation = wrapConstraint(location);
}

return result;
Expand Down Expand Up @@ -1505,6 +1468,27 @@ function convertFromConstraint<T>(value: T | undefined): Constraint<T> {
}
}

function convertToRelayConstraints(
constraints: IRelaySettingsNormal<IOpenVpnConstraints, IWireguardConstraints>,
): grpcTypes.NormalRelaySettings {
const relayConstraints = new grpcTypes.NormalRelaySettings();

if (constraints.tunnelProtocol !== 'any') {
relayConstraints.setTunnelType(convertToTunnelType(constraints.tunnelProtocol.only));
}
relayConstraints.setLocation(convertToLocation(liftConstraint(constraints.location)));
relayConstraints.setWireguardConstraints(
convertToWireguardConstraints(constraints.wireguardConstraints),
);
relayConstraints.setOpenvpnConstraints(
convertToOpenVpnConstraints(constraints.openvpnConstraints),
);
relayConstraints.setProvidersList(constraints.providers);
relayConstraints.setOwnership(convertToOwnership(constraints.ownership));

return relayConstraints;
}

function convertToNormalBridgeSettings(
constraints: IBridgeConstraints,
): grpcTypes.BridgeSettings.BridgeConstraints {
Expand Down
4 changes: 2 additions & 2 deletions gui/src/main/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ export default class Settings implements Readonly<ISettings> {
IpcMainEventChannel.settings.handleSetWireguardQuantumResistant((quantumResistant?: boolean) =>
this.daemonRpc.setWireguardQuantumResistant(quantumResistant),
);
IpcMainEventChannel.settings.handleUpdateRelaySettings((update) =>
this.daemonRpc.updateRelaySettings(update),
IpcMainEventChannel.settings.handleSetRelaySettings((relaySettings) =>
this.daemonRpc.setRelaySettings(relaySettings),
);
IpcMainEventChannel.settings.handleUpdateBridgeSettings((bridgeSettings) => {
return this.daemonRpc.setBridgeSettings(bridgeSettings);
Expand Down
Loading

0 comments on commit f073ef8

Please sign in to comment.