Skip to content

Commit

Permalink
feat: single event handler for sealed interface
Browse files Browse the repository at this point in the history
  • Loading branch information
aludwiko committed Dec 23, 2023
1 parent 652242f commit 79cbfbb
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import kalix.javasdk.Metadata;
import kalix.javasdk.eventsourcedentity.EventSourcedEntity;
import kalix.javasdk.eventsourcedentity.EventSourcedEntityContext;
import kalix.javasdk.impl.MethodInvoker;
import kalix.javasdk.testkit.impl.EventSourcedEntityEffectsRunner;
import kalix.javasdk.testkit.impl.TestKitEventSourcedEntityContext;
import kalix.javasdk.impl.JsonMessageCodec;
import kalix.javasdk.impl.eventsourcedentity.EventSourceEntityHandlers;
import kalix.javasdk.impl.eventsourcedentity.EventSourcedHandlersExtractor;
import scala.collection.immutable.Map;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
Expand All @@ -44,7 +45,7 @@ public class EventSourcedTestKit<S, E, ES extends EventSourcedEntity<S, E>>
extends EventSourcedEntityEffectsRunner<S, E> {

private final ES entity;
private final EventSourceEntityHandlers eventHandlers;
private final Map<String, MethodInvoker> eventHandlers;

private final JsonMessageCodec messageCodec;

Expand Down Expand Up @@ -124,7 +125,7 @@ public <R> EventSourcedResult<R> call(Function<ES, EventSourcedEntity.Effect<R>>
@Override
protected final S handleEvent(S state, E event) {
try {
Method method = eventHandlers.handlers().apply(messageCodec.removeVersion(messageCodec.typeUrlFor(event.getClass()))).method();
Method method = eventHandlers.apply(messageCodec.removeVersion(messageCodec.typeUrlFor(event.getClass()))).method();
return (S) method.invoke(entity, event);
} catch (NoSuchElementException e) {
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,16 @@ public Counter onValueSet(CounterEvent.ValueSet evt) {
public Counter onValueMultiplied(CounterEvent.ValueMultiplied evt) {
return new Counter(this.value * evt.value());
}

public Counter apply(CounterEvent counterEvent) {
if (counterEvent instanceof CounterEvent.ValueIncreased increased) {
return onValueIncreased(increased);
} else if (counterEvent instanceof CounterEvent.ValueSet set) {
return onValueSet(set);
} else if (counterEvent instanceof CounterEvent.ValueMultiplied multiplied) {
return onValueMultiplied(multiplied);
} else {
throw new RuntimeException("Unknown event type: " + counterEvent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,7 @@ public Effect<Integer> restart() { // force entity restart, useful for testing
}

@EventHandler
public Counter handleIncrease(CounterEvent.ValueIncreased increased) {
return currentState().onValueIncreased(increased);
}

@EventHandler
public Counter handleSet(CounterEvent.ValueSet set) {
return currentState().onValueSet(set);
}

@EventHandler
public Counter handleMultiply(CounterEvent.ValueMultiplied multiplied) {
return currentState().onValueMultiplied(multiplied);
public Counter handle(CounterEvent counterEvent) {
return currentState().apply(counterEvent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import kalix.javasdk.annotations.TypeName;

public interface CounterEvent {
public sealed interface CounterEvent {

@TypeName("increased")
record ValueIncreased(int value) implements CounterEvent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import kalix.javasdk.impl.ComponentDescriptorFactory$;
import kalix.javasdk.impl.JsonMessageCodec;
import kalix.javasdk.impl.MessageCodec;
import kalix.javasdk.impl.eventsourcedentity.EventSourceEntityHandlers;
import kalix.javasdk.impl.MethodInvoker;
import kalix.javasdk.impl.eventsourcedentity.EventSourcedEntityRouter;
import kalix.javasdk.impl.eventsourcedentity.EventSourcedHandlersExtractor;
import kalix.javasdk.impl.eventsourcedentity.ReflectiveEventSourcedEntityRouter;
import scala.collection.immutable.Map;

import java.util.Optional;
import java.util.function.Function;
Expand All @@ -46,7 +47,7 @@ public class ReflectiveEventSourcedEntityProvider<S, E, ES extends EventSourcedE

private final JsonMessageCodec messageCodec;

private final EventSourceEntityHandlers eventHandlers;
private final Map<String, MethodInvoker> eventHandlers;

public static <S, E, ES extends EventSourcedEntity<S, E>> ReflectiveEventSourcedEntityProvider<S, E, ES> of(
Class<ES> cls,
Expand All @@ -68,14 +69,6 @@ public ReflectiveEventSourcedEntityProvider(
"Event Sourced Entity [" + entityClass.getName() + "] is missing '@TypeId' annotation");

this.eventHandlers = EventSourcedHandlersExtractor.handlersFrom(entityClass, messageCodec);
if (this.eventHandlers.errors().nonEmpty()) {
throw new IllegalArgumentException(
"Event Sourced Entity ["
+ entityClass.getName()
+ "] has event handlers configured incorrectly: "
+ this.eventHandlers.errors());
}

this.entityType = typeId;
this.factory = factory;
this.options = options.withForwardHeaders(ForwardHeadersExtractor.extractFrom(entityClass));
Expand Down Expand Up @@ -104,7 +97,7 @@ public String entityType() {
public EventSourcedEntityRouter<S, E, ES> newRouter(EventSourcedEntityContext context) {
ES entity = factory.apply(context);
return new ReflectiveEventSourcedEntityRouter<>(
entity, componentDescriptor.commandHandlers(), eventHandlers.handlers(), messageCodec);
entity, componentDescriptor.commandHandlers(), eventHandlers, messageCodec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ package kalix.javasdk.impl

import java.lang.reflect.AnnotatedElement
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType

import scala.reflect.ClassTag

import kalix.javasdk.action.Action
import kalix.javasdk.annotations.EventHandler
import kalix.javasdk.annotations.Publish
import kalix.javasdk.annotations.Query
import kalix.javasdk.annotations.Subscribe
Expand Down Expand Up @@ -148,6 +150,39 @@ object Validations {
validateEventSourcedEntity(component) ++
validateWorkflow(component)

private def validateEventHandlers(entityClass: Class[_]): Validation = {

val annotatedHandlers = entityClass.getDeclaredMethods
.filter(_.getAnnotation(classOf[EventHandler]) != null)
.toList

val genericTypeArguments = entityClass.getGenericSuperclass
.asInstanceOf[ParameterizedType]
.getActualTypeArguments

// the state type parameter from the ES entity defines the return type of each event handler
val stateType = genericTypeArguments.head
.asInstanceOf[Class[_]]

val eventType = genericTypeArguments(1).asInstanceOf[Class[_]]

val (invalidHandlers, validSignatureHandlers) = annotatedHandlers.partition((m: Method) =>
m.getParameterCount != 1 || !Modifier.isPublic(m.getModifiers) || (stateType != m.getReturnType))

val signatureValidation = Validation(
invalidHandlers
.sortBy(_.getName) //for tests
.map(method =>
errorMessage(
entityClass,
s"event handler [${method.getName}] must be public, with exactly one parameter and return type '${stateType.getTypeName}'.")))

val missingHandlerInputParams = validSignatureHandlers.sortBy(_.getName).map(_.getParameterTypes.head)

signatureValidation ++ ambiguousHandlersErrors(validSignatureHandlers, entityClass) ++
missingEventHandler(missingHandlerInputParams, eventType, entityClass)
}

private def validateCompoundIdsOrder(component: Class[_]): Validation = {
val restService = RestServiceIntrospector.inspectService(component)
component.getMethods.toIndexedSeq
Expand Down Expand Up @@ -187,7 +222,8 @@ object Validations {

private def validateEventSourcedEntity(component: Class[_]): Validation = {
when[EventSourcedEntity[_, _]](component) {
validateCompoundIdsOrder(component)
validateCompoundIdsOrder(component) ++
validateEventHandlers(component)
}
}

Expand Down Expand Up @@ -255,7 +291,7 @@ object Validations {
hasStreamSubscription(component),
hasTopicSubscription(component))

when(typeLevelSubs.filter(identity).size > 1) {
when(typeLevelSubs.count(identity) > 1) {
Validation(errorMessage(component, "Only one subscription type is allowed on a type level."))
}
}
Expand Down Expand Up @@ -287,12 +323,9 @@ object Validations {
val methods = component.getMethods.toIndexedSeq

if (hasSubscription(component)) {
val effectMethodsByInputParams: Map[Option[Class[_]], IndexedSeq[Method]] = methods
val effectMethods = methods
.filter(updateMethodPredicate)
.groupBy(_.getParameterTypes.lastOption)

Validation(ambiguousHandlersErrors(effectMethodsByInputParams, component))

ambiguousHandlersErrors(effectMethods, component)
} else {
val effectOutputMethodsGrouped = methods
.filter(hasSubscription)
Expand All @@ -301,30 +334,40 @@ object Validations {

effectOutputMethodsGrouped
.map { case (_, methods) =>
val effectMethodsByInputParams: Map[Option[Class[_]], IndexedSeq[Method]] =
methods.groupBy(_.getParameterTypes.lastOption)
Validation(ambiguousHandlersErrors(effectMethodsByInputParams, component))
ambiguousHandlersErrors(methods, component)
}
.fold(Valid)(_ ++ _)
}

}

private def ambiguousHandlersErrors(
effectMethodsInputParams: Map[Option[Class[_]], IndexedSeq[Method]],
component: Class[_]) = {
val errors = effectMethodsInputParams
private def ambiguousHandlersErrors(handlers: Seq[Method], component: Class[_]): Validation = {
val ambiguousHandlers = handlers
.groupBy(_.getParameterTypes.lastOption)
.filter(_._2.size > 1)
.map {
case (Some(inputType), methods) =>
errorMessage(
Validation(errorMessage(
component,
s"Ambiguous handlers for ${inputType.getCanonicalName}, methods: [${methods.sorted.map(_.getName).mkString(", ")}] consume the same type.")
s"Ambiguous handlers for ${inputType.getCanonicalName}, methods: [${methods.sorted.map(_.getName).mkString(", ")}] consume the same type."))
case (None, methods) => //only delete handlers
errorMessage(component, s"Ambiguous delete handlers: [${methods.sorted.map(_.getName).mkString(", ")}].")
Validation(
errorMessage(component, s"Ambiguous delete handlers: [${methods.sorted.map(_.getName).mkString(", ")}]."))
}
.toSeq
errors

val sealedHandler = handlers.find(_.getParameterTypes.lastOption.exists(_.isSealed))
val sealedHandlerMixedUsage = if (sealedHandler.nonEmpty && handlers.size > 1) {
val unexpectedHandlerNames = handlers.filterNot(m => m == sealedHandler.get).map(_.getName)
Validation(
errorMessage(
component,
s"Event handler accepting a sealed interface [${sealedHandler.get.getName}] cannot be mixed with handlers for specific events. Please remove following handlers: [${unexpectedHandlerNames
.mkString(", ")}]."))
} else {
Valid
}

ambiguousHandlers.fold(sealedHandlerMixedUsage)(_ ++ _)
}

private def missingEventHandlerValidations(
Expand All @@ -339,7 +382,7 @@ object Validations {
val effectMethodsInputParams: Seq[Class[_]] = methods
.filter(updateMethodPredicate)
.map(_.getParameterTypes.last) //last because it could be a view update methods with 2 params
Validation(missingErrors(effectMethodsInputParams, eventType, component))
missingEventHandler(effectMethodsInputParams, eventType, component)
} else {
Valid
}
Expand All @@ -350,23 +393,38 @@ object Validations {
.filter(updateMethodPredicate)
.groupBy(findEventSourcedEntityClass)

val errors = effectOutputMethodsGrouped.flatMap { case (entityClass, methods) =>
val eventType = getEventType(entityClass)
if (eventType.isSealed) {
missingErrors(methods.map(_.getParameterTypes.last), eventType, component)
} else {
List.empty
effectOutputMethodsGrouped
.map { case (entityClass, methods) =>
val eventType = getEventType(entityClass)
if (eventType.isSealed) {
missingEventHandler(methods.map(_.getParameterTypes.last), eventType, component)
} else {
Valid
}
}
}
Validation(errors.toSeq)
.fold(Valid)(_ ++ _)
}
}

private def missingErrors(effectOutputInputParams: Seq[Class[_]], eventType: Class[_], component: Class[_]) = {
eventType.getPermittedSubclasses
.filterNot(effectOutputInputParams.contains)
.map(clazz => s"Component '${component.getSimpleName}' is missing an event handler for '${clazz.getName}'")
.toList
private def missingEventHandler(
inputEventParams: Seq[Class[_]],
eventType: Class[_],
component: Class[_]): Validation = {
if (inputEventParams.exists(param => param.isSealed && param == eventType)) {
//single sealed interface handler
Valid
} else {
if (eventType.isSealed) {
//checking possible only for sealed interfaces
Validation(
eventType.getPermittedSubclasses
.filterNot(inputEventParams.contains)
.map(clazz => errorMessage(component, s"missing an event handler for '${clazz.getName}'."))
.toList)
} else {
Valid
}
}
}

private def topicSubscriptionValidations(component: Class[_]): Validation = {
Expand Down Expand Up @@ -662,7 +720,7 @@ object Validations {
val numParams = method.getParameters.length
errorMessage(
method,
s"Method annotated with '@Subscribe.ValueEntity' and handleDeletes=true must not have parameters. Found ${numParams} method parameters.")
s"Method annotated with '@Subscribe.ValueEntity' and handleDeletes=true must not have parameters. Found $numParams method parameters.")
}

Validation(messages)
Expand Down
Loading

0 comments on commit 79cbfbb

Please sign in to comment.