Skip to content

Commit

Permalink
add check for at least one successful workflow invoke and test
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidOrchard committed Sep 30, 2024
1 parent 0149b98 commit c306788
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
21 changes: 15 additions & 6 deletions core/capabilities/webapi/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (h *triggerConnectorHandler) processTrigger(ctx context.Context, gatewayID
// Pass on the payload with the expectation that it's in an acceptable format for the executor
wrappedPayload, err := values.WrapMap(payload)
if err != nil {
return fmt.Errorf("e rror wrapping payload %s", err)
return fmt.Errorf("error wrapping payload %s", err)
}
topics := payload.Topics

Expand All @@ -90,19 +90,24 @@ func (h *triggerConnectorHandler) processTrigger(ctx context.Context, gatewayID
return fmt.Errorf("empty Workflow Topics")
}

// workflows that have matched topics
matchedWorkflows := 0
// workflows that have matched topic and passed all checks
fullyMatchedWorkflows := 0
for _, trigger := range h.registeredWorkflows {
for _, topic := range topics {
if trigger.allowedTopics[topic] {
matchedWorkflows++

if !trigger.allowedSenders[sender.String()] {
return fmt.Errorf("unauthorized Sender %s, messageID %s", sender.String(), body.MessageId)
err = fmt.Errorf("unauthorized Sender %s, messageID %s", sender.String(), body.MessageId)
h.lggr.Debugw(err.Error())
continue
}
if !trigger.rateLimiter.Allow(body.Sender) {
return fmt.Errorf("request rate-limited for sender %s, messageID %s", sender.String(), body.MessageId)
err = fmt.Errorf("request rate-limited for sender %s, messageID %s", sender.String(), body.MessageId)
continue
}

fullyMatchedWorkflows++
TriggerEventID := body.Sender + payload.TriggerEventID
tr := capabilities.TriggerResponse{
Event: capabilities.TriggerEvent{
Expand All @@ -124,7 +129,11 @@ func (h *triggerConnectorHandler) processTrigger(ctx context.Context, gatewayID
if matchedWorkflows == 0 {
return fmt.Errorf("no Matching Workflow Topics")
}
return nil

if fullyMatchedWorkflows > 0 {
return nil
}
return err
}

func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) {
Expand Down
58 changes: 58 additions & 0 deletions core/capabilities/webapi/trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (
workflowExecutionID1 = "95ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0abbadeed"
owner1 = "0x00000000000000000000000000000000000000aa"
address1 = "0x853d51d5d9935964267a5050aC53aa63ECA39bc5"
address2 = "0x853d51d5d9935964267a5050aC53aa63ECA39bc6"
)

type testHarness struct {
Expand Down Expand Up @@ -299,6 +300,63 @@ func TestRegisterNoAllowedSenders(t *testing.T) {
gatewayRequest(t, privateKey1, `["daily_price_update"]`, "")
}

func TestTriggerExecute2WorkflowsSameTopicDifferentAllowLists(t *testing.T) {
th := setup(t)
ctx := testutils.Context(t)
ctx, cancelContext := context.WithDeadline(ctx, time.Now().Add(10*time.Second))
Config, _ := workflowTriggerConfig(th, []string{address2}, []string{"daily_price_update"})
triggerReq := capabilities.TriggerRegistrationRequest{
TriggerID: triggerID1,
Metadata: capabilities.RequestMetadata{
WorkflowID: workflowID1,
WorkflowOwner: owner1,
},
Config: Config,
}
channel, err := th.trigger.RegisterTrigger(ctx, triggerReq)
require.NoError(t, err)

Config2, err := workflowTriggerConfig(th, []string{address1}, []string{"daily_price_update"})
require.NoError(t, err)

triggerReq2 := capabilities.TriggerRegistrationRequest{
TriggerID: triggerID2,
Metadata: capabilities.RequestMetadata{
WorkflowID: workflowID1,
WorkflowOwner: owner1,
},
Config: Config2,
}
channel2, err := th.trigger.RegisterTrigger(ctx, triggerReq2)
require.NoError(t, err)

t.Run("happy case single topic to single workflow", func(t *testing.T) {
gatewayRequest := gatewayRequest(t, privateKey1, `["daily_price_update"]`, "")

th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
resp, _ := getResponseFromArg(args.Get(2))
require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp)
}).Return(nil).Once()

th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest)

requireNoChanMsg(t, channel)
received, chanErr := requireChanMsg(t, channel2)
require.Equal(t, received.Event.TriggerType, TriggerType)
require.NoError(t, chanErr)
data := received.Event.Outputs
var payload webapicapabilities.TriggerRequestPayload
unwrapErr := data.UnwrapTo(&payload)
require.NoError(t, unwrapErr)
require.Equal(t, payload.Topics, []string{"daily_price_update"})
})
err = th.trigger.UnregisterTrigger(ctx, triggerReq)
require.NoError(t, err)
err = th.trigger.UnregisterTrigger(ctx, triggerReq2)
require.NoError(t, err)
cancelContext()
}

func TestRegisterUnregister(t *testing.T) {
th := setup(t)
ctx := testutils.Context(t)
Expand Down

0 comments on commit c306788

Please sign in to comment.