Skip to content

Commit

Permalink
Merge branch '3.2.x'
Browse files Browse the repository at this point in the history
Closes gh-40474
  • Loading branch information
wilkinsona committed Apr 22, 2024
2 parents 9184448 + 1f06aa2 commit cde9166
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.StringWriter;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -32,6 +33,7 @@
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ResponseStatusException;
Expand All @@ -46,8 +48,8 @@
* <li>error - The error reason</li>
* <li>exception - The class name of the root exception (if configured)</li>
* <li>message - The exception message (if configured)</li>
* <li>errors - Any {@link ObjectError}s from a {@link BindingResult} exception (if
* configured)</li>
* <li>errors - Any {@link ObjectError}s from a {@link BindingResult} or
* {@link MethodValidationResult} exception (if configured)</li>
* <li>trace - The exception stack trace (if configured)</li>
* <li>path - The URL path when the exception was raised</li>
* <li>requestId - Unique ID associated with the current request</li>
Expand All @@ -58,6 +60,7 @@
* @author Michele Mancioppi
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
* @since 2.0.0
* @see ErrorAttributes
*/
Expand Down Expand Up @@ -97,9 +100,8 @@ private Map<String, Object> getErrorAttributes(ServerRequest request, boolean in
HttpStatus errorStatus = determineHttpStatus(error, responseStatusAnnotation);
errorAttributes.put("status", errorStatus.value());
errorAttributes.put("error", errorStatus.getReasonPhrase());
errorAttributes.put("message", determineMessage(error, responseStatusAnnotation));
errorAttributes.put("requestId", request.exchange().getRequest().getId());
handleException(errorAttributes, determineException(error), includeStackTrace);
handleException(errorAttributes, error, responseStatusAnnotation, includeStackTrace);
return errorAttributes;
}

Expand All @@ -113,44 +115,51 @@ private HttpStatus determineHttpStatus(Throwable error, MergedAnnotation<Respons
return responseStatusAnnotation.getValue("code", HttpStatus.class).orElse(HttpStatus.INTERNAL_SERVER_ERROR);
}

private String determineMessage(Throwable error, MergedAnnotation<ResponseStatus> responseStatusAnnotation) {
if (error instanceof BindingResult) {
return error.getMessage();
}
if (error instanceof ResponseStatusException responseStatusException) {
return responseStatusException.getReason();
}
String reason = responseStatusAnnotation.getValue("reason", String.class).orElse("");
if (StringUtils.hasText(reason)) {
return reason;
}
return (error.getMessage() != null) ? error.getMessage() : "";
}

private Throwable determineException(Throwable error) {
if (error instanceof ResponseStatusException) {
return (error.getCause() != null) ? error.getCause() : error;
}
return error;
}

private void addStackTrace(Map<String, Object> errorAttributes, Throwable error) {
StringWriter stackTrace = new StringWriter();
error.printStackTrace(new PrintWriter(stackTrace));
stackTrace.flush();
errorAttributes.put("trace", stackTrace.toString());
}

private void handleException(Map<String, Object> errorAttributes, Throwable error, boolean includeStackTrace) {
errorAttributes.put("exception", error.getClass().getName());
if (includeStackTrace) {
addStackTrace(errorAttributes, error);
private void handleException(Map<String, Object> errorAttributes, Throwable error,
MergedAnnotation<ResponseStatus> responseStatusAnnotation, boolean includeStackTrace) {
Throwable exception;
if (error instanceof BindingResult bindingResult) {
errorAttributes.put("message", error.getMessage());
errorAttributes.put("errors", bindingResult.getAllErrors());
exception = error;
}
if (error instanceof BindingResult result) {
if (result.hasErrors()) {
errorAttributes.put("errors", result.getAllErrors());
}
else if (error instanceof MethodValidationResult methodValidationResult) {
addMessageAndErrorsFromMethodValidationResult(errorAttributes, methodValidationResult);
exception = error;
}
else if (error instanceof ResponseStatusException responseStatusException) {
errorAttributes.put("message", responseStatusException.getReason());
exception = (responseStatusException.getCause() != null) ? responseStatusException.getCause() : error;
}
else {
exception = error;
String reason = responseStatusAnnotation.getValue("reason", String.class).orElse("");
String message = StringUtils.hasText(reason) ? reason : error.getMessage();
errorAttributes.put("message", (message != null) ? message : "");
}
errorAttributes.put("exception", exception.getClass().getName());
if (includeStackTrace) {
addStackTrace(errorAttributes, exception);
}
}

private void addMessageAndErrorsFromMethodValidationResult(Map<String, Object> errorAttributes,
MethodValidationResult result) {
List<ObjectError> errors = result.getAllErrors()
.stream()
.filter(ObjectError.class::isInstance)
.map(ObjectError.class::cast)
.toList();
errorAttributes.put("message",
"Validation failed for method='" + result.getMethod() + "'. Error count: " + errors.size());
errorAttributes.put("errors", errors);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.StringWriter;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import jakarta.servlet.RequestDispatcher;
Expand All @@ -36,6 +37,7 @@
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.servlet.HandlerExceptionResolver;
Expand All @@ -50,8 +52,8 @@
* <li>error - The error reason</li>
* <li>exception - The class name of the root exception (if configured)</li>
* <li>message - The exception message (if configured)</li>
* <li>errors - Any {@link ObjectError}s from a {@link BindingResult} exception (if
* configured)</li>
* <li>errors - Any {@link ObjectError}s from a {@link BindingResult} or
* {@link MethodValidationResult} exception (if configured)</li>
* <li>trace - The exception stack trace (if configured)</li>
* <li>path - The URL path when the exception was raised</li>
* </ul>
Expand All @@ -62,6 +64,7 @@
* @author Vedran Pavic
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
* @since 2.0.0
* @see ErrorAttributes
*/
Expand Down Expand Up @@ -149,12 +152,18 @@ private void addErrorDetails(Map<String, Object> errorAttributes, WebRequest web
}

private void addErrorMessage(Map<String, Object> errorAttributes, WebRequest webRequest, Throwable error) {
BindingResult result = extractBindingResult(error);
if (result == null) {
addExceptionErrorMessage(errorAttributes, webRequest, error);
BindingResult bindingResult = extractBindingResult(error);
if (bindingResult != null) {
addMessageAndErrorsFromBindingResult(errorAttributes, bindingResult);
}
else {
addBindingResultErrorMessage(errorAttributes, result);
MethodValidationResult methodValidationResult = extractMethodValidationResult(error);
if (methodValidationResult != null) {
addMessageAndErrorsFromMethodValidationResult(errorAttributes, methodValidationResult);
}
else {
addExceptionErrorMessage(errorAttributes, webRequest, error);
}
}
}

Expand Down Expand Up @@ -187,10 +196,25 @@ protected String getMessage(WebRequest webRequest, Throwable error) {
return "No message available";
}

private void addBindingResultErrorMessage(Map<String, Object> errorAttributes, BindingResult result) {
errorAttributes.put("message", "Validation failed for object='" + result.getObjectName() + "'. "
+ "Error count: " + result.getErrorCount());
errorAttributes.put("errors", result.getAllErrors());
private void addMessageAndErrorsFromBindingResult(Map<String, Object> errorAttributes, BindingResult result) {
addMessageAndErrorsForValidationFailure(errorAttributes, "object='" + result.getObjectName() + "'",
result.getAllErrors());
}

private void addMessageAndErrorsFromMethodValidationResult(Map<String, Object> errorAttributes,
MethodValidationResult result) {
List<ObjectError> errors = result.getAllErrors()
.stream()
.filter(ObjectError.class::isInstance)
.map(ObjectError.class::cast)
.toList();
addMessageAndErrorsForValidationFailure(errorAttributes, "method='" + result.getMethod() + "'", errors);
}

private void addMessageAndErrorsForValidationFailure(Map<String, Object> errorAttributes, String validated,
List<ObjectError> errors) {
errorAttributes.put("message", "Validation failed for " + validated + ". Error count: " + errors.size());
errorAttributes.put("errors", errors);
}

private BindingResult extractBindingResult(Throwable error) {
Expand All @@ -200,6 +224,13 @@ private BindingResult extractBindingResult(Throwable error) {
return null;
}

private MethodValidationResult extractMethodValidationResult(Throwable error) {
if (error instanceof MethodValidationResult methodValidationResult) {
return methodValidationResult;
}
return null;
}

private void addStackTrace(Map<String, Object> errorAttributes, Throwable error) {
StringWriter stackTrace = new StringWriter();
error.printStackTrace(new PrintWriter(stackTrace));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@
import org.springframework.validation.BindingResult;
import org.springframework.validation.MapBindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.validation.method.ParameterValidationResult;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.support.WebExchangeBindException;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
Expand All @@ -51,6 +54,7 @@
* @author Stephane Nicoll
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
*/
class DefaultErrorAttributesTests {

Expand Down Expand Up @@ -271,6 +275,25 @@ void extractBindingResultErrors() throws Exception {
assertThat(attributes).containsEntry("errors", bindingResult.getAllErrors());
}

@Test
void extractMethodValidationResultErrors() throws Exception {
Object target = "test";
Method method = String.class.getMethod("substring", int.class);
MethodParameter parameter = new MethodParameter(method, 0);
MethodValidationResult methodValidationResult = MethodValidationResult.create(target, method,
List.of(new ParameterValidationResult(parameter, -1,
List.of(new ObjectError("beginIndex", "beginIndex is negative")), null, null, null)));
HandlerMethodValidationException ex = new HandlerMethodValidationException(methodValidationResult);
MockServerHttpRequest request = MockServerHttpRequest.get("/test").build();
Map<String, Object> attributes = this.errorAttributes.getErrorAttributes(buildServerRequest(request, ex),
ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
assertThat(attributes.get("message")).asString()
.isEqualTo(
"Validation failed for method='public java.lang.String java.lang.String.substring(int)'. Error count: 1");
assertThat(attributes).containsEntry("errors",
methodValidationResult.getAllErrors().stream().filter(ObjectError.class::isInstance).toList());
}

@Test
void extractBindingResultErrorsExcludeMessageAndErrors() throws Exception {
Method method = getClass().getDeclaredMethod("method", String.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;

import jakarta.servlet.ServletException;
import org.junit.jupiter.api.Test;

import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.context.MessageSourceResolvable;
import org.springframework.core.MethodParameter;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
Expand All @@ -34,9 +36,12 @@
import org.springframework.validation.BindingResult;
import org.springframework.validation.MapBindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.validation.method.ParameterValidationResult;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.WebRequest;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import org.springframework.web.servlet.ModelAndView;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -48,6 +53,7 @@
* @author Vedran Pavic
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
*/
class DefaultErrorAttributesTests {

Expand Down Expand Up @@ -202,18 +208,37 @@ void withMethodArgumentNotValidExceptionBindingErrors() {
testBindingResult(bindingResult, ex, ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
}

@Test
void withHandlerMethodValidationExceptionBindingErrors() {
Object target = "test";
Method method = ReflectionUtils.findMethod(String.class, "substring", int.class);
MethodParameter parameter = new MethodParameter(method, 0);
MethodValidationResult methodValidationResult = MethodValidationResult.create(target, method,
List.of(new ParameterValidationResult(parameter, -1,
List.of(new ObjectError("beginIndex", "beginIndex is negative")), null, null, null)));
HandlerMethodValidationException ex = new HandlerMethodValidationException(methodValidationResult);
testErrors(methodValidationResult.getAllErrors(),
"Validation failed for method='public java.lang.String java.lang.String.substring(int)'. Error count: 1",
ex, ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
}

private void testBindingResult(BindingResult bindingResult, Exception ex, ErrorAttributeOptions options) {
testErrors(bindingResult.getAllErrors(), "Validation failed for object='objectName'. Error count: 1", ex,
options);
}

private void testErrors(List<? extends MessageSourceResolvable> errors, String expectedMessage, Exception ex,
ErrorAttributeOptions options) {
this.request.setAttribute("jakarta.servlet.error.exception", ex);
Map<String, Object> attributes = this.errorAttributes.getErrorAttributes(this.webRequest, options);
if (options.isIncluded(Include.MESSAGE)) {
assertThat(attributes).containsEntry("message",
"Validation failed for object='objectName'. Error count: 1");
assertThat(attributes).containsEntry("message", expectedMessage);
}
else {
assertThat(attributes).doesNotContainKey("message");
}
if (options.isIncluded(Include.BINDING_ERRORS)) {
assertThat(attributes).containsEntry("errors", bindingResult.getAllErrors());
assertThat(attributes).containsEntry("errors", errors);
}
else {
assertThat(attributes).doesNotContainKey("errors");
Expand Down

0 comments on commit cde9166

Please sign in to comment.