From 37d1775a2cc9eb8f132e94c80444e4c4db3bb195 Mon Sep 17 00:00:00 2001 From: Rodrigo Zhou <2068124+rodrigozhou@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:03:49 -0500 Subject: [PATCH] Support for mocking nexus operations (#1666) * Support for mocking nexus operations * address comments * add tests * add nexus events listeners * add test for nexus listeners * address comments * add assert nexus calls methods * address static checks * Support mocking Nexus operation with operation reference (#1683) * Support mocking Nexus operation with operation reference * address comments * address comments --- internal/internal_worker.go | 18 +- internal/internal_workflow_testsuite.go | 514 ++++++++++++++++++++++-- internal/workflow_testsuite.go | 220 +++++++++- test/nexus_test.go | 302 ++++++++++++++ 4 files changed, 1002 insertions(+), 52 deletions(-) diff --git a/internal/internal_worker.go b/internal/internal_worker.go index 91dfe8c28..23beb0c59 100644 --- a/internal/internal_worker.go +++ b/internal/internal_worker.go @@ -768,6 +768,22 @@ func (r *registry) getWorkflowDefinition(wt WorkflowType) (WorkflowDefinition, e return newSyncWorkflowDefinition(executor), nil } +func (r *registry) getNexusService(service string) *nexus.Service { + r.Lock() + defer r.Unlock() + return r.nexusServices[service] +} + +func (r *registry) getRegisteredNexusServices() []*nexus.Service { + r.Lock() + defer r.Unlock() + result := make([]*nexus.Service, 0, len(r.nexusServices)) + for _, s := range r.nexusServices { + result = append(result, s) + } + return result +} + // Validate function parameters. func validateFnFormat(fnType reflect.Type, isWorkflow bool) error { if fnType.Kind() != reflect.Func { @@ -1058,7 +1074,7 @@ func (aw *AggregatedWorker) start() error { return err } } - nexusServices := aw.registry.nexusServices + nexusServices := aw.registry.getRegisteredNexusServices() if len(nexusServices) > 0 { reg := nexus.NewServiceRegistry() for _, service := range nexusServices { diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 3b9742f73..7cbf77552 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -112,6 +112,20 @@ type ( done bool onCompleted func(*commonpb.Payload, error) onStarted func(opID string, e error) + isMocked bool + } + + testNexusAsyncOperationHandle struct { + result *commonpb.Payload + err error + delay time.Duration + } + + // Interface for nexus.OperationReference without the types as generics. + testNexusOperationReference interface { + Name() string + InputType() reflect.Type + OutputType() reflect.Type } testCallbackHandle struct { @@ -152,6 +166,7 @@ type ( workflowMock *mock.Mock activityMock *mock.Mock + nexusMock *mock.Mock service workflowservice.WorkflowServiceClient logger log.Logger metricsHandler metrics.Handler @@ -172,25 +187,31 @@ type ( timers map[string]*testTimerHandle runningWorkflows map[string]*testWorkflowHandle runningNexusOperations map[int64]*testNexusOperationHandle + nexusAsyncOpHandle map[string]*testNexusAsyncOperationHandle + nexusOperationRefs map[string]map[string]testNexusOperationReference runningCount int expectedWorkflowMockCalls map[string]struct{} expectedActivityMockCalls map[string]struct{} - - onActivityStartedListener func(activityInfo *ActivityInfo, ctx context.Context, args converter.EncodedValues) - onActivityCompletedListener func(activityInfo *ActivityInfo, result converter.EncodedValue, err error) - onActivityCanceledListener func(activityInfo *ActivityInfo) - onLocalActivityStartedListener func(activityInfo *ActivityInfo, ctx context.Context, args []interface{}) - onLocalActivityCompletedListener func(activityInfo *ActivityInfo, result converter.EncodedValue, err error) - onLocalActivityCanceledListener func(activityInfo *ActivityInfo) - onActivityHeartbeatListener func(activityInfo *ActivityInfo, details converter.EncodedValues) - onChildWorkflowStartedListener func(workflowInfo *WorkflowInfo, ctx Context, args converter.EncodedValues) - onChildWorkflowCompletedListener func(workflowInfo *WorkflowInfo, result converter.EncodedValue, err error) - onChildWorkflowCanceledListener func(workflowInfo *WorkflowInfo) - onTimerScheduledListener func(timerID string, duration time.Duration) - onTimerFiredListener func(timerID string) - onTimerCanceledListener func(timerID string) + expectedNexusMockCalls map[string]struct{} + + onActivityStartedListener func(activityInfo *ActivityInfo, ctx context.Context, args converter.EncodedValues) + onActivityCompletedListener func(activityInfo *ActivityInfo, result converter.EncodedValue, err error) + onActivityCanceledListener func(activityInfo *ActivityInfo) + onLocalActivityStartedListener func(activityInfo *ActivityInfo, ctx context.Context, args []interface{}) + onLocalActivityCompletedListener func(activityInfo *ActivityInfo, result converter.EncodedValue, err error) + onLocalActivityCanceledListener func(activityInfo *ActivityInfo) + onActivityHeartbeatListener func(activityInfo *ActivityInfo, details converter.EncodedValues) + onChildWorkflowStartedListener func(workflowInfo *WorkflowInfo, ctx Context, args converter.EncodedValues) + onChildWorkflowCompletedListener func(workflowInfo *WorkflowInfo, result converter.EncodedValue, err error) + onChildWorkflowCanceledListener func(workflowInfo *WorkflowInfo) + onTimerScheduledListener func(timerID string, duration time.Duration) + onTimerFiredListener func(timerID string) + onTimerCanceledListener func(timerID string) + onNexusOperationStartedListener func(service string, operation string, args converter.EncodedValue) + onNexusOperationCompletedListener func(service string, operation string, result converter.EncodedValue, err error) + onNexusOperationCanceledListener func(service string, operation string) } // testWorkflowEnvironmentImpl is the environment that runs the workflow/activity unit tests. @@ -259,10 +280,13 @@ func newTestWorkflowEnvironmentImpl(s *WorkflowTestSuite, parentRegistry *regist localActivities: make(map[string]*localActivityTask), runningWorkflows: make(map[string]*testWorkflowHandle), runningNexusOperations: make(map[int64]*testNexusOperationHandle), + nexusAsyncOpHandle: make(map[string]*testNexusAsyncOperationHandle), + nexusOperationRefs: make(map[string]map[string]testNexusOperationReference), callbackChannel: make(chan testCallbackHandle, 1000), testTimeout: 3 * time.Second, expectedWorkflowMockCalls: make(map[string]struct{}), expectedActivityMockCalls: make(map[string]struct{}), + expectedNexusMockCalls: make(map[string]struct{}), }, workflowInfo: &WorkflowInfo{ @@ -1845,6 +1869,9 @@ func (w *workflowExecutorWrapper) Execute(ctx Context, input *commonpb.Payloads) } func (m *mockWrapper) getCtxArg(ctx interface{}) []interface{} { + if m.fn == nil { + return nil + } fnType := reflect.TypeOf(m.fn) if fnType.NumIn() > 0 { if (!m.isWorkflow && isActivityContext(fnType.In(0))) || @@ -1873,6 +1900,23 @@ func (m *mockWrapper) getWorkflowMockReturn(ctx interface{}, input *commonpb.Pay return m.getMockReturn(ctx, input, m.env.workflowMock) } +func (m *mockWrapper) getNexusMockReturn( + ctx interface{}, + operation string, + input interface{}, + options interface{}, +) (retArgs mock.Arguments) { + if _, ok := m.env.expectedNexusMockCalls[m.name]; !ok { + // no mock + return nil + } + return m.getMockReturnWithActualArgs( + ctx, + []interface{}{operation, input, options}, + m.env.nexusMock, + ) +} + func (m *mockWrapper) getMockReturn(ctx interface{}, input *commonpb.Payloads, envMock *mock.Mock) (retArgs mock.Arguments) { fnType := reflect.TypeOf(m.fn) reflectArgs, err := decodeArgs(m.dataConverter, fnType, input) @@ -2323,18 +2367,10 @@ func (env *testWorkflowEnvironmentImpl) executeChildWorkflowWithDelay(delayStart go childEnv.executeWorkflowInternal(delayStart, params.WorkflowType.Name, params.Input) } -func (env *testWorkflowEnvironmentImpl) newTestNexusTaskHandler() *nexusTaskHandler { - if len(env.registry.nexusServices) == 0 { - panic(fmt.Errorf("no nexus services registered")) - } - - reg := nexus.NewServiceRegistry() - for _, service := range env.registry.nexusServices { - if err := reg.Register(service); err != nil { - panic(fmt.Errorf("failed to register nexus service '%v': %w", service, err)) - } - } - handler, err := reg.NewHandler() +func (env *testWorkflowEnvironmentImpl) newTestNexusTaskHandler( + opHandle *testNexusOperationHandle, +) *nexusTaskHandler { + handler, err := newTestNexusHandler(env, opHandle) if err != nil { panic(fmt.Errorf("failed to create nexus handler: %w", err)) } @@ -2357,7 +2393,6 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation( startedHandler func(opID string, e error), ) int64 { seq := env.nextID() - taskHandler := env.newTestNexusTaskHandler() // Use lower case header values to simulate how the Nexus SDK (used internally by the "real" server) would transmit // these headers over the wire. nexusHeader := make(map[string]string, len(params.nexusHeader)) @@ -2376,7 +2411,8 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation( onCompleted: callback, onStarted: startedHandler, } - env.runningNexusOperations[seq] = handle + taskHandler := env.newTestNexusTaskHandler(handle) + env.setNexusOperationHandle(seq, handle) var opID string if params.options.ScheduleToCloseTimeout > 0 { @@ -2442,9 +2478,11 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation( case *nexuspb.StartOperationResponse_AsyncSuccess: env.postCallback(func() { opID = v.AsyncSuccess.GetOperationId() - handle.startedCallback(v.AsyncSuccess.GetOperationId(), nil) + handle.startedCallback(opID, nil) if handle.cancelRequested { handle.cancel() + } else if handle.isMocked { + env.scheduleNexusAsyncOperationCompletion(handle) } }, true) case *nexuspb.StartOperationResponse_OperationError: @@ -2463,7 +2501,7 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation( } func (env *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) { - handle, ok := env.runningNexusOperations[seq] + handle, ok := env.getNexusOperationHandle(seq) if !ok { panic(fmt.Errorf("no running operation found for sequence: %d", seq)) } @@ -2483,9 +2521,123 @@ func (env *testWorkflowEnvironmentImpl) RequestCancelNexusOperation(seq int64) { } } +func (env *testWorkflowEnvironmentImpl) RegisterNexusAsyncOperationCompletion( + service string, + operation string, + operationID string, + result any, + err error, + delay time.Duration, +) error { + opRef := env.nexusOperationRefs[service][operation] + if opRef == nil { + return fmt.Errorf("nexus service %q operation %q not mocked", service, operation) + } + if reflect.TypeOf(result) != opRef.OutputType() { + return fmt.Errorf( + "nexus service %q operation %q expected result type %s, got %T", + service, + operation, + opRef.OutputType(), + result, + ) + } + + var data *commonpb.Payload + if result != nil { + var encodeErr error + data, encodeErr = env.GetDataConverter().ToPayload(result) + if encodeErr != nil { + return encodeErr + } + } + + // Getting the locker to prevent race condition if this function is called while + // the test env is already running. + env.locker.Lock() + defer env.locker.Unlock() + env.setNexusAsyncOperationCompletionHandle( + service, + operation, + operationID, + &testNexusAsyncOperationHandle{ + result: data, + err: err, + delay: delay, + }, + ) + return nil +} + +func (env *testWorkflowEnvironmentImpl) getNexusAsyncOperationCompletionHandle( + service string, + operation string, + operationID string, +) *testNexusAsyncOperationHandle { + uniqueOpID := env.makeUniqueNexusOperationID(service, operation, operationID) + return env.nexusAsyncOpHandle[uniqueOpID] +} + +func (env *testWorkflowEnvironmentImpl) setNexusAsyncOperationCompletionHandle( + service string, + operation string, + operationID string, + handle *testNexusAsyncOperationHandle, +) { + uniqueOpID := env.makeUniqueNexusOperationID(service, operation, operationID) + env.nexusAsyncOpHandle[uniqueOpID] = handle +} + +func (env *testWorkflowEnvironmentImpl) deleteNexusAsyncOperationCompletionHandle( + service string, + operation string, + operationID string, +) { + uniqueOpID := env.makeUniqueNexusOperationID(service, operation, operationID) + delete(env.nexusAsyncOpHandle, uniqueOpID) +} + +func (env *testWorkflowEnvironmentImpl) scheduleNexusAsyncOperationCompletion( + handle *testNexusOperationHandle, +) { + completionHandle := env.getNexusAsyncOperationCompletionHandle( + handle.params.client.Service(), + handle.params.operation, + handle.operationID, + ) + if completionHandle == nil { + return + } + env.deleteNexusAsyncOperationCompletionHandle( + handle.params.client.Service(), + handle.params.operation, + handle.operationID, + ) + var nexusErr error + if completionHandle.err != nil { + nexusErr = env.failureConverter.FailureToError(nexusOperationFailure( + handle.params, + handle.operationID, + &failurepb.Failure{ + Message: completionHandle.err.Error(), + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + }, + }, + }, + )) + } + env.registerDelayedCallback(func() { + env.postCallback(func() { + handle.completedCallback(completionHandle.result, nexusErr) + }, true) + }, completionHandle.delay) +} + func (env *testWorkflowEnvironmentImpl) resolveNexusOperation(seq int64, result *commonpb.Payload, err error) { env.postCallback(func() { - handle, ok := env.runningNexusOperations[seq] + handle, ok := env.getNexusOperationHandle(seq) if !ok { panic(fmt.Errorf("no running operation found for sequence: %d", seq)) } @@ -2499,6 +2651,32 @@ func (env *testWorkflowEnvironmentImpl) resolveNexusOperation(seq int64, result }, true) } +func (env *testWorkflowEnvironmentImpl) getNexusOperationHandle( + seqID int64, +) (*testNexusOperationHandle, bool) { + handle, ok := env.runningNexusOperations[seqID] + return handle, ok +} + +func (env *testWorkflowEnvironmentImpl) setNexusOperationHandle( + seqID int64, + handle *testNexusOperationHandle, +) { + env.runningNexusOperations[seqID] = handle +} + +func (env *testWorkflowEnvironmentImpl) deleteNexusOperationHandle(seqID int64) { + delete(env.runningNexusOperations, seqID) +} + +func (env *testWorkflowEnvironmentImpl) makeUniqueNexusOperationID( + service string, + operation string, + operationID string, +) string { + return fmt.Sprintf("%s_%s_%s", service, operation, operationID) +} + func (env *testWorkflowEnvironmentImpl) SideEffect(f func() (*commonpb.Payloads, error), callback ResultHandler) { callback(f()) } @@ -2792,6 +2970,18 @@ func (env *testWorkflowEnvironmentImpl) getActivityMockRunFn(callWrapper *MockCa } } +func (env *testWorkflowEnvironmentImpl) getNexusOperationMockRunFn( + callWrapper *MockCallWrapper, +) func(args mock.Arguments) { + env.locker.Lock() + defer env.locker.Unlock() + + env.expectedNexusMockCalls[callWrapper.call.Method] = struct{}{} + return func(args mock.Arguments) { + env.runBeforeMockCallReturns(callWrapper, args) + } +} + func (env *testWorkflowEnvironmentImpl) setLastCompletionResult(result interface{}) { data, err := encodeArg(env.GetDataConverter(), result) if err != nil { @@ -2927,8 +3117,19 @@ func (h *testNexusOperationHandle) completedCallback(result *commonpb.Payload, e return } h.done = true - delete(h.env.runningNexusOperations, h.seq) + h.env.deleteNexusOperationHandle(h.seq) h.onCompleted(result, err) + if h.env.onNexusOperationCompletedListener != nil { + h.env.onNexusOperationCompletedListener( + h.params.client.Service(), + h.params.operation, + newEncodedValue( + &commonpb.Payloads{Payloads: []*commonpb.Payload{result}}, + h.env.GetDataConverter(), + ), + err, + ) + } } // startedCallback is a callback registered to handle operation start. @@ -2954,7 +3155,7 @@ func (h *testNexusOperationHandle) cancel() { } h.env.runningCount++ task := h.newCancelTask() - taskHandler := h.env.newTestNexusTaskHandler() + taskHandler := h.env.newTestNexusTaskHandler(h) go func() { _, failure, err := taskHandler.Execute(task) @@ -2967,6 +3168,253 @@ func (h *testNexusOperationHandle) cancel() { h.completedCallback(nil, fmt.Errorf("operation cancelation handler failed: %v", failure.GetError().GetFailure().GetMessage())) } h.env.runningCount-- + if h.env.onNexusOperationCanceledListener != nil { + h.env.onNexusOperationCanceledListener(h.params.client.Service(), h.params.operation) + } }, false) }() } + +type testNexusHandler struct { + nexus.UnimplementedHandler + + env *testWorkflowEnvironmentImpl + opHandle *testNexusOperationHandle + handler nexus.Handler +} + +func newTestNexusHandler( + env *testWorkflowEnvironmentImpl, + opHandle *testNexusOperationHandle, +) (nexus.Handler, error) { + nexusServices := env.registry.getRegisteredNexusServices() + if len(nexusServices) == 0 { + panic(fmt.Errorf("no nexus services registered")) + } + + reg := nexus.NewServiceRegistry() + for _, service := range nexusServices { + if err := reg.Register(service); err != nil { + return nil, fmt.Errorf("failed to register nexus service '%v': %w", service, err) + } + } + handler, err := reg.NewHandler() + if err != nil { + return nil, fmt.Errorf("failed to create nexus handler: %w", err) + } + return &testNexusHandler{ + env: env, + opHandle: opHandle, + handler: handler, + }, nil +} + +func (r *testNexusHandler) StartOperation( + ctx context.Context, + service string, + operation string, + input *nexus.LazyValue, + options nexus.StartOperationOptions, +) (nexus.HandlerStartOperationResult[any], error) { + s := r.env.registry.getNexusService(service) + if s == nil { + panic(fmt.Sprintf( + "nexus service %q is not registered with the TestWorkflowEnvironment", + service, + )) + } + + opRef := r.env.nexusOperationRefs[service][operation] + op := s.Operation(operation) + if opRef == nil { + if op == nil { + panic(fmt.Sprintf( + "nexus service %q operation %q not registered and not mocked", + service, + operation, + )) + } + opRef = op.(testNexusOperationReference) + } + + inputPtr := reflect.New(opRef.InputType()) + err := input.Consume(inputPtr.Interface()) + if err != nil { + panic("mock of ExecuteNexusOperation failed to deserialize input") + } + + // rebuild the input as *nexus.LazyValue + payload, err := r.env.dataConverter.ToPayload(inputPtr.Elem().Interface()) + if err != nil { + // this should not be possible + panic("mock of ExecuteNexusOperation failed to convert input to payload") + } + serializer := &payloadSerializer{ + converter: r.env.dataConverter, + payload: payload, + } + input = nexus.NewLazyValue( + serializer, + &nexus.Reader{ + ReadCloser: emptyReaderNopCloser, + }, + ) + + if r.env.onNexusOperationStartedListener != nil { + waitCh := make(chan struct{}) + r.env.postCallback(func() { + r.env.onNexusOperationStartedListener( + service, + operation, + newEncodedValue( + &commonpb.Payloads{Payloads: []*commonpb.Payload{payload}}, + r.env.GetDataConverter(), + ), + ) + close(waitCh) + }, false) + <-waitCh // wait until listener returns + } + + m := &mockWrapper{ + env: r.env, + name: service, + fn: nil, + isWorkflow: false, + dataConverter: r.env.dataConverter, + } + mockRet := m.getNexusMockReturn( + ctx, + operation, + inputPtr.Elem().Interface(), + r.opHandle.params.options, + ) + if mockRet != nil { + mockRetLen := len(mockRet) + if mockRetLen != 2 { + panic(fmt.Sprintf( + "mock of ExecuteNexusOperation has incorrect number of return values, expected 2, got %d", + mockRetLen, + )) + } + + // we already verified function has 2 return values (result, error) + mockErr := mockRet[1] // last mock return must be error + if mockErr != nil { + if err, ok := mockErr.(error); ok { + return nil, err + } + panic(fmt.Sprintf( + "mock of ExecuteNexusOperation has incorrect return type, expected error, got %T", + mockErr, + )) + } + + mockResult := mockRet[0] + result, ok := mockResult.(nexus.HandlerStartOperationResult[any]) + if mockResult != nil && !ok { + panic(fmt.Sprintf( + "mock of ExecuteNexusOperation has incorrect return type, expected nexus.HandlerStartOperationResult[T], but actual is %T", + mockResult, + )) + } + + // If the result is nexus.HandlerStartOperationResultSync, check the result value type + // matches the operation return type. + value := reflect.ValueOf(result).Elem().FieldByName("Value") + if (value != reflect.Value{}) { + if value.Type() != opRef.OutputType() { + panic(fmt.Sprintf( + "mock of ExecuteNexusOperation has incorrect return type, operation expects to return %s, got %s", + opRef.OutputType(), + value.Type(), + )) + } + } + + r.opHandle.isMocked = true + return result, nil + } + + return r.handler.StartOperation(ctx, service, operation, input, options) +} + +func (r *testNexusHandler) CancelOperation( + ctx context.Context, + service string, + operation string, + operationID string, + options nexus.CancelOperationOptions, +) error { + if r.opHandle.isMocked { + // if the operation was mocked, then there's no workflow running + return nil + } + return r.handler.CancelOperation(ctx, service, operation, operationID, options) +} + +func (r *testNexusHandler) GetOperationInfo( + ctx context.Context, + service string, + operation string, + operationID string, + options nexus.GetOperationInfoOptions, +) (*nexus.OperationInfo, error) { + return r.handler.GetOperationInfo(ctx, service, operation, operationID, options) +} + +func (r *testNexusHandler) GetOperationResult( + ctx context.Context, + service string, + operation string, + operationID string, + options nexus.GetOperationResultOptions, +) (any, error) { + return r.handler.GetOperationResult(ctx, service, operation, operationID, options) +} + +func (env *testWorkflowEnvironmentImpl) registerNexusOperationReference( + service string, + opRef testNexusOperationReference, +) { + if service == "" { + panic("tried to register a service with no name") + } + if opRef.Name() == "" { + panic("tried to register an operation with no name") + } + m := env.nexusOperationRefs[service] + if m == nil { + m = make(map[string]testNexusOperationReference) + env.nexusOperationRefs[service] = m + } + m[opRef.Name()] = opRef +} + +// testNexusOperation implements nexus.RegisterableOperation and serves as dummy +// operation that can be created from a testNexusOperationReference, so that +// mocked Nexus operations can be registered in a Nexus service. +type testNexusOperation struct { + nexus.UnimplementedOperation[any, any] + testNexusOperationReference +} + +var _ nexus.RegisterableOperation = (*testNexusOperation)(nil) + +func (o *testNexusOperation) Name() string { + return o.testNexusOperationReference.Name() +} + +func (o *testNexusOperation) InputType() reflect.Type { + return o.testNexusOperationReference.InputType() +} + +func (o *testNexusOperation) OutputType() reflect.Type { + return o.testNexusOperationReference.OutputType() +} + +func newTestNexusOperation(opRef testNexusOperationReference) *testNexusOperation { + return &testNexusOperation{ + testNexusOperationReference: opRef, + } +} diff --git a/internal/workflow_testsuite.go b/internal/workflow_testsuite.go index 1b4c2ae9a..4775cf22f 100644 --- a/internal/workflow_testsuite.go +++ b/internal/workflow_testsuite.go @@ -66,6 +66,7 @@ type ( TestWorkflowEnvironment struct { workflowMock mock.Mock activityMock mock.Mock + nexusMock mock.Mock impl *testWorkflowEnvironmentImpl } @@ -543,6 +544,139 @@ func (e *TestWorkflowEnvironment) OnUpsertMemo(attributes interface{}) *MockCall return e.wrapWorkflowCall(call) } +// OnNexusOperation setup a mock call for Nexus operation. +// Parameter service must be Nexus service (*nexus.Service) or service name (string). +// Parameter operation must be Nexus operation (nexus.RegisterableOperation), Nexus operation +// reference (nexus.OperationReference), or operation name (string). +// You must call Return() with appropriate parameters on the returned *MockCallWrapper instance. +// The first parameter of Return() is the result of type nexus.HandlerStartOperationResult[T], ie., +// it must be *nexus.HandlerStartOperationResultSync[T] or *nexus.HandlerStartOperationResultAsync. +// The second parameter of Return() is an error. +// If your mock returns *nexus.HandlerStartOperationResultAsync, then you need to register the +// completion of the async operation by calling RegisterNexusAsyncOperationCompletion. +// Example: assume the Nexus operation input/output types are as follows: +// +// type ( +// HelloInput struct { +// Message string +// } +// HelloOutput struct { +// Message string +// } +// ) +// +// Then, you can mock workflow.NexusClient.ExecuteOperation as follows: +// +// t.OnNexusOperation( +// "my-service", +// nexus.NewOperationReference[HelloInput, HelloOutput]("hello-operation"), +// HelloInput{Message: "Temporal"}, +// mock.Anything, // NexusOperationOptions +// ).Return( +// &nexus.HandlerStartOperationResultAsync{ +// OperationID: "hello-operation-id", +// }, +// nil, +// ) +// t.RegisterNexusAsyncOperationCompletion( +// "service-name", +// "hello-operation", +// "hello-operation-id", +// HelloOutput{Message: "Hello Temporal"}, +// nil, +// 1*time.Second, +// ) +func (e *TestWorkflowEnvironment) OnNexusOperation( + service any, + operation any, + input any, + options any, +) *MockCallWrapper { + var s *nexus.Service + switch stp := service.(type) { + case *nexus.Service: + s = stp + if e.impl.registry.getNexusService(s.Name) == nil { + e.impl.RegisterNexusService(s) + } + case string: + s = e.impl.registry.getNexusService(stp) + if s == nil { + s = nexus.NewService(stp) + e.impl.RegisterNexusService(s) + } + default: + panic("service must be *nexus.Service or string") + } + + var opRef testNexusOperationReference + switch otp := operation.(type) { + case testNexusOperationReference: + // This case covers both nexus.RegisterableOperation and nexus.OperationReference. + // All nexus.RegisterableOperation embeds nexus.UnimplementedOperation which + // implements nexus.OperationReference. + opRef = otp + if s.Operation(opRef.Name()) == nil { + if err := s.Register(newTestNexusOperation(opRef)); err != nil { + panic(fmt.Sprintf("cannot register operation %q: %v", opRef.Name(), err.Error())) + } + } + case string: + if op := s.Operation(otp); op != nil { + opRef = op.(testNexusOperationReference) + } else { + panic(fmt.Sprintf("operation %q not registered in service %q", otp, s.Name)) + } + default: + panic("operation must be nexus.RegisterableOperation, nexus.OperationReference, or string") + } + e.impl.registerNexusOperationReference(s.Name, opRef) + + if input != mock.Anything { + if opRef.InputType() != reflect.TypeOf(input) { + panic(fmt.Sprintf( + "operation %q expects input type %s, got %T", + opRef.Name(), + opRef.InputType(), + input, + )) + } + } + + if options != mock.Anything { + if _, ok := options.(NexusOperationOptions); !ok { + panic(fmt.Sprintf( + "options must be an instance of NexusOperationOptions or mock.Anything, got %T", + options, + )) + } + } + + call := e.nexusMock.On(s.Name, opRef.Name(), input, options) + return e.wrapNexusOperationCall(call) +} + +// RegisterNexusAsyncOperationCompletion registers a delayed completion of an Nexus async operation. +// The delay is counted from the moment the Nexus async operation starts. See the documentation of +// OnNexusOperation for an example. +func (e *TestWorkflowEnvironment) RegisterNexusAsyncOperationCompletion( + service string, + operation string, + operationID string, + result any, + err error, + delay time.Duration, +) error { + return e.impl.RegisterNexusAsyncOperationCompletion( + service, + operation, + operationID, + result, + err, + delay, + ) +} + func (e *TestWorkflowEnvironment) wrapWorkflowCall(call *mock.Call) *MockCallWrapper { callWrapper := &MockCallWrapper{call: call, env: e} call.Run(e.impl.getWorkflowMockRunFn(callWrapper)) @@ -555,6 +689,12 @@ func (e *TestWorkflowEnvironment) wrapActivityCall(call *mock.Call) *MockCallWra return callWrapper } +func (e *TestWorkflowEnvironment) wrapNexusOperationCall(call *mock.Call) *MockCallWrapper { + callWrapper := &MockCallWrapper{call: call, env: e} + call.Run(e.impl.getNexusOperationMockRunFn(callWrapper)) + return callWrapper +} + // Once indicates that the mock should only return the value once. func (c *MockCallWrapper) Once() *MockCallWrapper { return c.Times(1) @@ -632,6 +772,7 @@ func (c *MockCallWrapper) NotBefore(calls ...*MockCallWrapper) *MockCallWrapper func (e *TestWorkflowEnvironment) ExecuteWorkflow(workflowFn interface{}, args ...interface{}) { e.impl.workflowMock = &e.workflowMock e.impl.activityMock = &e.activityMock + e.impl.nexusMock = &e.nexusMock e.impl.executeWorkflow(workflowFn, args...) } @@ -827,6 +968,27 @@ func (e *TestWorkflowEnvironment) SetOnLocalActivityCanceledListener( return e } +func (e *TestWorkflowEnvironment) SetOnNexusOperationStartedListener( + listener func(service string, operation string, input converter.EncodedValue), +) *TestWorkflowEnvironment { + e.impl.onNexusOperationStartedListener = listener + return e +} + +func (e *TestWorkflowEnvironment) SetOnNexusOperationCompletedListener( + listener func(service string, operation string, result converter.EncodedValue, err error), +) *TestWorkflowEnvironment { + e.impl.onNexusOperationCompletedListener = listener + return e +} + +func (e *TestWorkflowEnvironment) SetOnNexusOperationCanceledListener( + listener func(service string, operation string), +) *TestWorkflowEnvironment { + e.impl.onNexusOperationCanceledListener = listener + return e +} + // IsWorkflowCompleted check if test is completed or not func (e *TestWorkflowEnvironment) IsWorkflowCompleted() bool { return e.impl.isWorkflowCompleted @@ -982,13 +1144,15 @@ func (e *TestWorkflowEnvironment) SetTypedSearchAttributesOnStart(searchAttribut return nil } -// AssertExpectations asserts that everything specified with OnActivity -// in fact called as expected. Calls may have occurred in any order. +// AssertExpectations asserts that everything specified with OnWorkflow, OnActivity, OnNexusOperation +// was in fact called as expected. Calls may have occurred in any order. func (e *TestWorkflowEnvironment) AssertExpectations(t mock.TestingT) bool { - return e.workflowMock.AssertExpectations(t) && e.activityMock.AssertExpectations(t) + return e.workflowMock.AssertExpectations(t) && + e.activityMock.AssertExpectations(t) && + e.nexusMock.AssertExpectations(t) } -// AssertCalled asserts that the method was called with the supplied arguments. +// AssertCalled asserts that the method (workflow or activity) was called with the supplied arguments. // Useful to assert that an Activity was called from within a workflow with the expected arguments. // Since the first argument is a context, consider using mock.Anything for that argument. // @@ -999,10 +1163,10 @@ func (e *TestWorkflowEnvironment) AssertExpectations(t mock.TestingT) bool { // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (e *TestWorkflowEnvironment) AssertCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool { dummyT := &testing.T{} - if !(e.workflowMock.AssertCalled(dummyT, methodName, arguments...) || e.activityMock.AssertCalled(dummyT, methodName, arguments...)) { - return e.workflowMock.AssertCalled(t, methodName, arguments...) && e.activityMock.AssertCalled(t, methodName, arguments...) - } - return true + return e.AssertWorkflowCalled(dummyT, methodName, arguments...) || + e.AssertActivityCalled(dummyT, methodName, arguments...) || + e.AssertWorkflowCalled(t, methodName, arguments...) || + e.AssertActivityCalled(t, methodName, arguments...) } // AssertWorkflowCalled asserts that the workflow method was called with the supplied arguments. @@ -1017,14 +1181,15 @@ func (e *TestWorkflowEnvironment) AssertActivityCalled(t mock.TestingT, methodNa return e.activityMock.AssertCalled(t, methodName, arguments...) } -// AssertNotCalled asserts that the method was not called with the given arguments. +// AssertNotCalled asserts that the method (workflow or activity) was not called with the given arguments. // See AssertCalled for more info. func (e *TestWorkflowEnvironment) AssertNotCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool { dummyT := &testing.T{} - if !(e.workflowMock.AssertNotCalled(dummyT, methodName, arguments...) || e.activityMock.AssertNotCalled(dummyT, methodName, arguments...)) { - return e.workflowMock.AssertNotCalled(t, methodName, arguments...) && e.activityMock.AssertNotCalled(t, methodName, arguments...) - } - return true + // Calling the individual functions instead of negating AssertCalled so the error message is more clear. + return e.AssertWorkflowNotCalled(dummyT, methodName, arguments...) && + e.AssertActivityNotCalled(dummyT, methodName, arguments...) && + e.AssertWorkflowNotCalled(t, methodName, arguments...) && + e.AssertActivityNotCalled(t, methodName, arguments...) } // AssertWorkflowNotCalled asserts that the workflow method was not called with the given arguments. @@ -1041,13 +1206,13 @@ func (e *TestWorkflowEnvironment) AssertActivityNotCalled(t mock.TestingT, metho return e.activityMock.AssertNotCalled(t, methodName, arguments...) } -// AssertNumberOfCalls asserts that a method was called expectedCalls times. +// AssertNumberOfCalls asserts that a method (workflow or activity) was called expectedCalls times. func (e *TestWorkflowEnvironment) AssertNumberOfCalls(t mock.TestingT, methodName string, expectedCalls int) bool { dummyT := &testing.T{} - if !(e.workflowMock.AssertNumberOfCalls(dummyT, methodName, expectedCalls) || e.activityMock.AssertNumberOfCalls(dummyT, methodName, expectedCalls)) { - return e.workflowMock.AssertNumberOfCalls(t, methodName, expectedCalls) && e.activityMock.AssertNumberOfCalls(t, methodName, expectedCalls) - } - return true + return e.workflowMock.AssertNumberOfCalls(dummyT, methodName, expectedCalls) || + e.activityMock.AssertNumberOfCalls(dummyT, methodName, expectedCalls) || + e.workflowMock.AssertNumberOfCalls(t, methodName, expectedCalls) || + e.activityMock.AssertNumberOfCalls(t, methodName, expectedCalls) } // AssertWorkflowNumberOfCalls asserts that a workflow method was called expectedCalls times. @@ -1061,3 +1226,22 @@ func (e *TestWorkflowEnvironment) AssertWorkflowNumberOfCalls(t mock.TestingT, m func (e *TestWorkflowEnvironment) AssertActivityNumberOfCalls(t mock.TestingT, methodName string, expectedCalls int) bool { return e.activityMock.AssertNumberOfCalls(t, methodName, expectedCalls) } + +// AssertNexusOperationCalled asserts that the Nexus operation was called with the supplied arguments. +// Special method for Nexus operations only. +func (e *TestWorkflowEnvironment) AssertNexusOperationCalled(t mock.TestingT, service string, operation string, input any, options any) bool { + return e.nexusMock.AssertCalled(t, service, operation, input, options) +} + +// AssertNexusOperationNotCalled asserts that the Nexus operation was called with the supplied arguments. +// Special method for Nexus operations only. +// See AssertNexusOperationCalled for more info. +func (e *TestWorkflowEnvironment) AssertNexusOperationNotCalled(t mock.TestingT, service string, operation string, input any, options any) bool { + return e.nexusMock.AssertNotCalled(t, service, operation, input, options) +} + +// AssertNexusOperationNumberOfCalls asserts that a Nexus operation was called expectedCalls times. +// Special method for Nexus operation only. +func (e *TestWorkflowEnvironment) AssertNexusOperationNumberOfCalls(t mock.TestingT, service string, expectedCalls int) bool { + return e.nexusMock.AssertNumberOfCalls(t, service, expectedCalls) +} diff --git a/test/nexus_test.go b/test/nexus_test.go index bbe7c0bd1..1a66e0b1c 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -34,6 +34,7 @@ import ( "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.temporal.io/api/common/v1" "go.temporal.io/api/enums/v1" @@ -43,6 +44,7 @@ import ( "go.temporal.io/api/serviceerror" "go.temporal.io/sdk/client" + "go.temporal.io/sdk/converter" "go.temporal.io/sdk/interceptor" "go.temporal.io/sdk/internal/common/metrics" ilog "go.temporal.io/sdk/internal/log" @@ -1227,6 +1229,306 @@ func TestWorkflowTestSuite_NexusSyncOperation_ClientMethods_Panic(t *testing.T) require.Equal(t, "not implemented in the test environment", panicReason) } +func TestWorkflowTestSuite_MockNexusOperation(t *testing.T) { + serviceName := "test" + dummyOpName := "dummy-operation" + dummyOp := nexus.NewSyncOperation( + dummyOpName, + func(ctx context.Context, name string, opts nexus.StartOperationOptions) (string, error) { + return "Hello " + name, nil + }, + ) + + wf := func(ctx workflow.Context, name string) (string, error) { + client := workflow.NewNexusClient("endpoint", serviceName) + fut := client.ExecuteOperation( + ctx, + dummyOp, + name, + workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: 2 * time.Second, + }, + ) + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return "", err + } + var res string + if err := fut.Get(ctx, &res); err != nil { + return "", err + } + return res, nil + } + + service := nexus.NewService(serviceName) + service.Register(dummyOp) + + t.Run("mock result sync", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation( + service, + dummyOp, + "Temporal", + workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: 2 * time.Second, + }, + ).Return( + &nexus.HandlerStartOperationResultSync[string]{ + Value: "fake result", + }, + nil, + ) + + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var res string + require.NoError(t, env.GetWorkflowResult(&res)) + require.Equal(t, "fake result", res) + + env.AssertExpectations(t) + env.AssertNexusOperationNumberOfCalls(t, service.Name, 1) + env.AssertNexusOperationCalled(t, service.Name, dummyOp.Name(), "Temporal", mock.Anything) + env.AssertNexusOperationNotCalled(t, service.Name, dummyOp.Name(), "random", mock.Anything) + }) + + t.Run("mock result async", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation(service, dummyOp, "Temporal", mock.Anything).Return( + &nexus.HandlerStartOperationResultAsync{ + OperationID: "operation-id", + }, + nil, + ) + require.NoError(t, env.RegisterNexusAsyncOperationCompletion( + service.Name, + dummyOp.Name(), + "operation-id", + "fake result", + nil, + 0, + )) + + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var res string + require.NoError(t, env.GetWorkflowResult(&res)) + require.Equal(t, "fake result", res) + }) + + t.Run("mock operation reference", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.OnNexusOperation( + serviceName, + nexus.NewOperationReference[string, string](dummyOpName), + "Temporal", + mock.Anything, + ).Return( + &nexus.HandlerStartOperationResultSync[string]{ + Value: "fake result", + }, + nil, + ) + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var res string + require.NoError(t, env.GetWorkflowResult(&res)) + require.Equal(t, "fake result", res) + }) + + t.Run("mock operation reference existing service", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation( + serviceName, + nexus.NewOperationReference[string, string](dummyOpName), + "Temporal", + mock.Anything, + ).Return( + &nexus.HandlerStartOperationResultSync[string]{ + Value: "fake result", + }, + nil, + ) + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var res string + require.NoError(t, env.GetWorkflowResult(&res)) + require.Equal(t, "fake result", res) + }) + + t.Run("mock error operation", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation(service, dummyOp, "Temporal", mock.Anything).Return( + nil, + errors.New("workflow operation failed"), + ) + + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.ErrorContains(t, env.GetWorkflowError(), "workflow operation failed") + }) + + t.Run("mock error handler", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation(service, dummyOp, "Temporal", mock.Anything).Return( + &nexus.HandlerStartOperationResultAsync{ + OperationID: "operation-id", + }, + nil, + ) + require.NoError(t, env.RegisterNexusAsyncOperationCompletion( + serviceName, + dummyOpName, + "operation-id", + "", + errors.New("workflow handler failed"), + 1*time.Second, + )) + + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + var execErr *temporal.WorkflowExecutionError + err := env.GetWorkflowError() + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.ErrorContains(t, opErr, "workflow handler failed") + }) + + t.Run("mock after ok", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation( + service, + dummyOp, + "Temporal", + workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: 2 * time.Second, + }, + ).After(1*time.Second).Return( + &nexus.HandlerStartOperationResultSync[string]{ + Value: "fake result", + }, + nil, + ) + + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var res string + require.NoError(t, env.GetWorkflowResult(&res)) + require.Equal(t, "fake result", res) + }) + + t.Run("mock after timeout", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation( + service, + dummyOp, + "Temporal", + workflow.NexusOperationOptions{ + ScheduleToCloseTimeout: 2 * time.Second, + }, + ).After(3*time.Second).Return( + &nexus.HandlerStartOperationResultSync[string]{ + Value: "fake result", + }, + nil, + ) + + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + var execErr *temporal.WorkflowExecutionError + err := env.GetWorkflowError() + require.ErrorAs(t, err, &execErr) + var opErr *temporal.NexusOperationError + err = execErr.Unwrap() + require.ErrorAs(t, err, &opErr) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + err = opErr.Unwrap() + var timeoutErr *temporal.TimeoutError + require.ErrorAs(t, err, &timeoutErr) + require.Equal(t, "operation timed out", timeoutErr.Message()) + }) +} + +func TestWorkflowTestSuite_NexusListeners(t *testing.T) { + startedListenerCalled := false + completedListenerCalled := false + handlerWf := func(ctx workflow.Context, _ nexus.NoValue) (nexus.NoValue, error) { + require.True(t, startedListenerCalled) + require.False(t, completedListenerCalled) + return nil, nil + } + op := temporalnexus.NewWorkflowRunOperation( + "op", + handlerWf, + func( + ctx context.Context, + _ nexus.NoValue, + opts nexus.StartOperationOptions, + ) (client.StartWorkflowOptions, error) { + return client.StartWorkflowOptions{ID: opts.RequestID}, nil + }, + ) + + callerWf := func(ctx workflow.Context) error { + client := workflow.NewNexusClient("endpoint", "test") + fut := client.ExecuteOperation(ctx, op, nil, workflow.NexusOperationOptions{}) + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return err + } + err := fut.Get(ctx, nil) + require.True(t, completedListenerCalled) + return err + } + + service := nexus.NewService("test") + service.Register(op) + + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(handlerWf) + env.RegisterWorkflow(callerWf) + env.RegisterNexusService(service) + + env.SetOnNexusOperationStartedListener( + func(service, operation string, input converter.EncodedValue) { + startedListenerCalled = true + }, + ) + env.SetOnNexusOperationCompletedListener( + func(service, operation string, result converter.EncodedValue, err error) { + completedListenerCalled = true + }, + ) + + env.ExecuteWorkflow(callerWf) + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.True(t, startedListenerCalled) + require.True(t, completedListenerCalled) +} + type nexusInterceptor struct { interceptor.WorkerInterceptorBase interceptor.WorkflowInboundInterceptorBase