Skip to content

Commit

Permalink
Address Observation Bean Name Collisions
Browse files Browse the repository at this point in the history
Closes gh-16161
  • Loading branch information
jzheaux committed Nov 25, 2024
1 parent a550215 commit 2b5a2ee
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Copyright 2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.rsocket;

import java.util.ArrayList;
import java.util.List;

import io.rsocket.core.RSocketServer;
import io.rsocket.exceptions.RejectedSetupException;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.rsocket.core.SecuritySocketAcceptorInterceptor;
import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

/**
* @author Rob Winch
*/
@ContextConfiguration
@ExtendWith(SpringExtension.class)
public class HelloRSocketWithWebFluxITests {

@Autowired
RSocketMessageHandler handler;

@Autowired
SecuritySocketAcceptorInterceptor interceptor;

@Autowired
ServerController controller;

private CloseableChannel server;

private RSocketRequester requester;

@BeforeEach
public void setup() {
// @formatter:off
this.server = RSocketServer.create()
.payloadDecoder(PayloadDecoder.ZERO_COPY)
.interceptors((registry) ->
registry.forSocketAcceptor(this.interceptor)
)
.acceptor(this.handler.responder())
.bind(TcpServerTransport.create("localhost", 0))
.block();
// @formatter:on
}

@AfterEach
public void dispose() {
this.requester.rsocket().dispose();
this.server.dispose();
this.controller.payloads.clear();
}

// gh-16161
@Test
public void retrieveMonoWhenSecureThenDenied() {
// @formatter:off
this.requester = RSocketRequester.builder()
.rsocketStrategies(this.handler.getRSocketStrategies())
.connectTcp("localhost", this.server.address().getPort())
.block();
// @formatter:on
String data = "rob";
// @formatter:off
assertThatExceptionOfType(Exception.class).isThrownBy(
() -> this.requester.route("secure.retrieve-mono")
.data(data)
.retrieveMono(String.class)
.block()
)
.matches((ex) -> ex instanceof RejectedSetupException
|| ex.getClass().toString().contains("ReactiveException"));
// @formatter:on
assertThat(this.controller.payloads).isEmpty();
}

@Configuration
@EnableRSocketSecurity
@EnableWebFluxSecurity
static class Config {

@Bean
ServerController controller() {
return new ServerController();
}

@Bean
RSocketMessageHandler messageHandler() {
RSocketMessageHandler handler = new RSocketMessageHandler();
handler.setRSocketStrategies(rsocketStrategies());
return handler;
}

@Bean
RSocketStrategies rsocketStrategies() {
return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build();
}

@Bean
MapReactiveUserDetailsService uds() {
// @formatter:off
UserDetails rob = User.withDefaultPasswordEncoder()
.username("rob")
.password("password")
.roles("USER", "ADMIN")
.build();
// @formatter:on
return new MapReactiveUserDetailsService(rob);
}

}

@Controller
static class ServerController {

private List<String> payloads = new ArrayList<>();

@MessageMapping("**")
String retrieveMono(String payload) {
add(payload);
return "Hi " + payload;
}

private void add(String p) {
this.payloads.add(p);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.config.annotation.rsocket;

import java.util.Map;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -62,8 +64,12 @@ void setPasswordEncoder(PasswordEncoder passwordEncoder) {
}

@Autowired(required = false)
void setAuthenticationManagerPostProcessor(ObjectPostProcessor<ReactiveAuthenticationManager> postProcessor) {
this.postProcessor = postProcessor;
void setAuthenticationManagerPostProcessor(
Map<String, ObjectPostProcessor<ReactiveAuthenticationManager>> postProcessors) {
if (postProcessors.size() == 1) {
this.postProcessor = postProcessors.values().iterator().next();
}
this.postProcessor = postProcessors.get("rSocketAuthenticationManagerPostProcessor");
}

@Bean(name = RSOCKET_SECURITY_BEAN_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import org.springframework.security.authorization.ReactiveAuthorizationManager;
import org.springframework.security.config.ObjectPostProcessor;
import org.springframework.security.config.observation.SecurityObservationSettings;
import org.springframework.security.web.server.ObservationWebFilterChainDecorator;
import org.springframework.security.web.server.WebFilterChainProxy.WebFilterChainDecorator;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.security.rsocket.api.PayloadExchange;

@Configuration(proxyBeanMethods = false)
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
Expand All @@ -45,7 +43,7 @@ class ReactiveObservationConfiguration {

@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
static ObjectPostProcessor<ReactiveAuthorizationManager<ServerWebExchange>> rSocketAuthorizationManagerPostProcessor(
static ObjectPostProcessor<ReactiveAuthorizationManager<PayloadExchange>> rSocketAuthorizationManagerPostProcessor(
ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
return new ObjectPostProcessor<>() {
@Override
Expand All @@ -71,18 +69,4 @@ public ReactiveAuthenticationManager postProcess(ReactiveAuthenticationManager o
};
}

@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
static ObjectPostProcessor<WebFilterChainDecorator> rSocketFilterChainDecoratorPostProcessor(
ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
return new ObjectPostProcessor<>() {
@Override
public WebFilterChainDecorator postProcess(WebFilterChainDecorator object) {
ObservationRegistry r = registry.getIfUnique(() -> ObservationRegistry.NOOP);
boolean active = !r.isNoop() && predicate.getIfUnique(() -> all).shouldObserveRequests();
return active ? new ObservationWebFilterChainDecorator(r) : object;
}
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ReactiveAuthorizationManager postProcess(ReactiveAuthorizationManager obj

@Bean
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
static ObjectPostProcessor<ReactiveAuthenticationManager> authenticationManagerPostProcessor(
static ObjectPostProcessor<ReactiveAuthenticationManager> reactiveAuthenticationManagerPostProcessor(
ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
return new ObjectPostProcessor<>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.config.annotation.web.reactive;

import java.util.Map;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.ObjectProvider;
Expand Down Expand Up @@ -96,8 +98,12 @@ void setUserDetailsPasswordService(ReactiveUserDetailsPasswordService userDetail
}

@Autowired(required = false)
void setAuthenticationManagerPostProcessor(ObjectPostProcessor<ReactiveAuthenticationManager> postProcessor) {
this.postProcessor = postProcessor;
void setAuthenticationManagerPostProcessor(
Map<String, ObjectPostProcessor<ReactiveAuthenticationManager>> postProcessors) {
if (postProcessors.size() == 1) {
this.postProcessor = postProcessors.values().iterator().next();
}
this.postProcessor = postProcessors.get("reactiveAuthenticationManagerPostProcessor");
}

@Autowired(required = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,31 @@ public void getWhenUsingObservationRegistryThenObservesRequest() {
assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
}

// gh-16161
@Test
public void getWhenUsingRSocketThenObservesRequest() {
this.spring.register(ObservationRegistryConfig.class, RSocketSecurityConfig.class).autowire();
// @formatter:off
this.webClient
.get()
.uri("/hello")
.headers((headers) -> headers.setBasicAuth("user", "password"))
.exchange()
.expectStatus()
.isNotFound();
// @formatter:on
ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(6)).onStart(captor.capture());
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
assertThat(contexts.next().getContextualName()).isEqualTo("http get");
assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain before");
assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications");
assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations");
assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests");
assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
}

@Configuration
static class SubclassConfig extends ServerHttpSecurityConfiguration {

Expand Down

0 comments on commit 2b5a2ee

Please sign in to comment.