diff --git a/core/src/main/java/org/springframework/security/authentication/ott/GenerateOneTimeTokenRequest.java b/core/src/main/java/org/springframework/security/authentication/ott/GenerateOneTimeTokenRequest.java index b03e65dd18..c0ff79dbde 100644 --- a/core/src/main/java/org/springframework/security/authentication/ott/GenerateOneTimeTokenRequest.java +++ b/core/src/main/java/org/springframework/security/authentication/ott/GenerateOneTimeTokenRequest.java @@ -17,6 +17,7 @@ package org.springframework.security.authentication.ott; import java.time.Duration; +import java.util.UUID; import org.springframework.util.Assert; @@ -24,6 +25,7 @@ * Class to store information related to an One-Time Token authentication request * * @author Marcus da Coregio + * @author Max Batiscev * @since 6.4 */ public class GenerateOneTimeTokenRequest { @@ -34,6 +36,8 @@ public class GenerateOneTimeTokenRequest { private final Duration expiresIn; + private final String tokenValue; + public GenerateOneTimeTokenRequest(String username) { this(username, DEFAULT_EXPIRES_IN); } @@ -43,6 +47,16 @@ public GenerateOneTimeTokenRequest(String username, Duration expiresIn) { Assert.notNull(expiresIn, "expiresIn cannot be null"); this.username = username; this.expiresIn = expiresIn; + this.tokenValue = UUID.randomUUID().toString(); + } + + public GenerateOneTimeTokenRequest(String username, Duration expiresIn, String tokenValue) { + Assert.hasText(username, "username cannot be empty"); + Assert.hasText(tokenValue, "tokenValue cannot be empty"); + Assert.notNull(expiresIn, "expiresIn cannot be null"); + this.username = username; + this.expiresIn = expiresIn; + this.tokenValue = tokenValue; } public String getUsername() { @@ -53,4 +67,8 @@ public Duration getExpiresIn() { return this.expiresIn; } + public String getTokenValue() { + return this.tokenValue; + } + } diff --git a/core/src/main/java/org/springframework/security/authentication/ott/InMemoryOneTimeTokenService.java b/core/src/main/java/org/springframework/security/authentication/ott/InMemoryOneTimeTokenService.java index 0d67961794..99571a9b12 100644 --- a/core/src/main/java/org/springframework/security/authentication/ott/InMemoryOneTimeTokenService.java +++ b/core/src/main/java/org/springframework/security/authentication/ott/InMemoryOneTimeTokenService.java @@ -32,6 +32,7 @@ * there is more or equal than 100 tokens stored in the map. * * @author Marcus da Coregio + * @author Max Batischev * @since 6.4 */ public final class InMemoryOneTimeTokenService implements OneTimeTokenService { @@ -43,10 +44,9 @@ public final class InMemoryOneTimeTokenService implements OneTimeTokenService { @Override @NonNull public OneTimeToken generate(GenerateOneTimeTokenRequest request) { - String token = UUID.randomUUID().toString(); Instant expiresAt = this.clock.instant().plus(request.getExpiresIn()); - OneTimeToken ott = new DefaultOneTimeToken(token, request.getUsername(), expiresAt); - this.oneTimeTokenByToken.put(token, ott); + OneTimeToken ott = new DefaultOneTimeToken(request.getTokenValue(), request.getUsername(), expiresAt); + this.oneTimeTokenByToken.put(request.getTokenValue(), ott); cleanExpiredTokensIfNeeded(); return ott; } diff --git a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java index a58665bd1e..4bcab640a4 100644 --- a/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java +++ b/core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java @@ -24,7 +24,6 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; -import java.util.UUID; import java.util.function.Function; import org.apache.commons.logging.Log; @@ -130,9 +129,8 @@ public void setCleanupCron(String cleanupCron) { @Override public OneTimeToken generate(GenerateOneTimeTokenRequest request) { Assert.notNull(request, "generateOneTimeTokenRequest cannot be null"); - String token = UUID.randomUUID().toString(); Instant expiresAt = this.clock.instant().plus(request.getExpiresIn()); - OneTimeToken oneTimeToken = new DefaultOneTimeToken(token, request.getUsername(), expiresAt); + OneTimeToken oneTimeToken = new DefaultOneTimeToken(request.getTokenValue(), request.getUsername(), expiresAt); insertOneTimeToken(oneTimeToken); return oneTimeToken; } diff --git a/web/src/main/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolver.java b/web/src/main/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolver.java index f8577c60a8..ad2e9ede52 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolver.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolver.java @@ -17,6 +17,8 @@ package org.springframework.security.web.authentication.ott; import java.time.Duration; +import java.util.UUID; +import java.util.function.Supplier; import jakarta.servlet.http.HttpServletRequest; @@ -37,13 +39,15 @@ public final class DefaultGenerateOneTimeTokenRequestResolver implements Generat private Duration expiresIn = DEFAULT_EXPIRES_IN; + private Supplier tokenValueFactory = () -> UUID.randomUUID().toString(); + @Override public GenerateOneTimeTokenRequest resolve(HttpServletRequest request) { String username = request.getParameter("username"); if (!StringUtils.hasText(username)) { return null; } - return new GenerateOneTimeTokenRequest(username, this.expiresIn); + return new GenerateOneTimeTokenRequest(username, this.expiresIn, this.tokenValueFactory.get()); } /** @@ -55,4 +59,14 @@ public void setExpiresIn(Duration expiresIn) { this.expiresIn = expiresIn; } + /** + * Sets factory for token value generation + * @param tokenValueFactory factory for token value generation + * @since 6.5 + */ + public void setTokenValueFactory(Supplier tokenValueFactory) { + Assert.notNull(tokenValueFactory, "tokenValueFactory cannot be null"); + this.tokenValueFactory = tokenValueFactory; + } + } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolver.java b/web/src/main/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolver.java index f89298f6d4..2d6bec3a5d 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolver.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolver.java @@ -17,6 +17,8 @@ package org.springframework.security.web.server.authentication.ott; import java.time.Duration; +import java.util.UUID; +import java.util.function.Supplier; import reactor.core.publisher.Mono; @@ -40,13 +42,15 @@ public final class DefaultServerGenerateOneTimeTokenRequestResolver private Duration expiresIn = DEFAULT_EXPIRES_IN; + private Supplier tokenValueFactory = () -> UUID.randomUUID().toString(); + @Override public Mono resolve(ServerWebExchange exchange) { // @formatter:off return exchange.getFormData() .mapNotNull((data) -> data.getFirst(USERNAME)) .switchIfEmpty(Mono.empty()) - .map((username) -> new GenerateOneTimeTokenRequest(username, this.expiresIn)); + .map((username) -> new GenerateOneTimeTokenRequest(username, this.expiresIn, this.tokenValueFactory.get())); // @formatter:on } @@ -59,4 +63,14 @@ public void setExpiresIn(Duration expiresIn) { this.expiresIn = expiresIn; } + /** + * Sets factory for token value generation + * @param tokenValueFactory factory for token value generation + * @since 6.5 + */ + public void setTokenValueFactory(Supplier tokenValueFactory) { + Assert.notNull(tokenValueFactory, "tokenValueFactory cannot be null"); + this.tokenValueFactory = tokenValueFactory; + } + } diff --git a/web/src/test/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolverTests.java b/web/src/test/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolverTests.java index 12a491230e..13be40cb81 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/ott/DefaultGenerateOneTimeTokenRequestResolverTests.java @@ -64,4 +64,15 @@ void resolveWhenExpiresInSetThenResolvesGenerateRequest() { assertThat(generateRequest.getExpiresIn()).isEqualTo(Duration.ofSeconds(600)); } + @Test + void resolveWhenTokenValueFactorySetThenResolvesGenerateRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("username", "test"); + this.requestResolver.setTokenValueFactory(() -> "tokenValue"); + + GenerateOneTimeTokenRequest generateRequest = this.requestResolver.resolve(request); + + assertThat(generateRequest.getTokenValue()).isEqualTo("tokenValue"); + } + } diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolverTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolverTests.java index c9bfc9eef1..2e032ec1c3 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/ott/DefaultServerGenerateOneTimeTokenRequestResolverTests.java @@ -71,4 +71,16 @@ void resolveWhenExpiresInSetThenResolvesGenerateRequest() { assertThat(generateRequest.getExpiresIn()).isEqualTo(Duration.ofSeconds(600)); } + @Test + void resolveWhenTokenValueFactorySetThenResolvesGenerateRequest() { + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/ott/generate") + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body("username=user")); + this.resolver.setTokenValueFactory(() -> "tokenValue"); + + GenerateOneTimeTokenRequest generateRequest = this.resolver.resolve(exchange).block(); + + assertThat(generateRequest.getTokenValue()).isEqualTo("tokenValue"); + } + }