diff --git a/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/confirmaccountprovider/ConfirmAccountProviderPresenter.kt b/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/confirmaccountprovider/ConfirmAccountProviderPresenter.kt index c15c84744f..9b27c84e4c 100644 --- a/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/confirmaccountprovider/ConfirmAccountProviderPresenter.kt +++ b/features/login/impl/src/main/kotlin/io/element/android/features/login/impl/screens/confirmaccountprovider/ConfirmAccountProviderPresenter.kt @@ -15,6 +15,7 @@ import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope +import androidx.core.net.toUri import dagger.assisted.Assisted import dagger.assisted.AssistedFactory import dagger.assisted.AssistedInject @@ -23,6 +24,7 @@ import io.element.android.features.login.impl.accountprovider.AccountProviderDat import io.element.android.features.login.impl.error.ChangeServerError import io.element.android.features.login.impl.screens.createaccount.AccountCreationNotSupported import io.element.android.features.login.impl.web.WebClientUrlForAuthenticationRetriever +import io.element.android.libraries.androidutils.uri.setQueryParameter import io.element.android.libraries.architecture.AsyncData import io.element.android.libraries.architecture.Presenter import io.element.android.libraries.architecture.runCatchingUpdatingState @@ -92,7 +94,16 @@ class ConfirmAccountProviderPresenter @AssistedInject constructor( val matrixHomeServerDetails = authenticationService.getHomeserverDetails().value!! if (matrixHomeServerDetails.supportsOidcLogin) { // Retrieve the details right now - LoginFlow.OidcFlow(authenticationService.getOidcUrl().getOrThrow()) + val oidcDetails = authenticationService.getOidcUrl().getOrThrow() + if (params.isAccountCreation) { + // In this case, add or replace the "prompt" parameter to "create" + val newUrl = oidcDetails.url.toUri() + .setQueryParameter("prompt", "create") + .toString() + LoginFlow.OidcFlow(oidcDetails.copy(url = newUrl)) + } else { + LoginFlow.OidcFlow(oidcDetails) + } } else if (params.isAccountCreation) { val url = webClientUrlForAuthenticationRetriever.retrieve(homeserverUrl) LoginFlow.AccountCreationFlow(url) diff --git a/libraries/androidutils/src/main/kotlin/io/element/android/libraries/androidutils/uri/UriExtensions.kt b/libraries/androidutils/src/main/kotlin/io/element/android/libraries/androidutils/uri/UriExtensions.kt index c0883a52a7..42933281a7 100644 --- a/libraries/androidutils/src/main/kotlin/io/element/android/libraries/androidutils/uri/UriExtensions.kt +++ b/libraries/androidutils/src/main/kotlin/io/element/android/libraries/androidutils/uri/UriExtensions.kt @@ -12,3 +12,16 @@ import android.net.Uri const val IGNORED_SCHEMA = "ignored" fun createIgnoredUri(path: String): Uri = Uri.parse("$IGNORED_SCHEMA://$path") + +fun Uri.setQueryParameter(key: String, value: String): Uri { + val existingParams = queryParameterNames + return buildUpon().apply { + clearQuery() + existingParams.forEach { existingKey -> + if (existingKey != key) { + appendQueryParameter(existingKey, getQueryParameter(existingKey)) + } + } + appendQueryParameter(key, value) + }.build() +} diff --git a/libraries/androidutils/src/test/kotlin/io/element/android/libraries/androidutils/uri/UriExtensionTest.kt b/libraries/androidutils/src/test/kotlin/io/element/android/libraries/androidutils/uri/UriExtensionTest.kt new file mode 100644 index 0000000000..217e4ba4b1 --- /dev/null +++ b/libraries/androidutils/src/test/kotlin/io/element/android/libraries/androidutils/uri/UriExtensionTest.kt @@ -0,0 +1,74 @@ +/* + * Copyright 2024 New Vector Ltd. + * + * SPDX-License-Identifier: AGPL-3.0-only + * Please see LICENSE in the repository root for full details. + */ + +package io.element.android.libraries.androidutils.uri + +import androidx.core.net.toUri +import com.google.common.truth.Truth.assertThat +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner + +@RunWith(RobolectricTestRunner::class) +class UriExtensionTest { + @Test + fun `url with prompt parameter should replace existing value`() { + val url = "https://beta.element.io/account/authorize" + + "?response_type=code" + + "&client_id=01J9RB9MEJCMVHWNYYHTVDNBVJ" + + "&redirect_uri=io.element%3A%2Fcallback" + + "&scope=scope" + + "&state=x61ILblUF6BwOTUA" + + "&nonce=N7TdVfDhyVNF9PbH" + + "&prompt=consent" + + "&code_challenge_method=S256" + + "&code_challenge=bDV2DWX0j0U-QtewSUJeXr3DEmvFxlHfQN_1UxXpOUk" + val result = url.toUri() + .setQueryParameter("prompt", "create") + .toString() + assertThat(result).isEqualTo( + "https://beta.element.io/account/authorize" + + "?response_type=code" + + "&client_id=01J9RB9MEJCMVHWNYYHTVDNBVJ" + + "&redirect_uri=io.element%3A%2Fcallback" + + "&scope=scope" + + "&state=x61ILblUF6BwOTUA" + + "&nonce=N7TdVfDhyVNF9PbH" + + "&code_challenge_method=S256" + + "&code_challenge=bDV2DWX0j0U-QtewSUJeXr3DEmvFxlHfQN_1UxXpOUk" + + "&prompt=create" + ) + } + + @Test + fun `url without prompt parameter should add the parameter`() { + val url = "https://beta.element.io/account/authorize" + + "?response_type=code" + + "&client_id=01J9RB9MEJCMVHWNYYHTVDNBVJ" + + "&redirect_uri=io.element%3A%2Fcallback" + + "&scope=scope" + + "&state=x61ILblUF6BwOTUA" + + "&nonce=N7TdVfDhyVNF9PbH" + + "&code_challenge_method=S256" + + "&code_challenge=bDV2DWX0j0U-QtewSUJeXr3DEmvFxlHfQN_1UxXpOUk" + val result = url.toUri() + .setQueryParameter("prompt", "create") + .toString() + assertThat(result).isEqualTo( + "https://beta.element.io/account/authorize" + + "?response_type=code" + + "&client_id=01J9RB9MEJCMVHWNYYHTVDNBVJ" + + "&redirect_uri=io.element%3A%2Fcallback" + + "&scope=scope" + + "&state=x61ILblUF6BwOTUA" + + "&nonce=N7TdVfDhyVNF9PbH" + + "&code_challenge_method=S256" + + "&code_challenge=bDV2DWX0j0U-QtewSUJeXr3DEmvFxlHfQN_1UxXpOUk" + + "&prompt=create" + ) + } +}