Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor proxy.shouldTerminate function and move the functionality to Act.Registry #615

Merged
merged 6 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/rs/zerolog"
"github.com/spf13/cast"
)

type IRegistry interface {
Add(policy *sdkAct.Policy)
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
RunAll(result map[string]any) map[string]any
ShouldTerminate(result map[string]any) bool
}

// Registry keeps track of all policies and actions.
Expand Down Expand Up @@ -402,6 +405,70 @@ func runActionWithTimeout(
}
}

// RunAll run all the actions in the outputs and returns the end result.
func (r *Registry) RunAll(result map[string]any) map[string]any {
if _, exists := result[sdkAct.Outputs]; !exists {
r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is")
return result
}

var (
outputs []*sdkAct.Output
ok bool
)
if outputs, ok = result[sdkAct.Outputs].([]*sdkAct.Output); !ok || len(outputs) == 0 {
r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is")
// If the outputs are nil or empty, we should delete the key from the result.
delete(result, sdkAct.Outputs)
return result
}

endResult := make(map[string]any)
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
r.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
runResult, err := r.Run(output, WithResult(result), WithLogger(r.Logger))
// If the action is async and we received a sentinel error, don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
r.Logger.Error().Err(err).Msg("Error running policy")
}
// Each action should return a map.
if v, ok := runResult.(map[string]any); ok {
endResult = v
} else {
r.Logger.Debug().Msg("Run result is not a map, skipping merging into end result.")
}
mostafa marked this conversation as resolved.
Show resolved Hide resolved
}
return endResult
}

// ShouldTerminate checks if any of the actions are terminal, indicating that the request
// should be terminated.
// This is an optimization to avoid executing the actions' functions unnecessarily.
// The __terminal__ field is only set when an action intends to terminate the request.
func (r *Registry) ShouldTerminate(result map[string]any) bool {
terminalVal, exists := result[sdkAct.Terminal]
if !exists {
r.Logger.Debug().Msg("Terminal key not found, request will continue.")
return false
}

shouldTerminate, ok := terminalVal.(bool)
if !ok {
r.Logger.Debug().Msg("Terminal key exists but cannot be cast to a boolean.")
return false
}

if shouldTerminate {
r.Logger.Debug().Msg("Request is marked as terminal. Terminating.")
}

return shouldTerminate
}

// WithLogger returns a parameter with the Logger to be used by the action.
// This is automatically prepended to the parameters when running an action.
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {
Expand Down
84 changes: 84 additions & 0 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,87 @@ func Test_Run_Timeout(t *testing.T) {
})
}
}

// Test_RunAll_And_ShouldTerminate tests the RunAll function of the act registry
// with a terminal action (and signal).
func Test_RunAll_And_ShouldTerminate(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: BuiltinPolicies(),
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
*sdkAct.Log("info", "testing log via Act", map[string]any{"test": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)

// This is what the hook returns along with "request", "response" and other fields.
// These two keys and values should exist in the result after policy execution.
result := map[string]any{
sdkAct.Outputs: outputs,
sdkAct.Terminal: true,
}

assert.True(t, actRegistry.ShouldTerminate(result))

result = actRegistry.RunAll(result)

time.Sleep(time.Millisecond) // wait for async action to complete

assert.NotEmpty(t, result)
// Terminate action does nothing when run. It is just a signal to terminate.
assert.Contains(t, out.String(),
`{"level":"debug","action":"terminate","executionMode":"sync","message":"Running action"}`)
assert.Contains(t, out.String(),
`{"level":"debug","action":"log","executionMode":"async","message":"Running action"}`)
assert.Contains(t, out.String(), `{"level":"info","test":true,"message":"testing log via Act"}`)
}

// Test_RunAll_Empty_Result tests the RunAll function of the act registry with an empty result.
func Test_RunAll_Empty_Result(t *testing.T) {
out := bytes.Buffer{}
logger := zerolog.New(&out)
actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: BuiltinPolicies(),
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
})
assert.NotNil(t, actRegistry)

results := []map[string]any{
{},
{
sdkAct.Outputs: false, // This is invalid, hence it will be removed.
},
}

for _, result := range results {
assert.False(t, actRegistry.ShouldTerminate(result))

result = actRegistry.RunAll(result)

time.Sleep(time.Millisecond) // wait for async action to complete

assert.Empty(t, result)
}
}
37 changes: 2 additions & 35 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ import (
"errors"
"io"
"net"
"slices"
"time"

sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/gatewayd-io/gatewayd/act"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
"github.com/gatewayd-io/gatewayd/metrics"
Expand All @@ -21,9 +18,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/go-co-op/gocron"
"github.com/rs/zerolog"
"github.com/spf13/cast"
"go.opentelemetry.io/otel"
"golang.org/x/exp/maps"
)

//nolint:interfacebloat
Expand Down Expand Up @@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) {
return false, result
}

outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output)
if !ok {
pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type")
return false, result
}

// This is a shortcut to avoid running the actions' functions.
// The Terminal field is only present if the action wants to terminate the request,
// that is the `__terminal__` field is set in one of the outputs.
keys := maps.Keys(result)
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
actionResult := make(map[string]any)
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
pr.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]any); ok {
actionResult = v
}
}
terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result)
actionResult := pr.PluginRegistry.ActRegistry.RunAll(result)
if terminate {
pr.Logger.Debug().Fields(
map[string]any{
Expand Down