Skip to content

Commit

Permalink
Refactor ConnectivityListener
Browse files Browse the repository at this point in the history
  • Loading branch information
Rawa authored and dlon committed Nov 22, 2024
1 parent df75a17 commit 1507503
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 63 deletions.
1 change: 1 addition & 0 deletions android/lib/talpid/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ android {
dependencies {
implementation(projects.lib.model)

implementation(libs.androidx.ktx)
implementation(libs.androidx.lifecycle.service)
implementation(libs.kermit)
implementation(libs.kotlin.stdlib)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,86 +1,95 @@
package net.mullvad.talpid

import android.content.Context
import android.net.ConnectivityManager
import android.net.ConnectivityManager.NetworkCallback
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import co.touchlab.kermit.Logger
import java.net.InetAddress
import kotlin.properties.Delegates.observable

class ConnectivityListener {
private val availableNetworks = HashSet<Network>()

private val callback =
object : NetworkCallback() {
override fun onAvailable(network: Network) {
availableNetworks.add(network)
isConnected = true
}

override fun onLost(network: Network) {
availableNetworks.remove(network)
isConnected = availableNetworks.isNotEmpty()
}
}

private val defaultNetworkCallback =
object : NetworkCallback() {
override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) {
super.onLinkPropertiesChanged(network, linkProperties)
currentDnsServers = ArrayList(linkProperties.dnsServers)
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.scan
import kotlinx.coroutines.flow.stateIn
import net.mullvad.talpid.util.NetworkEvent
import net.mullvad.talpid.util.defaultNetworkFlow
import net.mullvad.talpid.util.networkFlow

class ConnectivityListener(val connectivityManager: ConnectivityManager) {
// Used by JNI
var senderAddress = 0L
set(value) {
if (value == 0L) {
destroySender(field)
}
field = value
}

private lateinit var connectivityManager: ConnectivityManager
private lateinit var _isConnected: StateFlow<Boolean>
// Used by JNI
val isConnected
get() = _isConnected.value

private lateinit var _currentDnsServers: StateFlow<List<InetAddress>>
// Used by JNI
var isConnected by
observable(false) { _, oldValue, newValue ->
if (newValue != oldValue) {
if (senderAddress != 0L) {
notifyConnectivityChange(newValue, senderAddress)
val currentDnsServers
get() = ArrayList(_currentDnsServers.value)

fun register(scope: CoroutineScope) {
_currentDnsServers =
dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers())

_isConnected =
hasInternetCapability()
.onEach {
if (senderAddress != 0L) {
notifyConnectivityChange(it, senderAddress)
}
}
}
}
.stateIn(scope, SharingStarted.Eagerly, false)
}

var currentDnsServers: ArrayList<InetAddress> = ArrayList()
private set(value) {
field = ArrayList(value.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER })
Logger.d("New currentDnsServers: $field")
}
fun unregister() {
senderAddress = 0L
}

var senderAddress = 0L
private fun dnsServerChanges(): Flow<List<InetAddress>> =
connectivityManager
.defaultNetworkFlow()
.filterIsInstance<NetworkEvent.LinkPropertiesChanged>()
.map { it.linkProperties.dnsServersWithoutFallback() }

private fun currentDnsServers(): List<InetAddress> =
connectivityManager
.getLinkProperties(connectivityManager.activeNetwork)
?.dnsServersWithoutFallback() ?: emptyList()

fun register(context: Context) {
private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> =
dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }

private fun hasInternetCapability(): Flow<Boolean> {
val request =
NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
.build()

connectivityManager =
context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager

connectivityManager.registerNetworkCallback(request, callback)
currentDnsServers =
connectivityManager.getLinkProperties(connectivityManager.activeNetwork)?.dnsServers?.let { ArrayList(it) }
?: ArrayList()
connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback)
}

fun unregister() {
connectivityManager.unregisterNetworkCallback(callback)
connectivityManager.unregisterNetworkCallback(defaultNetworkCallback)

if (senderAddress != 0L) {
var oldSender = senderAddress
senderAddress = 0L
destroySender(oldSender)
}
return connectivityManager
.networkFlow(request)
.scan(setOf<Network>()) { networks, event ->
when (event) {
is NetworkEvent.Available -> networks + event.network
is NetworkEvent.Lost -> networks - event.network
else -> networks
}
}
.map { it.isNotEmpty() }
.distinctUntilChanged()
}

private external fun notifyConnectivityChange(isConnected: Boolean, senderAddress: Long)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package net.mullvad.talpid

import android.net.ConnectivityManager
import android.os.ParcelFileDescriptor
import androidx.annotation.CallSuper
import androidx.core.content.getSystemService
import androidx.lifecycle.lifecycleScope
import co.touchlab.kermit.Logger
import java.net.Inet4Address
import java.net.Inet6Address
Expand Down Expand Up @@ -29,12 +32,13 @@ open class TalpidVpnService : LifecycleVpnService() {
private var currentTunConfig: TunConfig? = null

// Used by JNI
val connectivityListener = ConnectivityListener()
lateinit var connectivityListener: ConnectivityListener

@CallSuper
override fun onCreate() {
super.onCreate()
connectivityListener.register(this)
connectivityListener = ConnectivityListener(getSystemService<ConnectivityManager>()!!)
connectivityListener.register(lifecycleScope)
}

@CallSuper
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package net.mullvad.talpid.util

import android.net.ConnectivityManager
import android.net.ConnectivityManager.NetworkCallback
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.channels.trySendBlocking
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow

fun ConnectivityManager.defaultNetworkFlow(): Flow<NetworkEvent> =
callbackFlow<NetworkEvent> {
val callback =
object : NetworkCallback() {
override fun onLinkPropertiesChanged(
network: Network,
linkProperties: LinkProperties,
) {
super.onLinkPropertiesChanged(network, linkProperties)
trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
}

override fun onAvailable(network: Network) {
super.onAvailable(network)
trySendBlocking(NetworkEvent.Available(network))
}

override fun onCapabilitiesChanged(
network: Network,
networkCapabilities: NetworkCapabilities,
) {
super.onCapabilitiesChanged(network, networkCapabilities)
trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities))
}

override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
super.onBlockedStatusChanged(network, blocked)
trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked))
}

override fun onLosing(network: Network, maxMsToLive: Int) {
super.onLosing(network, maxMsToLive)
trySendBlocking(NetworkEvent.Losing(network, maxMsToLive))
}

override fun onLost(network: Network) {
super.onLost(network)
trySendBlocking(NetworkEvent.Lost(network))
}

override fun onUnavailable() {
super.onUnavailable()
trySendBlocking(NetworkEvent.Unavailable)
}
}
registerDefaultNetworkCallback(callback)

awaitClose { unregisterNetworkCallback(callback) }
}

fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow<NetworkEvent> =
callbackFlow<NetworkEvent> {
val callback =
object : NetworkCallback() {
override fun onLinkPropertiesChanged(
network: Network,
linkProperties: LinkProperties,
) {
super.onLinkPropertiesChanged(network, linkProperties)
trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties))
}

override fun onAvailable(network: Network) {
super.onAvailable(network)
trySendBlocking(NetworkEvent.Available(network))
}

override fun onCapabilitiesChanged(
network: Network,
networkCapabilities: NetworkCapabilities,
) {
super.onCapabilitiesChanged(network, networkCapabilities)
trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities))
}

override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
super.onBlockedStatusChanged(network, blocked)
trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked))
}

override fun onLosing(network: Network, maxMsToLive: Int) {
super.onLosing(network, maxMsToLive)
trySendBlocking(NetworkEvent.Losing(network, maxMsToLive))
}

override fun onLost(network: Network) {
super.onLost(network)
trySendBlocking(NetworkEvent.Lost(network))
}

override fun onUnavailable() {
super.onUnavailable()
trySendBlocking(NetworkEvent.Unavailable)
}
}
registerNetworkCallback(networkRequest, callback)

awaitClose { unregisterNetworkCallback(callback) }
}

sealed interface NetworkEvent {
data class Available(val network: Network) : NetworkEvent

data object Unavailable : NetworkEvent

data class LinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) :
NetworkEvent

data class CapabilitiesChanged(
val network: Network,
val networkCapabilities: NetworkCapabilities,
) : NetworkEvent

data class BlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent

data class Losing(val network: Network, val maxMsToLive: Int) : NetworkEvent

data class Lost(val network: Network) : NetworkEvent
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class MullvadVpnService : TalpidVpnService() {
}

override fun onDestroy() {
super.onDestroy()
Logger.i("MullvadVpnService: onDestroy")
// Shutting down the daemon gracefully
managementService.stop()
Expand All @@ -214,7 +215,6 @@ class MullvadVpnService : TalpidVpnService() {
managementService.enterIdle()

Logger.i("Shutdown complete")
super.onDestroy()
}

// If an intent is from the system it is because of the OS starting/stopping the VPN.
Expand Down

0 comments on commit 1507503

Please sign in to comment.