From 3f35e13c2268307ddbb296fc0332250e1de2ffca Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sun, 6 Oct 2024 23:27:05 +0200 Subject: [PATCH 1/6] Refactor proxy.shouldTerminate function and move the functionality to Act.Registry Simplify syntax and avoid unnecessary loops (using maps.Keys) Create RunAll and ShouldTerminate functions in Act.Registry --- act/registry.go | 40 ++++++++++++++++++++++++++++++++++++++++ network/proxy.go | 37 ++----------------------------------- 2 files changed, 42 insertions(+), 35 deletions(-) diff --git a/act/registry.go b/act/registry.go index c4cab57a..2e2e9cf1 100644 --- a/act/registry.go +++ b/act/registry.go @@ -12,12 +12,15 @@ import ( "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/rs/zerolog" + "github.com/spf13/cast" ) type IRegistry interface { Add(policy *sdkAct.Policy) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError) + RunAll(result map[string]any) map[string]any + ShouldTerminate(result map[string]any) bool } // Registry keeps track of all policies and actions. @@ -402,6 +405,43 @@ func runActionWithTimeout( } } +// RunAll run all the actions in the outputs and returns the end result. +func (r *Registry) RunAll(result map[string]any) map[string]any { + outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) + if !ok { + r.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") + return nil + } + + endResult := make(map[string]any) + for _, output := range outputs { + if !cast.ToBool(output.Verdict) { + r.Logger.Debug().Msg( + "Skipping the action, because the verdict of the policy execution is false") + continue + } + runResult, err := r.Run(output, WithResult(result)) + // If the action is async and we received a sentinel error, don't log the error. + if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { + r.Logger.Error().Err(err).Msg("Error running policy") + } + // Each action should return a map. + if v, ok := runResult.(map[string]any); ok { + endResult = v + } + } + return endResult +} + +// ShouldTerminate checks if any of the actions are terminal, indicating that the request +// should be terminated. +// This is an optimization to avoid executing the actions' functions unnecessarily. +// The __terminal__ field is only set when an action intends to terminate the request. +func (r *Registry) ShouldTerminate(result map[string]any) bool { + _, ok := result[sdkAct.Terminal] + return ok && cast.ToBool(result[sdkAct.Terminal]) +} + // WithLogger returns a parameter with the Logger to be used by the action. // This is automatically prepended to the parameters when running an action. func WithLogger(logger zerolog.Logger) sdkAct.Parameter { diff --git a/network/proxy.go b/network/proxy.go index 50ad47b4..cdf9b71f 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -6,13 +6,10 @@ import ( "errors" "io" "net" - "slices" "time" - sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" "github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" - "github.com/gatewayd-io/gatewayd/act" "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/metrics" @@ -21,9 +18,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" "github.com/rs/zerolog" - "github.com/spf13/cast" "go.opentelemetry.io/otel" - "golang.org/x/exp/maps" ) //nolint:interfacebloat @@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) { return false, result } - outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) - if !ok { - pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") - return false, result - } - - // This is a shortcut to avoid running the actions' functions. - // The Terminal field is only present if the action wants to terminate the request, - // that is the `__terminal__` field is set in one of the outputs. - keys := maps.Keys(result) - terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal]) - actionResult := make(map[string]any) - for _, output := range outputs { - if !cast.ToBool(output.Verdict) { - pr.Logger.Debug().Msg( - "Skipping the action, because the verdict of the policy execution is false") - continue - } - actRes, err := pr.PluginRegistry.ActRegistry.Run( - output, act.WithResult(result)) - // If the action is async and we received a sentinel error, - // don't log the error. - if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { - pr.Logger.Error().Err(err).Msg("Error running policy") - } - // The terminate action should return a map. - if v, ok := actRes.(map[string]any); ok { - actionResult = v - } - } + terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result) + actionResult := pr.PluginRegistry.ActRegistry.RunAll(result) if terminate { pr.Logger.Debug().Fields( map[string]any{ From 3efa57cc0d5562572f1d350974a3eec0a26c4c26 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 00:06:04 +0200 Subject: [PATCH 2/6] Add logger to RunAll --- act/registry.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/act/registry.go b/act/registry.go index 2e2e9cf1..06dc285b 100644 --- a/act/registry.go +++ b/act/registry.go @@ -420,7 +420,7 @@ func (r *Registry) RunAll(result map[string]any) map[string]any { "Skipping the action, because the verdict of the policy execution is false") continue } - runResult, err := r.Run(output, WithResult(result)) + runResult, err := r.Run(output, WithResult(result), WithLogger(r.Logger)) // If the action is async and we received a sentinel error, don't log the error. if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { r.Logger.Error().Err(err).Msg("Error running policy") From 0ba1ba98ac57fef32b825cac246ee54f9a55b72a Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 00:06:28 +0200 Subject: [PATCH 3/6] Add test for RunAll and ShouldTerminate --- act/registry_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/act/registry_test.go b/act/registry_test.go index d2f54bdb..220b8820 100644 --- a/act/registry_test.go +++ b/act/registry_test.go @@ -930,3 +930,51 @@ func Test_Run_Timeout(t *testing.T) { }) } } + +func Test_RunAll_And_ShouldTerminate(t *testing.T) { + out := bytes.Buffer{} + logger := zerolog.New(&out) + actRegistry := NewActRegistry( + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) + assert.NotNil(t, actRegistry) + + outputs := actRegistry.Apply([]sdkAct.Signal{ + *sdkAct.Terminate(), + *sdkAct.Log("info", "testing log via Act", map[string]any{"test": true}), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }) + assert.NotNil(t, outputs) + + // This is what the hook returns along with "request", "response" and other fields. + // These two keys and values should exist in the result after policy execution. + result := map[string]any{ + sdkAct.Outputs: outputs, + sdkAct.Terminal: true, + } + + assert.True(t, actRegistry.ShouldTerminate(result)) + + result = actRegistry.RunAll(result) + + time.Sleep(time.Millisecond) // wait for async action to complete + + assert.NotEmpty(t, result) + // Terminate action does nothing when run. It is just a signal to terminate. + assert.Contains(t, out.String(), + `{"level":"debug","action":"terminate","executionMode":"sync","message":"Running action"}`) + assert.Contains(t, out.String(), + `{"level":"debug","action":"log","executionMode":"async","message":"Running action"}`) + assert.Contains(t, out.String(), `{"level":"info","test":true,"message":"testing log via Act"}`) +} From 6ab0dd63edc19eca3d14ae0ddbd8e291f4bdede6 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 00:24:38 +0200 Subject: [PATCH 4/6] Separate the check for existence of 'outputs' key from type check --- act/registry.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/act/registry.go b/act/registry.go index 06dc285b..739a371f 100644 --- a/act/registry.go +++ b/act/registry.go @@ -407,10 +407,20 @@ func runActionWithTimeout( // RunAll run all the actions in the outputs and returns the end result. func (r *Registry) RunAll(result map[string]any) map[string]any { - outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output) - if !ok { - r.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type") - return nil + if _, exists := result[sdkAct.Outputs]; !exists { + r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is") + return result + } + + var ( + outputs []*sdkAct.Output + ok bool + ) + if outputs, ok = result[sdkAct.Outputs].([]*sdkAct.Output); !ok || len(outputs) == 0 { + // If the outputs are nil or empty, we should delete the key from the result. + delete(result, sdkAct.Outputs) + r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is") + return result } endResult := make(map[string]any) From 81afcee7c96a33f94b610f9264dc0adf43b6c3d5 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 00:24:52 +0200 Subject: [PATCH 5/6] Add test case for failures --- act/registry_test.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/act/registry_test.go b/act/registry_test.go index 220b8820..53665764 100644 --- a/act/registry_test.go +++ b/act/registry_test.go @@ -931,6 +931,8 @@ func Test_Run_Timeout(t *testing.T) { } } +// Test_RunAll_And_ShouldTerminate tests the RunAll function of the act registry +// with a terminal action (and signal). func Test_RunAll_And_ShouldTerminate(t *testing.T) { out := bytes.Buffer{} logger := zerolog.New(&out) @@ -978,3 +980,37 @@ func Test_RunAll_And_ShouldTerminate(t *testing.T) { `{"level":"debug","action":"log","executionMode":"async","message":"Running action"}`) assert.Contains(t, out.String(), `{"level":"info","test":true,"message":"testing log via Act"}`) } + +// Test_RunAll_Empty_Result tests the RunAll function of the act registry with an empty result. +func Test_RunAll_Empty_Result(t *testing.T) { + out := bytes.Buffer{} + logger := zerolog.New(&out) + actRegistry := NewActRegistry( + Registry{ + Signals: BuiltinSignals(), + Policies: BuiltinPolicies(), + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) + assert.NotNil(t, actRegistry) + + results := []map[string]any{ + {}, + { + sdkAct.Outputs: false, // This is invalid, hence it will be removed. + }, + } + + for _, result := range results { + assert.False(t, actRegistry.ShouldTerminate(result)) + + result = actRegistry.RunAll(result) + + time.Sleep(time.Millisecond) // wait for async action to complete + + assert.Empty(t, result) + } +} From f76cd2b9cfcb0f0e9174376cf36dc24944ac4af7 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Mon, 7 Oct 2024 18:54:35 +0200 Subject: [PATCH 6/6] Address comments by @sinadarbouy --- act/registry.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/act/registry.go b/act/registry.go index 739a371f..68fe97ab 100644 --- a/act/registry.go +++ b/act/registry.go @@ -417,9 +417,9 @@ func (r *Registry) RunAll(result map[string]any) map[string]any { ok bool ) if outputs, ok = result[sdkAct.Outputs].([]*sdkAct.Output); !ok || len(outputs) == 0 { + r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is") // If the outputs are nil or empty, we should delete the key from the result. delete(result, sdkAct.Outputs) - r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is") return result } @@ -438,6 +438,8 @@ func (r *Registry) RunAll(result map[string]any) map[string]any { // Each action should return a map. if v, ok := runResult.(map[string]any); ok { endResult = v + } else { + r.Logger.Debug().Msg("Run result is not a map, skipping merging into end result.") } } return endResult @@ -448,8 +450,23 @@ func (r *Registry) RunAll(result map[string]any) map[string]any { // This is an optimization to avoid executing the actions' functions unnecessarily. // The __terminal__ field is only set when an action intends to terminate the request. func (r *Registry) ShouldTerminate(result map[string]any) bool { - _, ok := result[sdkAct.Terminal] - return ok && cast.ToBool(result[sdkAct.Terminal]) + terminalVal, exists := result[sdkAct.Terminal] + if !exists { + r.Logger.Debug().Msg("Terminal key not found, request will continue.") + return false + } + + shouldTerminate, ok := terminalVal.(bool) + if !ok { + r.Logger.Debug().Msg("Terminal key exists but cannot be cast to a boolean.") + return false + } + + if shouldTerminate { + r.Logger.Debug().Msg("Request is marked as terminal. Terminating.") + } + + return shouldTerminate } // WithLogger returns a parameter with the Logger to be used by the action.