diff --git a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/AuthorizationProxyDataConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/AuthorizationProxyDataConfiguration.java index e446ee2736..c94f44ecaa 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/method/configuration/AuthorizationProxyDataConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/method/configuration/AuthorizationProxyDataConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,23 @@ package org.springframework.security.config.annotation.method.configuration; +import java.util.List; + import org.springframework.aop.framework.AopInfrastructureBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Role; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.data.domain.PageImpl; +import org.springframework.data.domain.SliceImpl; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResult; +import org.springframework.data.geo.GeoResults; import org.springframework.security.aot.hint.SecurityHintsRegistrar; import org.springframework.security.authorization.AuthorizationProxyFactory; +import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory; import org.springframework.security.data.aot.hint.AuthorizeReturnObjectDataHintsRegistrar; @Configuration(proxyBeanMethods = false) @@ -34,4 +44,45 @@ static SecurityHintsRegistrar authorizeReturnObjectDataHintsRegistrar(Authorizat return new AuthorizeReturnObjectDataHintsRegistrar(proxyFactory); } + @Bean + @Role(BeanDefinition.ROLE_INFRASTRUCTURE) + @Order(Ordered.HIGHEST_PRECEDENCE + 100) + DataTargetVisitor dataTargetVisitor() { + return new DataTargetVisitor(); + } + + private static final class DataTargetVisitor implements AuthorizationAdvisorProxyFactory.TargetVisitor { + + @Override + public Object visit(AuthorizationAdvisorProxyFactory proxyFactory, Object target) { + if (target instanceof GeoResults geoResults) { + return new GeoResults<>(proxyCast(proxyFactory, geoResults.getContent()), + geoResults.getAverageDistance()); + } + if (target instanceof GeoResult geoResult) { + return new GeoResult<>(proxyCast(proxyFactory, geoResult.getContent()), geoResult.getDistance()); + } + if (target instanceof GeoPage geoPage) { + GeoResults results = new GeoResults<>(proxyCast(proxyFactory, geoPage.getContent()), + geoPage.getAverageDistance()); + return new GeoPage<>(results, geoPage.getPageable(), geoPage.getTotalElements()); + } + if (target instanceof PageImpl page) { + List content = proxyCast(proxyFactory, page.getContent()); + return new PageImpl<>(content, page.getPageable(), page.getTotalElements()); + } + if (target instanceof SliceImpl slice) { + List content = proxyCast(proxyFactory, slice.getContent()); + return new SliceImpl<>(content, slice.getPageable(), slice.hasNext()); + } + return null; + } + + @SuppressWarnings("unchecked") + private T proxyCast(AuthorizationAdvisorProxyFactory proxyFactory, T target) { + return (T) proxyFactory.proxy(target); + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostMethodSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostMethodSecurityConfigurationTests.java index 8eec0f1bce..e5c87a5f2d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostMethodSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostMethodSecurityConfigurationTests.java @@ -62,6 +62,14 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.annotation.AnnotationConfigurationException; import org.springframework.core.annotation.Order; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageImpl; +import org.springframework.data.domain.Slice; +import org.springframework.data.domain.SliceImpl; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResult; +import org.springframework.data.geo.GeoResults; import org.springframework.http.HttpStatusCode; import org.springframework.http.ResponseEntity; import org.springframework.security.access.AccessDeniedException; @@ -733,6 +741,28 @@ public void findByIdWhenUnauthorizedResultThenDenies() { assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude); } + @Test + @WithMockUser(authorities = "airplane:read") + public void findGeoResultByIdWhenAuthorizedResultThenAuthorizes() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + GeoResult geoResultFlight = flights.findGeoResultFlightById("1"); + Flight flight = geoResultFlight.getContent(); + assertThatNoException().isThrownBy(flight::getAltitude); + assertThatNoException().isThrownBy(flight::getSeats); + } + + @Test + @WithMockUser(authorities = "seating:read") + public void findGeoResultByIdWhenUnauthorizedResultThenDenies() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + GeoResult geoResultFlight = flights.findGeoResultFlightById("1"); + Flight flight = geoResultFlight.getContent(); + assertThatNoException().isThrownBy(flight::getSeats); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude); + } + @Test @WithMockUser(authorities = "airplane:read") public void findByIdWhenAuthorizedResponseEntityThenAuthorizes() { @@ -804,6 +834,46 @@ public void findAllWhenPostFilterThenFilters() { .doesNotContain("Kevin Mitnick")); } + @Test + @WithMockUser(authorities = "airplane:read") + public void findPageWhenPostFilterThenFilters() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findPage() + .forEach((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName) + .doesNotContain("Kevin Mitnick")); + } + + @Test + @WithMockUser(authorities = "airplane:read") + public void findSliceWhenPostFilterThenFilters() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findSlice() + .forEach((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName) + .doesNotContain("Kevin Mitnick")); + } + + @Test + @WithMockUser(authorities = "airplane:read") + public void findGeoPageWhenPostFilterThenFilters() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findGeoPage() + .forEach((flight) -> assertThat(flight.getContent().getPassengers()).extracting(Passenger::getName) + .doesNotContain("Kevin Mitnick")); + } + + @Test + @WithMockUser(authorities = "airplane:read") + public void findGeoResultsWhenPostFilterThenFilters() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findGeoResults() + .forEach((flight) -> assertThat(flight.getContent().getPassengers()).extracting(Passenger::getName) + .doesNotContain("Kevin Mitnick")); + } + @Test @WithMockUser(authorities = "airplane:read") public void findAllWhenPreFilterThenFilters() { @@ -1688,10 +1758,39 @@ Iterator findAll() { return this.flights.values().iterator(); } + Page findPage() { + return new PageImpl<>(new ArrayList<>(this.flights.values())); + } + + Slice findSlice() { + return new SliceImpl<>(new ArrayList<>(this.flights.values())); + } + + GeoPage findGeoPage() { + List> results = new ArrayList<>(); + for (Flight flight : this.flights.values()) { + results.add(new GeoResult<>(flight, new Distance(flight.altitude))); + } + return new GeoPage<>(new GeoResults<>(results)); + } + + GeoResults findGeoResults() { + List> results = new ArrayList<>(); + for (Flight flight : this.flights.values()) { + results.add(new GeoResult<>(flight, new Distance(flight.altitude))); + } + return new GeoResults<>(results); + } + Flight findById(String id) { return this.flights.get(id); } + GeoResult findGeoResultFlightById(String id) { + Flight flight = this.flights.get(id); + return new GeoResult<>(flight, new Distance(flight.altitude)); + } + Flight save(Flight flight) { this.flights.put(flight.getId(), flight); return flight;