diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/WelcomeViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/WelcomeViewModelTest.kt index 88f6f0c9cb51..655c51f6dde6 100644 --- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/WelcomeViewModelTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/WelcomeViewModelTest.kt @@ -2,18 +2,30 @@ package net.mullvad.mullvadvpn.viewmodel import androidx.lifecycle.viewModelScope import app.cash.turbine.test +import io.mockk.Runs import io.mockk.coEvery +import io.mockk.coVerify import io.mockk.every +import io.mockk.just import io.mockk.mockk import io.mockk.mockkStatic import io.mockk.unmockkAll import kotlin.test.assertEquals import kotlin.test.assertIs import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.test.runTest +import net.mullvad.mullvadvpn.PaymentProvider +import net.mullvad.mullvadvpn.compose.state.PaymentState +import net.mullvad.mullvadvpn.compose.state.WelcomeDialogState import net.mullvad.mullvadvpn.compose.state.WelcomeUiState import net.mullvad.mullvadvpn.lib.common.test.TestCoroutineRule +import net.mullvad.mullvadvpn.lib.common.test.assertLists +import net.mullvad.mullvadvpn.lib.payment.PaymentRepository +import net.mullvad.mullvadvpn.lib.payment.model.PaymentAvailability +import net.mullvad.mullvadvpn.lib.payment.model.PaymentProduct +import net.mullvad.mullvadvpn.lib.payment.model.PurchaseResult import net.mullvad.mullvadvpn.model.AccountAndDevice import net.mullvad.mullvadvpn.model.AccountExpiry import net.mullvad.mullvadvpn.model.DeviceState @@ -41,6 +53,7 @@ class WelcomeViewModelTest { MutableStateFlow(ServiceConnectionState.Disconnected) private val deviceState = MutableStateFlow(DeviceState.Initial) private val accountExpiryState = MutableStateFlow(AccountExpiry.Missing) + private val purchaseResult = MutableSharedFlow(extraBufferCapacity = 1) // Service connections private val mockServiceConnectionContainer: ServiceConnectionContainer = mockk() @@ -51,6 +64,8 @@ class WelcomeViewModelTest { private val mockAccountRepository: AccountRepository = mockk() private val mockDeviceRepository: DeviceRepository = mockk() + private val mockPaymentRepository: PaymentRepository = mockk() + private val mockPaymentProvider: PaymentProvider = mockk() private val mockServiceConnectionManager: ServiceConnectionManager = mockk() private lateinit var viewModel: WelcomeViewModel @@ -69,11 +84,21 @@ class WelcomeViewModelTest { every { mockAccountRepository.accountExpiryState } returns accountExpiryState + coEvery { mockPaymentRepository.verifyPurchases() } just Runs + + coEvery { mockPaymentRepository.purchaseResult } returns purchaseResult + + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns + PaymentAvailability.ProductsUnavailable + + every { mockPaymentProvider.paymentRepository } returns mockPaymentRepository + viewModel = WelcomeViewModel( accountRepository = mockAccountRepository, deviceRepository = mockDeviceRepository, serviceConnectionManager = mockServiceConnectionManager, + paymentProvider = mockPaymentProvider, pollAccountExpiry = false ) } @@ -111,9 +136,9 @@ class WelcomeViewModelTest { // Act, Assert viewModel.uiState.test { assertEquals(WelcomeUiState(), awaitItem()) + eventNotifierTunnelUiState.notify(tunnelUiStateTestItem) serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - eventNotifierTunnelUiState.notify(tunnelUiStateTestItem) val result = awaitItem() assertEquals(tunnelUiStateTestItem, result.tunnelState) } @@ -128,8 +153,6 @@ class WelcomeViewModelTest { // Act, Assert viewModel.uiState.test { assertEquals(WelcomeUiState(), awaitItem()) - serviceConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) deviceState.value = DeviceState.LoggedIn( accountAndDevice = @@ -138,6 +161,8 @@ class WelcomeViewModelTest { device = mockk() ) ) + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) val result = awaitItem() assertEquals(expectedAccountNumber, result.accountNumber) } @@ -158,6 +183,151 @@ class WelcomeViewModelTest { } } + @Test + fun testVerifyPurchases() = runTest { + // Act + viewModel.verifyPurchases() + + // Assert + coVerify { mockPaymentRepository.verifyPurchases() } + } + + @Test + fun testBillingProductsUnavailableState() = runTest { + // Arrange + + // Act, Assert + viewModel.uiState.test { + // Default item + awaitItem() + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) + val result = awaitItem().billingPaymentState + assertIs(result) + } + } + + @Test + fun testBillingProductsGenericErrorState() = runTest { + // Arrange + val mockPaymentAvailability = PaymentAvailability.Error.Other(mockk()) + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns mockPaymentAvailability + + // Act, Assert + viewModel.uiState.test { + // Default item + assertIs(awaitItem().billingPaymentState) + viewModel.fetchPaymentAvailability() + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) + val result = awaitItem().billingPaymentState + assertIs(result) + } + } + + @Test + fun testBillingProductsBillingErrorState() = runTest { + // Arrange + val mockPaymentAvailability = PaymentAvailability.Error.BillingUnavailable + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns mockPaymentAvailability + + // Act, Assert + viewModel.uiState.test { + // Default item + assertIs(awaitItem().billingPaymentState) + viewModel.fetchPaymentAvailability() + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) + val result = awaitItem().billingPaymentState + assertIs(result) + } + } + + @Test + fun testBillingProductsPaymentAvailableState() = runTest { + // Arrange + val mockProduct: PaymentProduct = mockk() + val expectedProductList = listOf(mockProduct) + val mockPaymentAvailability = PaymentAvailability.ProductsAvailable(listOf(mockProduct)) + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns mockPaymentAvailability + + // Act, Assert + viewModel.uiState.test { + // Default item + assertIs(awaitItem().billingPaymentState) + viewModel.fetchPaymentAvailability() + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) + val result = awaitItem().billingPaymentState + assertIs(result) + assertLists(expectedProductList, result.products) + } + } + + @Test + fun testBillingVerificationError() = runTest { + // Arrange + val mockPaymentAvailability = PaymentAvailability.ProductsUnavailable + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns mockPaymentAvailability + + // Act, Assert + viewModel.uiState.test { + // Default item + assertIs(awaitItem().dialogState) + purchaseResult.tryEmit(PurchaseResult.VerificationError) + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) + val result = awaitItem().dialogState + assertIs(result) + } + } + + @Test + fun testBillingUserCancelled() = runTest { + // Arrange + val mockPaymentAvailability = PaymentAvailability.ProductsUnavailable + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns mockPaymentAvailability + + // Act, Assert + viewModel.uiState.test { + // Default item + assertIs(awaitItem().dialogState) + purchaseResult.tryEmit(PurchaseResult.PurchaseCancelled) + ensureAllEventsConsumed() + } + } + + @Test + fun testBillingPurchaseCompleted() = runTest { + // Arrange + val mockPaymentAvailability = PaymentAvailability.ProductsUnavailable + coEvery { mockPaymentRepository.queryPaymentAvailability() } returns mockPaymentAvailability + + // Act, Assert + viewModel.uiState.test { + // Default item + assertIs(awaitItem().dialogState) + purchaseResult.tryEmit(PurchaseResult.PurchaseCompleted) + serviceConnectionState.value = + ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) + val result = awaitItem().dialogState + assertIs(result) + } + } + + @Test + fun testStartBillingPayment() { + // Arrange + val mockProductId = "MOCK" + coEvery { mockPaymentRepository.purchaseBillingProduct(mockProductId) } just Runs + + // Act + viewModel.startBillingPayment(mockProductId) + + // Assert + coVerify { mockPaymentRepository.purchaseBillingProduct(mockProductId) } + } + companion object { private const val SERVICE_CONNECTION_MANAGER_EXTENSIONS = "net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManagerExtensionsKt"