From 993c14dad004adc2b2b7dfd0fc567f34af8abb91 Mon Sep 17 00:00:00 2001
From: Jonatan Rhodin <jonatan.rhodin@mullvad.net>
Date: Wed, 22 Nov 2023 17:39:02 +0100
Subject: [PATCH 1/2] Add payment verification to Connect Screen

---
 .../src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt  | 2 +-
 .../net/mullvad/mullvadvpn/usecase/PaymentUseCase.kt       | 7 ++++---
 .../net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt   | 7 ++++++-
 3 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt
index 56df6699de97..b39d16b0aaa5 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt
@@ -124,7 +124,7 @@ val uiModule = module {
     viewModel {
         ChangelogViewModel(get(), BuildConfig.VERSION_CODE, BuildConfig.ALWAYS_SHOW_CHANGELOG)
     }
-    viewModel { ConnectViewModel(get(), get(), get(), get(), get(), get()) }
+    viewModel { ConnectViewModel(get(), get(), get(), get(), get(), get(), get()) }
     viewModel { DeviceListViewModel(get(), get()) }
     viewModel { DeviceRevokedViewModel(get(), get()) }
     viewModel { LoginViewModel(get(), get(), get()) }
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PaymentUseCase.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PaymentUseCase.kt
index 151e2caec7cd..bda53bcaf29e 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PaymentUseCase.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PaymentUseCase.kt
@@ -20,7 +20,7 @@ interface PaymentUseCase {
 
     suspend fun resetPurchaseResult()
 
-    suspend fun verifyPurchases()
+    suspend fun verifyPurchases(onSuccessfulVerification: () -> Unit = {})
 }
 
 class PlayPaymentUseCase(private val paymentRepository: PaymentRepository) : PaymentUseCase {
@@ -42,11 +42,12 @@ class PlayPaymentUseCase(private val paymentRepository: PaymentRepository) : Pay
         _purchaseResult.emit(null)
     }
 
-    override suspend fun verifyPurchases() {
+    override suspend fun verifyPurchases(onSuccessfulVerification: () -> Unit) {
         paymentRepository.verifyPurchases().collect {
             if (it == VerificationResult.Success) {
                 // Update the payment availability after a successful verification.
                 queryPaymentAvailability()
+                onSuccessfulVerification()
             }
         }
     }
@@ -68,7 +69,7 @@ class EmptyPaymentUseCase : PaymentUseCase {
         // No op
     }
 
-    override suspend fun verifyPurchases() {
+    override suspend fun verifyPurchases(onSuccessfulVerification: () -> Unit) {
         // No op
     }
 }
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt
index 89b83d9f111e..76c290f439c0 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt
@@ -33,6 +33,7 @@ import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState
 import net.mullvad.mullvadvpn.ui.serviceconnection.authTokenCache
 import net.mullvad.mullvadvpn.ui.serviceconnection.connectionProxy
 import net.mullvad.mullvadvpn.usecase.NewDeviceNotificationUseCase
+import net.mullvad.mullvadvpn.usecase.PaymentUseCase
 import net.mullvad.mullvadvpn.usecase.RelayListUseCase
 import net.mullvad.mullvadvpn.util.callbackFlowFromNotifier
 import net.mullvad.mullvadvpn.util.combine
@@ -49,7 +50,8 @@ class ConnectViewModel(
     private val deviceRepository: DeviceRepository,
     private val inAppNotificationController: InAppNotificationController,
     private val newDeviceNotificationUseCase: NewDeviceNotificationUseCase,
-    private val relayListUseCase: RelayListUseCase
+    private val relayListUseCase: RelayListUseCase,
+    private val paymentUseCase: PaymentUseCase
 ) : ViewModel() {
     private val _uiSideEffect = MutableSharedFlow<UiSideEffect>(extraBufferCapacity = 1)
     val uiSideEffect = _uiSideEffect.asSharedFlow()
@@ -137,6 +139,9 @@ class ConnectViewModel(
         // The create account cache is no longer needed as we have successfully reached the connect
         // screen
         accountRepository.clearCreatedAccountCache()
+        viewModelScope.launch {
+            paymentUseCase.verifyPurchases { accountRepository.fetchAccountExpiry() }
+        }
     }
 
     private fun LocationInfoCache.locationCallbackFlow() = callbackFlow {

From 022952de592b17f1a8afc852dc99e38ed446aebb Mon Sep 17 00:00:00 2001
From: Jonatan Rhodin <jonatan.rhodin@mullvad.net>
Date: Wed, 22 Nov 2023 17:42:44 +0100
Subject: [PATCH 2/2] Fix tests

---
 .../mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt   | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt
index f6157431db69..345a57df8097 100644
--- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt
+++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt
@@ -39,6 +39,7 @@ import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager
 import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState
 import net.mullvad.mullvadvpn.ui.serviceconnection.authTokenCache
 import net.mullvad.mullvadvpn.ui.serviceconnection.connectionProxy
+import net.mullvad.mullvadvpn.usecase.PaymentUseCase
 import net.mullvad.mullvadvpn.usecase.RelayListUseCase
 import net.mullvad.mullvadvpn.util.appVersionCallbackFlow
 import net.mullvad.talpid.tunnel.ErrorState
@@ -89,6 +90,9 @@ class ConnectViewModelTest {
     // Relay list use case
     private val mockRelayListUseCase: RelayListUseCase = mockk()
 
+    // Payment use case
+    private val mockPaymentUseCase: PaymentUseCase = mockk(relaxed = true)
+
     // Captures
     private val locationSlot = slot<((GeoIpLocation?) -> Unit)>()
 
@@ -139,7 +143,8 @@ class ConnectViewModelTest {
                 deviceRepository = mockDeviceRepository,
                 inAppNotificationController = mockInAppNotificationController,
                 relayListUseCase = mockRelayListUseCase,
-                newDeviceNotificationUseCase = mockk()
+                newDeviceNotificationUseCase = mockk(),
+                paymentUseCase = mockPaymentUseCase
             )
     }