Skip to content

Commit

Permalink
feat: add --cost-limit flag (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
matheusfm authored Feb 21, 2025
1 parent 760c0c8 commit 7e5bd08
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
7 changes: 6 additions & 1 deletion pkg/cmd/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type ScanOptions struct {
SkipAnnotation *string
DisableAnnotationSkip *bool
DisableZoraBanner *bool
CostLimit *uint64

ctx context.Context
log logr.Logger
Expand All @@ -73,6 +74,7 @@ func NewScanOptions() *ScanOptions {
DisableAnnotationSkip: pointer.Bool(false),
DisableZoraBanner: pointer.Bool(false),
SkipAnnotation: pointer.String("marvin.undistro.io/skip"),
CostLimit: pointer.Uint64(1000000),
}
}

Expand Down Expand Up @@ -100,6 +102,9 @@ func (o *ScanOptions) AddFlags(flags *pflag.FlagSet) {
if o.DisableZoraBanner != nil {
flags.BoolVar(o.DisableZoraBanner, "disable-zora-banner", *o.DisableZoraBanner, "Disable Zora banner on output")
}
if o.CostLimit != nil {
flags.Uint64Var(o.CostLimit, "cost-limit", *o.CostLimit, "CEL cost limit. Set 0 to disable it.")
}
}

// Init initializes the kubernetes clients, get server version and API resources
Expand Down Expand Up @@ -213,7 +218,7 @@ func (o *ScanOptions) runCheck(check types.Check) *types.CheckResult {
log := o.log.WithValues("check", check.ID)
cr := types.NewCheckResult(check)
defer cr.UpdateStatus()
v, err := validator.Compile(check, o.apiResources, o.kubeVersion)
v, err := validator.Compile(check, o.apiResources, o.kubeVersion, *o.CostLimit)
if err != nil {
log.Error(err, "failed to compile check "+check.ID)
cr.AddError(fmt.Errorf("%s compile error: %s", check.Path, err.Error()))
Expand Down
23 changes: 13 additions & 10 deletions pkg/validator/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ var baseEnvOptions = []cel.EnvOption{

var programOptions = []cel.ProgramOption{
cel.EvalOptions(cel.OptOptimize),
cel.CostLimit(1000000),
cel.InterruptCheckFrequency(100),
}

Expand All @@ -60,7 +59,7 @@ var podSpecEnvOptions = []cel.EnvOption{
}

// Compile compiles variables and expressions of the given check and returns a Validator
func Compile(check types.Check, apiResources []*metav1.APIResourceList, kubeVersion *version.Info) (Validator, error) {
func Compile(check types.Check, apiResources []*metav1.APIResourceList, kubeVersion *version.Info, costLimit uint64) (Validator, error) {
if len(check.Validations) == 0 {
return nil, errors.New("invalid check: a check must have at least 1 validation")
}
Expand All @@ -69,12 +68,12 @@ func Compile(check types.Check, apiResources []*metav1.APIResourceList, kubeVers
return nil, fmt.Errorf("environment construction error %s", err.Error())
}

variables, err := compileVariables(env, check.Variables)
variables, err := compileVariables(env, check.Variables, costLimit)
if err != nil {
return nil, err
}

prgs, err := compileValidations(env, check.Validations)
prgs, err := compileValidations(env, check.Validations, costLimit)
if err != nil {
return nil, err
}
Expand All @@ -100,10 +99,10 @@ func newEnv(check types.Check) (*cel.Env, error) {
return cel.NewEnv(opts...)
}

func compileVariables(env *cel.Env, vars []types.Variable) ([]compiledVariable, error) {
func compileVariables(env *cel.Env, vars []types.Variable, costLimit uint64) ([]compiledVariable, error) {
variables := make([]compiledVariable, 0, len(vars))
for _, v := range vars {
prg, err := compileExpression(env, v.Expression, cel.AnyType)
prg, err := compileExpression(env, v.Expression, costLimit, cel.AnyType)
if err != nil {
return nil, fmt.Errorf("variables[%q].expression: %s", v.Name, err)
}
Expand All @@ -112,10 +111,10 @@ func compileVariables(env *cel.Env, vars []types.Variable) ([]compiledVariable,
return variables, nil
}

func compileValidations(env *cel.Env, vals []types.Validation) ([]cel.Program, error) {
func compileValidations(env *cel.Env, vals []types.Validation, costLimit uint64) ([]cel.Program, error) {
prgs := make([]cel.Program, 0, len(vals))
for i, v := range vals {
prg, err := compileExpression(env, v.Expression, cel.BoolType)
prg, err := compileExpression(env, v.Expression, costLimit, cel.BoolType)
if err != nil {
return nil, fmt.Errorf("validations[%d].expression: %s", i, err)
}
Expand All @@ -124,7 +123,7 @@ func compileValidations(env *cel.Env, vals []types.Validation) ([]cel.Program, e
return prgs, nil
}

func compileExpression(env *cel.Env, exp string, allowedTypes ...*cel.Type) (cel.Program, error) {
func compileExpression(env *cel.Env, exp string, costLimit uint64, allowedTypes ...*cel.Type) (cel.Program, error) {
ast, issues := env.Compile(exp)
if issues != nil && issues.Err() != nil {
return nil, fmt.Errorf("type-check error: %s", issues.Err())
Expand All @@ -142,7 +141,11 @@ func compileExpression(env *cel.Env, exp string, allowedTypes ...*cel.Type) (cel
}
return nil, fmt.Errorf("must evaluate to one of %v", allowedTypes)
}
prg, err := env.Program(ast, programOptions...)
opts := programOptions
if costLimit <= 0 {
opts = append(opts, cel.CostLimit(costLimit))
}
prg, err := env.Program(ast, opts...)
if err != nil {
return nil, fmt.Errorf("program construction error: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/validator/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestCompile(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.check.ID, func(t *testing.T) {
_, err := Compile(tt.check, apiResources, kubeVersion)
_, err := Compile(tt.check, apiResources, kubeVersion, 1000000)
if !tt.wantErr(t, err, fmt.Sprintf("Compile(%v, %v, %v)", tt.check, apiResources, kubeVersion)) {
return
}
Expand Down
2 changes: 1 addition & 1 deletion test/builtins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestBuiltinChecks(t *testing.T) {
assert.True(t, ok)
assert.NotNil(t, check)
assert.NotEmpty(t, check.ID)
v, err := validator.Compile(check, nil, nil)
v, err := validator.Compile(check, nil, nil, 1000000)
assert.NoError(t, err)
assert.NotNil(t, v)
for _, tt := range checkTests {
Expand Down

0 comments on commit 7e5bd08

Please sign in to comment.