From 589416cd671ffdc9aff4e38cf25d8634a7cc24d2 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Apr 2024 13:16:48 +0900 Subject: [PATCH] support context propagation - context.Context instance passed in ContextEval can be propagated to binding function to cancel the process. --- cel/cel_test.go | 37 ++++- cel/decls.go | 18 +++ cel/decls_test.go | 5 +- cel/library.go | 13 +- cel/program.go | 16 +- common/decls/decls.go | 83 +++++++---- common/decls/decls_test.go | 61 ++++---- common/functions/functions.go | 24 ++- interpreter/attribute_patterns.go | 21 +-- interpreter/attribute_patterns_test.go | 22 +-- interpreter/attributes.go | 123 ++++++++-------- interpreter/attributes_test.go | 127 +++++++++------- interpreter/decorators.go | 22 +-- interpreter/interpretable.go | 195 +++++++++++++------------ interpreter/interpreter.go | 4 +- interpreter/interpreter_test.go | 24 +-- interpreter/optimizations.go | 3 +- interpreter/planner.go | 11 +- interpreter/prune_test.go | 3 +- interpreter/runtimecost_test.go | 3 +- 20 files changed, 481 insertions(+), 334 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index f386b13b..2c05ee1a 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1014,6 +1014,41 @@ func TestContextEval(t *testing.T) { } } +func TestContextEvalPropagation(t *testing.T) { + env, err := NewEnv(Function("test", + Overload("test_int", []*Type{}, IntType, + FunctionBindingContext(func(ctx context.Context, _ ...ref.Val) ref.Val { + md := ctx.Value("metadata") + if md == nil { + return types.NewErr("cannot find metadata value") + } + return types.Int(md.(int)) + }), + ), + )) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + ast, iss := env.Compile("test()") + if iss.Err() != nil { + t.Fatalf("env.Compile(expr) failed: %v", iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + + expected := 10 + ctx := context.WithValue(context.Background(), "metadata", expected) + out, _, err := prg.ContextEval(ctx, map[string]interface{}{}) + if err != nil { + t.Fatalf("prg.ContextEval() failed: %v", err) + } + if out != types.Int(expected) { + t.Errorf("prg.ContextEval() got %v, but wanted %d", out, expected) + } +} + func BenchmarkContextEval(b *testing.B) { env := testEnv(b, Variable("items", ListType(IntType)), @@ -1428,7 +1463,7 @@ func TestCustomInterpreterDecorator(t *testing.T) { if !lhsIsConst || !rhsIsConst { return i, nil } - val := call.Eval(interpreter.EmptyActivation()) + val := call.Eval(context.Background(), interpreter.EmptyActivation()) if types.IsError(val) { return nil, val.(*types.Err) } diff --git a/cel/decls.go b/cel/decls.go index b59e3708..06de1d8a 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -287,6 +287,24 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt { return decls.FunctionBinding(binding) } +// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt { + return decls.UnaryBindingContext(binding) +} + +// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt { + return decls.BinaryBindingContext(binding) +} + +// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt { + return decls.FunctionBindingContext(binding) +} + // OverloadIsNonStrict enables the function to be called with error and unknown argument values. // // Note: do not use this option unless absoluately necessary as it should be an uncommon feature. diff --git a/cel/decls_test.go b/cel/decls_test.go index f15862fa..e34e696b 100644 --- a/cel/decls_test.go +++ b/cel/decls_test.go @@ -15,6 +15,7 @@ package cel import ( + "context" "fmt" "math" "reflect" @@ -673,7 +674,7 @@ func TestExprDeclToDeclaration(t *testing.T) { } prg, err := e.Program(ast, Functions(&functions.Overload{ Operator: overloads.SizeString, - Unary: func(arg ref.Val) ref.Val { + Unary: func(ctx context.Context, arg ref.Val) ref.Val { str, ok := arg.(types.String) if !ok { return types.MaybeNoSuchOverloadErr(arg) @@ -682,7 +683,7 @@ func TestExprDeclToDeclaration(t *testing.T) { }, }, &functions.Overload{ Operator: overloads.SizeStringInst, - Unary: func(arg ref.Val) ref.Val { + Unary: func(ctx context.Context, arg ref.Val) ref.Val { str, ok := arg.(types.String) if !ok { return types.MaybeNoSuchOverloadErr(arg) diff --git a/cel/library.go b/cel/library.go index 754d91b9..0b75ee05 100644 --- a/cel/library.go +++ b/cel/library.go @@ -15,6 +15,7 @@ package cel import ( + "context" "math" "strconv" "strings" @@ -494,9 +495,9 @@ func (opt *evalOptionalOr) ID() int64 { // Eval evaluates the left-hand side optional to determine whether it contains a value, else // proceeds with the right-hand side evaluation. -func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val { +func (opt *evalOptionalOr) Eval(ctx context.Context, vars interpreter.Activation) ref.Val { // short-circuit lhs. - optLHS := opt.lhs.Eval(ctx) + optLHS := opt.lhs.Eval(ctx, vars) optVal, ok := optLHS.(*types.Optional) if !ok { return optLHS @@ -504,7 +505,7 @@ func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val { if optVal.HasValue() { return optVal } - return opt.rhs.Eval(ctx) + return opt.rhs.Eval(ctx, vars) } // evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value, @@ -522,9 +523,9 @@ func (opt *evalOptionalOrValue) ID() int64 { // Eval evaluates the left-hand side optional to determine whether it contains a value, else // proceeds with the right-hand side evaluation. -func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { +func (opt *evalOptionalOrValue) Eval(ctx context.Context, vars interpreter.Activation) ref.Val { // short-circuit lhs. - optLHS := opt.lhs.Eval(ctx) + optLHS := opt.lhs.Eval(ctx, vars) optVal, ok := optLHS.(*types.Optional) if !ok { return optLHS @@ -532,7 +533,7 @@ func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { if optVal.HasValue() { return optVal.GetValue() } - return opt.rhs.Eval(ctx) + return opt.rhs.Eval(ctx, vars) } type timeUTCLibrary struct{} diff --git a/cel/program.go b/cel/program.go index ece9fbda..455f54c2 100644 --- a/cel/program.go +++ b/cel/program.go @@ -264,6 +264,11 @@ func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorat // Eval implements the Program interface method. func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) { + return p.eval(context.Background(), input) +} + +// Eval implements the Program interface method. +func (p *prog) eval(ctx context.Context, input any) (v ref.Val, det *EvalDetails, err error) { // Configure error recovery for unexpected panics during evaluation. Note, the use of named // return values makes it possible to modify the error response during the recovery // function. @@ -291,7 +296,7 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) { if p.defaultVars != nil { vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars) } - v = p.interpretable.Eval(vars) + v = p.interpretable.Eval(ctx, vars) // The output of an internal Eval may have a value (`v`) that is a types.Err. This step // translates the CEL value to a Go error response. This interface does not quite match the // RPC signature which allows for multiple errors to be returned, but should be sufficient. @@ -321,7 +326,7 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail default: return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input) } - return p.Eval(vars) + return p.eval(ctx, vars) } // progFactory is a helper alias for marking a program creation factory function. @@ -349,6 +354,11 @@ func newProgGen(factory progFactory) (Program, error) { // Eval implements the Program interface method. func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) { + return gen.eval(context.Background(), input) +} + +// Eval implements the Program interface method. +func (gen *progGen) eval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) { // The factory based Eval() differs from the standard evaluation model in that it generates a // new EvalState instance for each call to ensure that unique evaluations yield unique stateful // results. @@ -368,7 +378,7 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) { } // Evaluate the input, returning the result and the 'state' within EvalDetails. - v, _, err := p.Eval(input) + v, _, err := p.ContextEval(ctx, input) if err != nil { return v, det, err } diff --git a/common/decls/decls.go b/common/decls/decls.go index 734ebe57..e49e8acf 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -16,6 +16,7 @@ package decls import ( + "context" "fmt" "strings" @@ -242,7 +243,7 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { // All of the defined overloads are wrapped into a top-level function which // performs dynamic dispatch to the proper overload based on the argument types. bindings := append([]*functions.Overload{}, overloads...) - funcDispatch := func(args ...ref.Val) ref.Val { + funcDispatch := func(ctx context.Context, args ...ref.Val) ref.Val { for _, oID := range f.overloadOrdinals { o := f.overloads[oID] // During dynamic dispatch over multiple functions, signature agreement checks @@ -250,15 +251,15 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { switch len(args) { case 1: if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) { - return o.unaryOp(args[0]) + return o.unaryOp(ctx, args[0]) } case 2: if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) { - return o.binaryOp(args[0], args[1]) + return o.binaryOp(ctx, args[0], args[1]) } } if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) { - return o.functionOp(args...) + return o.functionOp(ctx, args...) } // eventually this will fall through to the noSuchOverload below. } @@ -333,8 +334,10 @@ func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt { return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name()) } f.singleton = &functions.Overload{ - Operator: f.Name(), - Unary: fn, + Operator: f.Name(), + Unary: func(ctx context.Context, val ref.Val) ref.Val { + return fn(val) + }, OperandTrait: trait, } return f, nil @@ -355,8 +358,10 @@ func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt { return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name()) } f.singleton = &functions.Overload{ - Operator: f.Name(), - Binary: fn, + Operator: f.Name(), + Binary: func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val { + return fn(lhs, rhs) + }, OperandTrait: trait, } return f, nil @@ -377,8 +382,10 @@ func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOp return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name()) } f.singleton = &functions.Overload{ - Operator: f.Name(), - Function: fn, + Operator: f.Name(), + Function: func(ctx context.Context, values ...ref.Val) ref.Val { + return fn(values...) + }, OperandTrait: trait, } return f, nil @@ -460,11 +467,11 @@ type OverloadDecl struct { // Function implementation options. Optional, but encouraged. // unaryOp is a function binding that takes a single argument. - unaryOp functions.UnaryOp + unaryOp functions.UnaryContextOp // binaryOp is a function binding that takes two arguments. - binaryOp functions.BinaryOp + binaryOp functions.BinaryContextOp // functionOp is a catch-all for zero-arity and three-plus arity functions. - functionOp functions.FunctionOp + functionOp functions.FunctionContextOp } // ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker @@ -580,41 +587,41 @@ func (o *OverloadDecl) hasBinding() bool { } // guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined. -func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryOp { +func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryContextOp { if o.unaryOp == nil { return nil } - return func(arg ref.Val) ref.Val { + return func(ctx context.Context, arg ref.Val) ref.Val { if !o.matchesRuntimeUnarySignature(disableTypeGuards, arg) { return MaybeNoSuchOverload(funcName, arg) } - return o.unaryOp(arg) + return o.unaryOp(ctx, arg) } } // guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined. -func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryOp { +func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryContextOp { if o.binaryOp == nil { return nil } - return func(arg1, arg2 ref.Val) ref.Val { + return func(ctx context.Context, arg1, arg2 ref.Val) ref.Val { if !o.matchesRuntimeBinarySignature(disableTypeGuards, arg1, arg2) { return MaybeNoSuchOverload(funcName, arg1, arg2) } - return o.binaryOp(arg1, arg2) + return o.binaryOp(ctx, arg1, arg2) } } // guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided. -func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionOp { +func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionContextOp { if o.functionOp == nil { return nil } - return func(args ...ref.Val) ref.Val { + return func(ctx context.Context, args ...ref.Val) ref.Val { if !o.matchesRuntimeSignature(disableTypeGuards, args...) { return MaybeNoSuchOverload(funcName, args...) } - return o.functionOp(args...) + return o.functionOp(ctx, args...) } } @@ -667,6 +674,30 @@ type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error) // UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. func UnaryBinding(binding functions.UnaryOp) OverloadOpt { + return UnaryBindingContext(func(ctx context.Context, val ref.Val) ref.Val { + return binding(val) + }) +} + +// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func BinaryBinding(binding functions.BinaryOp) OverloadOpt { + return BinaryBindingContext(func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val { + return binding(lhs, rhs) + }) +} + +// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func FunctionBinding(binding functions.FunctionOp) OverloadOpt { + return FunctionBindingContext(func(ctx context.Context, values ...ref.Val) ref.Val { + return binding(values...) + }) +} + +// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func UnaryBindingContext(binding functions.UnaryContextOp) OverloadOpt { return func(o *OverloadDecl) (*OverloadDecl, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.ID()) @@ -679,9 +710,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt { } } -// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime +// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. -func BinaryBinding(binding functions.BinaryOp) OverloadOpt { +func BinaryBindingContext(binding functions.BinaryContextOp) OverloadOpt { return func(o *OverloadDecl) (*OverloadDecl, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.ID()) @@ -694,9 +725,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt { } } -// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime +// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime // type-guard which ensures runtime type agreement between the overload signature and runtime argument types. -func FunctionBinding(binding functions.FunctionOp) OverloadOpt { +func FunctionBindingContext(binding functions.FunctionContextOp) OverloadOpt { return func(o *OverloadDecl) (*OverloadDecl, error) { if o.hasBinding() { return nil, fmt.Errorf("overload already has a binding: %s", o.ID()) diff --git a/common/decls/decls_test.go b/common/decls/decls_test.go index 4a017025..77e5095e 100644 --- a/common/decls/decls_test.go +++ b/common/decls/decls_test.go @@ -15,6 +15,7 @@ package decls import ( + "context" "reflect" "strings" "testing" @@ -75,11 +76,11 @@ func TestFunctionBindings(t *testing.T) { t.Errorf("binding missing unary implementation: %v", binding.Operator) continue } - if binding.Unary(in) != types.Int(2) { - t.Errorf("binding invocation got %v, wanted 2", binding.Unary(in)) + if binding.Unary(context.Background(), in) != types.Int(2) { + t.Errorf("binding invocation got %v, wanted 2", binding.Unary(context.Background(), in)) } - if binding.Unary(empty) != types.IntZero { - t.Errorf("binding invocation got %v, wanted 0", binding.Unary(empty)) + if binding.Unary(context.Background(), empty) != types.IntZero { + t.Errorf("binding invocation got %v, wanted 0", binding.Unary(context.Background(), empty)) } } } @@ -121,48 +122,49 @@ func TestFunctionVariableArgBindings(t *testing.T) { if len(bindings) != 4 { t.Errorf("sizeFunc.Bindings() produced %d bindings, wanted 4", len(bindings)) } + ctx := context.Background() in := types.String("hi") sep := types.String("") out := types.DefaultTypeAdapter.NativeToValue([]string{"h", "i"}) for _, binding := range bindings { if binding.Unary != nil { - if binding.Unary(in).Equal(out) != types.True { - t.Errorf("binding invocation got %v, wanted %v", binding.Unary(in), out) + if binding.Unary(ctx, in).Equal(out) != types.True { + t.Errorf("binding invocation got %v, wanted %v", binding.Unary(ctx, in), out) } - celErr := binding.Unary(types.Bytes("hi")) + celErr := binding.Unary(ctx, types.Bytes("hi")) if !types.IsError(celErr) || !strings.Contains(celErr.(*types.Err).String(), "no such overload") { t.Errorf("binding.Unary(bytes) got %v, wanted no such overload", celErr) } } if binding.Binary != nil { - if binding.Binary(in, sep).Equal(out) != types.True { - t.Errorf("binding invocation got %v, wanted %v", binding.Binary(in, sep), out) + if binding.Binary(ctx, in, sep).Equal(out) != types.True { + t.Errorf("binding invocation got %v, wanted %v", binding.Binary(ctx, in, sep), out) } - celErr := binding.Binary(types.Bytes("hi"), sep) + celErr := binding.Binary(ctx, types.Bytes("hi"), sep) if !types.IsError(celErr) || !strings.Contains(celErr.(*types.Err).String(), "no such overload") { t.Errorf("binding.Binary(bytes, string) got %v, wanted no such overload", celErr) } - celUnk := binding.Binary(types.Bytes("hi"), types.NewUnknown(1, types.NewAttributeTrail("x"))) + celUnk := binding.Binary(ctx, types.Bytes("hi"), types.NewUnknown(1, types.NewAttributeTrail("x"))) if !types.IsUnknown(celUnk) { t.Errorf("binding.Binary(bytes, unk) got %v, wanted unknown{1}", celUnk) } } if binding.Function != nil { - if binding.Function(in, sep, types.IntNegOne).Equal(out) != types.True { - t.Errorf("binding invocation got %v, wanted %v", binding.Function(in, sep, types.IntNegOne), out) + if binding.Function(ctx, in, sep, types.IntNegOne).Equal(out) != types.True { + t.Errorf("binding invocation got %v, wanted %v", binding.Function(ctx, in, sep, types.IntNegOne), out) } - celErr := binding.Function(types.Bytes("hi"), sep, types.IntOne) + celErr := binding.Function(ctx, types.Bytes("hi"), sep, types.IntOne) if !types.IsError(celErr) || !strings.Contains(celErr.(*types.Err).String(), "no such overload") { t.Errorf("binding.Function(bytes, string, int) got %v, wanted no such overload", celErr) } if binding.Operator == "split" { - if binding.Function(in).Equal(out) != types.True { - t.Errorf("binding invocation got %v, wanted %v", binding.Function(in), out) + if binding.Function(ctx, in).Equal(out) != types.True { + t.Errorf("binding invocation got %v, wanted %v", binding.Function(ctx, in), out) } - if binding.Function(in, sep).Equal(out) != types.True { - t.Errorf("binding invocation got %v, wanted %v", binding.Function(in, sep), out) + if binding.Function(ctx, in, sep).Equal(out) != types.True { + t.Errorf("binding invocation got %v, wanted %v", binding.Function(ctx, in, sep), out) } - out := binding.Function() + out := binding.Function(ctx) if !types.IsError(out) || out.(*types.Err).String() != "no such overload: split()" { t.Fatalf("binding.Function() got %v, wanted error", out) } @@ -189,7 +191,7 @@ func TestFunctionZeroArityBinding(t *testing.T) { if len(bindings) != 1 { t.Errorf("nowFunc.Bindings() produced %d bindings, wanted one", len(bindings)) } - out := bindings[0].Function() + out := bindings[0].Function(context.Background()) if out != now { t.Errorf("now() got %v, wanted %v", out, now) } @@ -231,12 +233,13 @@ func TestFunctionSingletonBinding(t *testing.T) { if bindings[0].Unary == nil { t.Fatalf("size.Bindings() missing singleton unary binding") } - result := bindings[0].Unary(types.String("hello")) + ctx := context.Background() + result := bindings[0].Unary(ctx, types.String("hello")) if result.Equal(types.Int(5)) != types.True { t.Errorf("size('hello') got %v, wanted 5", result) } // Invalid at type-check, but valid since type guard checks have been disabled - result = bindings[0].Unary(types.Bytes("hello")) + result = bindings[0].Unary(ctx, types.Bytes("hello")) if result.Equal(types.Int(5)) != types.True { t.Errorf("size(b'hello') got %v, wanted 5", result) } @@ -603,16 +606,17 @@ func TestOverloadIsNonStrict(t *testing.T) { if err != nil { t.Fatalf("fn.Binding() failed: %v", err) } + ctx := context.Background() m := types.DefaultTypeAdapter.NativeToValue(map[string]string{"hello": "world"}) - out := bindings[0].Function(m, types.String("hello"), types.String("goodbye")) + out := bindings[0].Function(ctx, m, types.String("hello"), types.String("goodbye")) if out.Equal(types.String("world")) != types.True { t.Errorf("function got %v, wanted 'world'", out) } - out = bindings[0].Function(m, types.String("missing"), types.String("goodbye")) + out = bindings[0].Function(ctx, m, types.String("missing"), types.String("goodbye")) if out.Equal(types.String("goodbye")) != types.True { t.Errorf("function got %v, wanted 'goodbye'", out) } - out = bindings[0].Function(m, types.NewErr("no such key"), types.String("goodbye")) + out = bindings[0].Function(ctx, m, types.NewErr("no such key"), types.String("goodbye")) if out.Equal(types.String("goodbye")) != types.True { t.Errorf("function got %v, wanted 'goodbye'", out) } @@ -647,16 +651,17 @@ func TestOverloadOperandTrait(t *testing.T) { t.Fatalf("fn.Binding() failed: %v", err) } m := types.DefaultTypeAdapter.NativeToValue(map[string]string{"hello": "world"}) - out := bindings[0].Function(m, types.String("hello"), types.String("goodbye")) + ctx := context.Background() + out := bindings[0].Function(ctx, m, types.String("hello"), types.String("goodbye")) if out.Equal(types.String("world")) != types.True { t.Errorf("function got %v, wanted 'world'", out) } - out = bindings[0].Function(m, types.String("missing"), types.String("goodbye")) + out = bindings[0].Function(ctx, m, types.String("missing"), types.String("goodbye")) if out.Equal(types.String("goodbye")) != types.True { t.Errorf("function got %v, wanted 'goodbye'", out) } noSuchKey := types.NewErr("no such key") - out = bindings[0].Function(m, noSuchKey, types.String("goodbye")) + out = bindings[0].Function(ctx, m, noSuchKey, types.String("goodbye")) if out != noSuchKey { t.Errorf("function got %v, wanted 'no such key'", out) } diff --git a/common/functions/functions.go b/common/functions/functions.go index 67f4a594..49afb92c 100644 --- a/common/functions/functions.go +++ b/common/functions/functions.go @@ -15,7 +15,11 @@ // Package functions defines the standard builtin functions supported by the interpreter package functions -import "github.com/google/cel-go/common/types/ref" +import ( + "context" + + "github.com/google/cel-go/common/types/ref" +) // Overload defines a named overload of a function, indicating an operand trait // which must be present on the first argument to the overload as well as one @@ -36,14 +40,14 @@ type Overload struct { OperandTrait int // Unary defines the overload with a UnaryOp implementation. May be nil. - Unary UnaryOp + Unary UnaryContextOp // Binary defines the overload with a BinaryOp implementation. May be nil. - Binary BinaryOp + Binary BinaryContextOp // Function defines the overload with a FunctionOp implementation. May be // nil. - Function FunctionOp + Function FunctionContextOp // NonStrict specifies whether the Overload will tolerate arguments that // are types.Err or types.Unknown. @@ -51,7 +55,7 @@ type Overload struct { } // UnaryOp is a function that takes a single value and produces an output. -type UnaryOp func(value ref.Val) ref.Val +type UnaryOp func(ref.Val) ref.Val // BinaryOp is a function that takes two values and produces an output. type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val @@ -59,3 +63,13 @@ type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val // FunctionOp is a function with accepts zero or more arguments and produces // a value or error as a result. type FunctionOp func(values ...ref.Val) ref.Val + +// UnaryContextOp is a contextual function that takes a single value and produces an output. +type UnaryContextOp func(context.Context, ref.Val) ref.Val + +// BinaryContextOp is a contextual function that takes two values and produces an output. +type BinaryContextOp func(ctx context.Context, lhs ref.Val, rhs ref.Val) ref.Val + +// FunctionContextOp is a contextual function with accepts zero or more arguments and produces +// a value or error as a result. +type FunctionContextOp func(ctx context.Context, values ...ref.Val) ref.Val diff --git a/interpreter/attribute_patterns.go b/interpreter/attribute_patterns.go index 1fbaaf17..70c5e5eb 100644 --- a/interpreter/attribute_patterns.go +++ b/interpreter/attribute_patterns.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "fmt" "github.com/google/cel-go/common/containers" @@ -240,6 +241,7 @@ func (fac *partialAttributeFactory) MaybeAttribute(id int64, name string) Attrib // example, the expression id representing variable `a` would be listed in the Unknown result, // whereas in the other pattern examples, the qualifier `b` would be returned as the Unknown. func (fac *partialAttributeFactory) matchesUnknownPatterns( + ctx context.Context, vars PartialActivation, attrID int64, variableNames []string, @@ -267,7 +269,7 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns( for i, qual := range qualifiers { attr, isAttr := qual.(Attribute) if isAttr { - val, err := attr.Resolve(vars) + val, err := attr.Resolve(ctx, vars) if err != nil { return nil, err } @@ -338,11 +340,11 @@ type attributeMatcher struct { } // AddQualifier implements the Attribute interface method. -func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) { +func (m *attributeMatcher) AddQualifier(ctx context.Context, qual Qualifier) (Attribute, error) { // Add the qualifier to the embedded NamespacedAttribute. If the input to the Resolve // method is not a PartialActivation, or does not match an unknown attribute pattern, the // Resolve method is directly invoked on the underlying NamespacedAttribute. - _, err := m.NamespacedAttribute.AddQualifier(qual) + _, err := m.NamespacedAttribute.AddQualifier(ctx, qual) if err != nil { return nil, err } @@ -357,12 +359,13 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) { // Resolve is an implementation of the NamespacedAttribute interface method which tests // for matching unknown attribute patterns and returns types.Unknown if present. Otherwise, // the standard Resolve logic applies. -func (m *attributeMatcher) Resolve(vars Activation) (any, error) { +func (m *attributeMatcher) Resolve(ctx context.Context, vars Activation) (any, error) { id := m.NamespacedAttribute.ID() // Bug in how partial activation is resolved, should search parents as well. partial, isPartial := toPartialActivation(vars) if isPartial { unk, err := m.fac.matchesUnknownPatterns( + ctx, partial, id, m.CandidateVariableNames(), @@ -374,17 +377,17 @@ func (m *attributeMatcher) Resolve(vars Activation) (any, error) { return unk, nil } } - return m.NamespacedAttribute.Resolve(vars) + return m.NamespacedAttribute.Resolve(ctx, vars) } // Qualify is an implementation of the Qualifier interface method. -func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) { - return attrQualify(m.fac, vars, obj, m) +func (m *attributeMatcher) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + return attrQualify(ctx, m.fac, vars, obj, m) } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly) +func (m *attributeMatcher) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(ctx, m.fac, vars, obj, m, presenceOnly) } func toPartialActivation(vars Activation) (PartialActivation, bool) { diff --git a/interpreter/attribute_patterns_test.go b/interpreter/attribute_patterns_test.go index 67a93f64..845d4a0e 100644 --- a/interpreter/attribute_patterns_test.go +++ b/interpreter/attribute_patterns_test.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "fmt" "testing" @@ -201,7 +202,7 @@ func TestAttributePattern_UnknownResolution(t *testing.T) { fac := NewPartialAttributeFactory(cont, reg, reg) attr := genAttr(fac, m) partVars, _ := NewPartialActivation(EmptyActivation(), tst.pattern) - val, err := attr.Resolve(partVars) + val, err := attr.Resolve(context.Background(), partVars) if err != nil { t.Fatalf("Got error: %s, wanted unknown", err) } @@ -225,7 +226,7 @@ func TestAttributePattern_UnknownResolution(t *testing.T) { fac := NewPartialAttributeFactory(cont, reg, reg) attr := genAttr(fac, m) partVars, _ := NewPartialActivation(EmptyActivation(), tst.pattern) - val, err := attr.Resolve(partVars) + val, err := attr.Resolve(context.Background(), partVars) if err == nil { t.Fatalf("Got value: %s, wanted error", val) } @@ -236,18 +237,19 @@ func TestAttributePattern_UnknownResolution(t *testing.T) { } func TestAttributePattern_CrossReference(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) fac := NewPartialAttributeFactory(containers.DefaultContainer, reg, reg) a := fac.AbsoluteAttribute(1, "a") b := fac.AbsoluteAttribute(2, "b") - a.AddQualifier(b) + a.AddQualifier(ctx, b) // Ensure that var a[b], the dynamic index into var 'a' is the unknown value // returned from attribute resolution. partVars, _ := NewPartialActivation( map[string]any{"a": []int64{1, 2}}, NewAttributePattern("b")) - val, err := a.Resolve(partVars) + val, err := a.Resolve(ctx, partVars) if err != nil { t.Fatal(err) } @@ -263,7 +265,7 @@ func TestAttributePattern_CrossReference(t *testing.T) { map[string]any{"a": []int64{1, 2}}, NewAttributePattern("a").QualInt(0), NewAttributePattern("b")) - val, err = a.Resolve(partVars) + val, err = a.Resolve(ctx, partVars) if err != nil { t.Fatal(err) } @@ -277,7 +279,7 @@ func TestAttributePattern_CrossReference(t *testing.T) { partVars, _ = NewPartialActivation( map[string]any{"a": []int64{1, 2}, "b": 0}, NewAttributePattern("a").QualInt(0).QualString("c")) - val, err = a.Resolve(partVars) + val, err = a.Resolve(ctx, partVars) if err != nil { t.Fatal(err) } @@ -292,7 +294,7 @@ func TestAttributePattern_CrossReference(t *testing.T) { // is the partial attribute factory. partVars, _ = NewPartialActivation( map[string]any{"a": []int64{1, 2}, "b": 0}) - val, err = a.Resolve(partVars) + val, err = a.Resolve(ctx, partVars) if err != nil { t.Fatal(err) } @@ -306,9 +308,9 @@ func TestAttributePattern_CrossReference(t *testing.T) { NewAttributePattern("a").QualInt(0).QualString("c")) // Qualify a[b] with 'c', a[b].c c, _ := fac.NewQualifier(nil, 3, "c", false) - a.AddQualifier(c) + a.AddQualifier(ctx, c) // The resolve step should return unknown - val, err = a.Resolve(partVars) + val, err = a.Resolve(ctx, partVars) if err != nil { t.Fatal(err) } @@ -331,7 +333,7 @@ func genAttr(fac AttributeFactory, a attr) Attribute { } for _, q := range a.quals { qual, _ := fac.NewQualifier(nil, id, q, false) - attr.AddQualifier(qual) + attr.AddQualifier(context.Background(), qual) id++ } return attr diff --git a/interpreter/attributes.go b/interpreter/attributes.go index e505f85e..e81346f3 100644 --- a/interpreter/attributes.go +++ b/interpreter/attributes.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "fmt" "strings" @@ -77,12 +78,12 @@ type Qualifier interface { // Qualify performs a qualification, e.g. field selection, on the input object and returns // the value of the access and whether the value was set. A non-nil value with a false presence // test result indicates that the value being returned is the default value. - Qualify(vars Activation, obj any) (any, error) + Qualify(ctx context.Context, vars Activation, obj any) (any, error) // QualifyIfPresent qualifies the object if the qualifier is declared or defined on the object. // The 'presenceOnly' flag indicates that the value is not necessary, just a boolean status as // to whether the qualifier is present. - QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) + QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) } // ConstantQualifier interface embeds the Qualifier interface and provides an option to inspect the @@ -102,7 +103,7 @@ type Attribute interface { Qualifier // AddQualifier adds a qualifier on the Attribute or error if the qualification is not a valid qualifier type. - AddQualifier(Qualifier) (Attribute, error) + AddQualifier(context.Context, Qualifier) (Attribute, error) // Resolve returns the value of the Attribute and whether it was present given an Activation. // For objects which support safe traversal, the value may be non-nil and the presence flag be false. @@ -110,7 +111,7 @@ type Attribute interface { // If an error is encountered during attribute resolution, it will be returned immediately. // If the attribute cannot be resolved within the Activation, the result must be: `nil`, `error` // with the error indicating which variable was missing. - Resolve(Activation) (any, error) + Resolve(context.Context, Activation) (any, error) } // NamespacedAttribute values are a variable within a namespace, and an optional set of qualifiers @@ -245,7 +246,7 @@ func (a *absoluteAttribute) IsOptional() bool { } // AddQualifier implements the Attribute interface method. -func (a *absoluteAttribute) AddQualifier(qual Qualifier) (Attribute, error) { +func (a *absoluteAttribute) AddQualifier(ctx context.Context, qual Qualifier) (Attribute, error) { a.qualifiers = append(a.qualifiers, qual) return a, nil } @@ -261,13 +262,13 @@ func (a *absoluteAttribute) Qualifiers() []Qualifier { } // Qualify is an implementation of the Qualifier interface method. -func (a *absoluteAttribute) Qualify(vars Activation, obj any) (any, error) { - return attrQualify(a.fac, vars, obj, a) +func (a *absoluteAttribute) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + return attrQualify(ctx, a.fac, vars, obj, a) } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (a *absoluteAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +func (a *absoluteAttribute) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(ctx, a.fac, vars, obj, a, presenceOnly) } // String implements the Stringer interface method. @@ -281,7 +282,7 @@ func (a *absoluteAttribute) String() string { // If the variable name cannot be found as an Activation variable or in the TypeProvider as // a type, then the result is `nil`, `error` with the error indicating the name of the first // variable searched as missing. -func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { +func (a *absoluteAttribute) Resolve(ctx context.Context, vars Activation) (any, error) { for _, nm := range a.namespaceNames { // If the variable is found, process it. Otherwise, wait until the checks to // determine whether the type is unknown before returning. @@ -290,7 +291,7 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { if celErr, ok := obj.(*types.Err); ok { return nil, celErr.Unwrap() } - obj, isOpt, err := applyQualifiers(vars, obj, a.qualifiers) + obj, isOpt, err := applyQualifiers(ctx, vars, obj, a.qualifiers) if err != nil { return nil, err } @@ -349,12 +350,12 @@ func (a *conditionalAttribute) IsOptional() bool { // AddQualifier appends the same qualifier to both sides of the conditional, in effect managing // the qualification of alternate attributes. -func (a *conditionalAttribute) AddQualifier(qual Qualifier) (Attribute, error) { - _, err := a.truthy.AddQualifier(qual) +func (a *conditionalAttribute) AddQualifier(ctx context.Context, qual Qualifier) (Attribute, error) { + _, err := a.truthy.AddQualifier(ctx, qual) if err != nil { return nil, err } - _, err = a.falsy.AddQualifier(qual) + _, err = a.falsy.AddQualifier(ctx, qual) if err != nil { return nil, err } @@ -362,23 +363,23 @@ func (a *conditionalAttribute) AddQualifier(qual Qualifier) (Attribute, error) { } // Qualify is an implementation of the Qualifier interface method. -func (a *conditionalAttribute) Qualify(vars Activation, obj any) (any, error) { - return attrQualify(a.fac, vars, obj, a) +func (a *conditionalAttribute) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + return attrQualify(ctx, a.fac, vars, obj, a) } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (a *conditionalAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +func (a *conditionalAttribute) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(ctx, a.fac, vars, obj, a, presenceOnly) } // Resolve evaluates the condition, and then resolves the truthy or falsy branch accordingly. -func (a *conditionalAttribute) Resolve(vars Activation) (any, error) { - val := a.expr.Eval(vars) +func (a *conditionalAttribute) Resolve(ctx context.Context, vars Activation) (any, error) { + val := a.expr.Eval(ctx, vars) if val == types.True { - return a.truthy.Resolve(vars) + return a.truthy.Resolve(ctx, vars) } if val == types.False { - return a.falsy.Resolve(vars) + return a.falsy.Resolve(ctx, vars) } if types.IsUnknown(val) { return val, nil @@ -435,7 +436,7 @@ func (a *maybeAttribute) IsOptional() bool { // possible field selection -- ns.a['b'], a['b'] // // If none of the attributes within the maybe resolves a value, the result is an error. -func (a *maybeAttribute) AddQualifier(qual Qualifier) (Attribute, error) { +func (a *maybeAttribute) AddQualifier(ctx context.Context, qual Qualifier) (Attribute, error) { str := "" isStr := false cq, isConst := qual.(ConstantQualifier) @@ -452,7 +453,7 @@ func (a *maybeAttribute) AddQualifier(qual Qualifier) (Attribute, error) { augmentedNames[i] = fmt.Sprintf("%s.%s", name, str) } } - _, err := attr.AddQualifier(qual) + _, err := attr.AddQualifier(ctx, qual) if err != nil { return nil, err } @@ -465,21 +466,21 @@ func (a *maybeAttribute) AddQualifier(qual Qualifier) (Attribute, error) { } // Qualify is an implementation of the Qualifier interface method. -func (a *maybeAttribute) Qualify(vars Activation, obj any) (any, error) { - return attrQualify(a.fac, vars, obj, a) +func (a *maybeAttribute) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + return attrQualify(ctx, a.fac, vars, obj, a) } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (a *maybeAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +func (a *maybeAttribute) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(ctx, a.fac, vars, obj, a, presenceOnly) } // Resolve follows the variable resolution rules to determine whether the attribute is a variable // or a field selection. -func (a *maybeAttribute) Resolve(vars Activation) (any, error) { +func (a *maybeAttribute) Resolve(ctx context.Context, vars Activation) (any, error) { var maybeErr error for _, attr := range a.attrs { - obj, err := attr.Resolve(vars) + obj, err := attr.Resolve(ctx, vars) // Return an error if one is encountered. if err != nil { resErr, ok := err.(*resolutionError) @@ -533,32 +534,32 @@ func (a *relativeAttribute) IsOptional() bool { } // AddQualifier implements the Attribute interface method. -func (a *relativeAttribute) AddQualifier(qual Qualifier) (Attribute, error) { +func (a *relativeAttribute) AddQualifier(ctx context.Context, qual Qualifier) (Attribute, error) { a.qualifiers = append(a.qualifiers, qual) return a, nil } // Qualify is an implementation of the Qualifier interface method. -func (a *relativeAttribute) Qualify(vars Activation, obj any) (any, error) { - return attrQualify(a.fac, vars, obj, a) +func (a *relativeAttribute) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + return attrQualify(ctx, a.fac, vars, obj, a) } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (a *relativeAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +func (a *relativeAttribute) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(ctx, a.fac, vars, obj, a, presenceOnly) } // Resolve expression value and qualifier relative to the expression result. -func (a *relativeAttribute) Resolve(vars Activation) (any, error) { +func (a *relativeAttribute) Resolve(ctx context.Context, vars Activation) (any, error) { // First, evaluate the operand. - v := a.operand.Eval(vars) + v := a.operand.Eval(ctx, vars) if types.IsError(v) { return nil, v.(*types.Err) } if types.IsUnknown(v) { return v, nil } - obj, isOpt, err := applyQualifiers(vars, v, a.qualifiers) + obj, isOpt, err := applyQualifiers(ctx, vars, v, a.qualifiers) if err != nil { return nil, err } @@ -705,13 +706,13 @@ func (q *stringQualifier) IsOptional() bool { } // Qualify implements the Qualifier interface method. -func (q *stringQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *stringQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { val, _, err := q.qualifyInternal(vars, obj, false, false) return val, err } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (q *stringQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *stringQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { return q.qualifyInternal(vars, obj, true, presenceOnly) } @@ -806,13 +807,13 @@ func (q *intQualifier) IsOptional() bool { } // Qualify implements the Qualifier interface method. -func (q *intQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *intQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { val, _, err := q.qualifyInternal(vars, obj, false, false) return val, err } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (q *intQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *intQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { return q.qualifyInternal(vars, obj, true, presenceOnly) } @@ -933,13 +934,13 @@ func (q *uintQualifier) IsOptional() bool { } // Qualify implements the Qualifier interface method. -func (q *uintQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *uintQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { val, _, err := q.qualifyInternal(vars, obj, false, false) return val, err } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (q *uintQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *uintQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { return q.qualifyInternal(vars, obj, true, presenceOnly) } @@ -998,13 +999,13 @@ func (q *boolQualifier) IsOptional() bool { } // Qualify implements the Qualifier interface method. -func (q *boolQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *boolQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { val, _, err := q.qualifyInternal(vars, obj, false, false) return val, err } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (q *boolQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *boolQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { return q.qualifyInternal(vars, obj, true, presenceOnly) } @@ -1052,7 +1053,7 @@ func (q *fieldQualifier) IsOptional() bool { } // Qualify implements the Qualifier interface method. -func (q *fieldQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *fieldQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { if rv, ok := obj.(ref.Val); ok { obj = rv.Value() } @@ -1064,7 +1065,7 @@ func (q *fieldQualifier) Qualify(vars Activation, obj any) (any, error) { } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (q *fieldQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *fieldQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { if rv, ok := obj.(ref.Val); ok { obj = rv.Value() } @@ -1110,12 +1111,12 @@ func (q *doubleQualifier) IsOptional() bool { } // Qualify implements the Qualifier interface method. -func (q *doubleQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *doubleQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { val, _, err := q.qualifyInternal(vars, obj, false, false) return val, err } -func (q *doubleQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *doubleQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { return q.qualifyInternal(vars, obj, true, presenceOnly) } @@ -1146,12 +1147,12 @@ func (q *unknownQualifier) IsOptional() bool { } // Qualify returns the unknown value associated with this qualifier. -func (q *unknownQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *unknownQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { return q.value, nil } // QualifyIfPresent is an implementation of the Qualifier interface method. -func (q *unknownQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *unknownQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { return q.value, true, nil } @@ -1160,7 +1161,7 @@ func (q *unknownQualifier) Value() ref.Val { return q.value } -func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, bool, error) { +func applyQualifiers(ctx context.Context, vars Activation, obj any, qualifiers []Qualifier) (any, bool, error) { optObj, isOpt := obj.(*types.Optional) if isOpt { if !optObj.HasValue() { @@ -1175,7 +1176,7 @@ func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, boo isOpt = isOpt || qual.IsOptional() if isOpt { var present bool - qualObj, present, err = qual.QualifyIfPresent(vars, obj, false) + qualObj, present, err = qual.QualifyIfPresent(ctx, vars, obj, false) if err != nil { return nil, false, err } @@ -1186,7 +1187,7 @@ func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, boo return types.OptionalNone, false, nil } } else { - qualObj, err = qual.Qualify(vars, obj) + qualObj, err = qual.Qualify(ctx, vars, obj) if err != nil { return nil, false, err } @@ -1197,8 +1198,8 @@ func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, boo } // attrQualify performs a qualification using the result of an attribute evaluation. -func attrQualify(fac AttributeFactory, vars Activation, obj any, qualAttr Attribute) (any, error) { - val, err := qualAttr.Resolve(vars) +func attrQualify(ctx context.Context, fac AttributeFactory, vars Activation, obj any, qualAttr Attribute) (any, error) { + val, err := qualAttr.Resolve(ctx, vars) if err != nil { return nil, err } @@ -1206,14 +1207,14 @@ func attrQualify(fac AttributeFactory, vars Activation, obj any, qualAttr Attrib if err != nil { return nil, err } - return qual.Qualify(vars, obj) + return qual.Qualify(ctx, vars, obj) } // attrQualifyIfPresent conditionally performs the qualification of the result of attribute is present // on the target object. -func attrQualifyIfPresent(fac AttributeFactory, vars Activation, obj any, qualAttr Attribute, +func attrQualifyIfPresent(ctx context.Context, fac AttributeFactory, vars Activation, obj any, qualAttr Attribute, presenceOnly bool) (any, bool, error) { - val, err := qualAttr.Resolve(vars) + val, err := qualAttr.Resolve(ctx, vars) if err != nil { return nil, false, err } @@ -1221,7 +1222,7 @@ func attrQualifyIfPresent(fac AttributeFactory, vars Activation, obj any, qualAt if err != nil { return nil, false, err } - return qual.QualifyIfPresent(vars, obj, presenceOnly) + return qual.QualifyIfPresent(ctx, vars, obj, presenceOnly) } // refQualify attempts to convert the value to a CEL value and then uses reflection methods to try and diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index b89b2214..acd7f39a 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "errors" "fmt" "reflect" @@ -36,6 +37,7 @@ import ( ) func TestAttributesAbsoluteAttr(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) cont, err := containers.NewContainer(containers.Name("acme.ns")) if err != nil { @@ -57,10 +59,10 @@ func TestAttributesAbsoluteAttr(t *testing.T) { qualB := makeQualifier(t, attrs, nil, 2, "b") qual4 := makeQualifier(t, attrs, nil, 3, uint64(4)) qualFalse := makeQualifier(t, attrs, nil, 4, false) - attr.AddQualifier(qualB) - attr.AddQualifier(qual4) - attr.AddQualifier(qualFalse) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualB) + attr.AddQualifier(ctx, qual4) + attr.AddQualifier(ctx, qualFalse) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -75,7 +77,7 @@ func TestAttributesAbsoluteAttrType(t *testing.T) { // int attr := attrs.AbsoluteAttribute(1, "int") - out, err := attr.Resolve(EmptyActivation()) + out, err := attr.Resolve(context.Background(), EmptyActivation()) if err != nil { t.Fatal(err) } @@ -85,6 +87,7 @@ func TestAttributesAbsoluteAttrType(t *testing.T) { } func TestAttributesAbsoluteAttrError(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) vars, err := NewActivation(map[string]any{ @@ -97,14 +100,15 @@ func TestAttributesAbsoluteAttrError(t *testing.T) { // acme.a.b[4][false] attr := attrs.AbsoluteAttribute(1, "err") qualMsg := makeQualifier(t, attrs, nil, 2, "message") - attr.AddQualifier(qualMsg) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualMsg) + out, err := attr.Resolve(ctx, vars) if err == nil { t.Errorf("attr.Resolve('err') got %v, wanted error", out) } } func TestAttributesRelativeAttr(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) data := map[string]any{ @@ -126,10 +130,10 @@ func TestAttributesRelativeAttr(t *testing.T) { attr := attrs.RelativeAttribute(1, op) qualA := makeQualifier(t, attrs, nil, 2, "a") qualNeg1 := makeQualifier(t, attrs, nil, 3, int64(-1)) - attr.AddQualifier(qualA) - attr.AddQualifier(qualNeg1) - attr.AddQualifier(attrs.AbsoluteAttribute(4, "b")) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualA) + attr.AddQualifier(ctx, qualNeg1) + attr.AddQualifier(ctx, attrs.AbsoluteAttribute(4, "b")) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -139,6 +143,7 @@ func TestAttributesRelativeAttr(t *testing.T) { } func TestAttributesRelativeAttrOneOf(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) cont, err := containers.NewContainer(containers.Name("acme.ns")) if err != nil { @@ -171,10 +176,10 @@ func TestAttributesRelativeAttrOneOf(t *testing.T) { attr := attrs.RelativeAttribute(1, op) qualA := makeQualifier(t, attrs, nil, 2, "a") qualNeg1 := makeQualifier(t, attrs, nil, 3, int64(-1)) - attr.AddQualifier(qualA) - attr.AddQualifier(qualNeg1) - attr.AddQualifier(attrs.MaybeAttribute(4, "b")) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualA) + attr.AddQualifier(ctx, qualNeg1) + attr.AddQualifier(ctx, attrs.MaybeAttribute(4, "b")) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -184,6 +189,7 @@ func TestAttributesRelativeAttrOneOf(t *testing.T) { } func TestAttributesRelativeAttrConditional(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) data := map[string]any{ @@ -211,16 +217,16 @@ func TestAttributesRelativeAttrConditional(t *testing.T) { attrs.AbsoluteAttribute(5, "b"), attrs.AbsoluteAttribute(6, "c")) qual0 := makeQualifier(t, attrs, nil, 7, 0) - condAttr.AddQualifier(qual0) + condAttr.AddQualifier(ctx, qual0) obj := NewConstValue(1, reg.NativeToValue(data)) attr := attrs.RelativeAttribute(1, obj) qualA := makeQualifier(t, attrs, nil, 2, "a") qualNeg1 := makeQualifier(t, attrs, nil, 3, int64(-1)) - attr.AddQualifier(qualA) - attr.AddQualifier(qualNeg1) - attr.AddQualifier(condAttr) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualA) + attr.AddQualifier(ctx, qualNeg1) + attr.AddQualifier(ctx, condAttr) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -230,6 +236,7 @@ func TestAttributesRelativeAttrConditional(t *testing.T) { } func TestAttributesRelativeAttrRelativeQualifier(t *testing.T) { + ctx := context.Background() cont, err := containers.NewContainer(containers.Name("acme.ns")) if err != nil { t.Fatal(err) @@ -282,15 +289,15 @@ func TestAttributesRelativeAttrRelativeQualifier(t *testing.T) { })) relAttr := attrs.RelativeAttribute(4, mp) qualB := makeQualifier(t, attrs, nil, 5, attrs.AbsoluteAttribute(5, "b")) - relAttr.AddQualifier(qualB) + relAttr.AddQualifier(ctx, qualB) attr := attrs.RelativeAttribute(1, obj) qualA := makeQualifier(t, attrs, nil, 2, "a") qualNeg1 := makeQualifier(t, attrs, nil, 3, int64(-1)) - attr.AddQualifier(qualA) - attr.AddQualifier(qualNeg1) - attr.AddQualifier(relAttr) + attr.AddQualifier(ctx, qualA) + attr.AddQualifier(ctx, qualNeg1) + attr.AddQualifier(ctx, relAttr) - out, err := attr.Resolve(vars) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -300,6 +307,7 @@ func TestAttributesRelativeAttrRelativeQualifier(t *testing.T) { } func TestAttributesOneofAttr(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) cont, err := containers.NewContainer(containers.Name("acme.ns")) if err != nil { @@ -318,8 +326,8 @@ func TestAttributesOneofAttr(t *testing.T) { // a.b -> should resolve to acme.ns.a.b per namespace resolution rules. attr := attrs.MaybeAttribute(1, "a") qualB := makeQualifier(t, attrs, nil, 2, "b") - attr.AddQualifier(qualB) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualB) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -329,6 +337,7 @@ func TestAttributesOneofAttr(t *testing.T) { } func TestAttributesConditionalAttrTrueBranch(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) data := map[string]any{ @@ -347,13 +356,13 @@ func TestAttributesConditionalAttrTrueBranch(t *testing.T) { tv := attrs.AbsoluteAttribute(2, "a") fv := attrs.MaybeAttribute(3, "b") qualC := makeQualifier(t, attrs, nil, 4, "c") - fv.AddQualifier(qualC) + fv.AddQualifier(ctx, qualC) cond := attrs.ConditionalAttribute(1, NewConstValue(0, types.True), tv, fv) qualNeg1 := makeQualifier(t, attrs, nil, 5, int64(-1)) qual1 := makeQualifier(t, attrs, nil, 6, int64(1)) - cond.AddQualifier(qualNeg1) - cond.AddQualifier(qual1) - out, err := cond.Resolve(vars) + cond.AddQualifier(ctx, qualNeg1) + cond.AddQualifier(ctx, qual1) + out, err := cond.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -363,6 +372,7 @@ func TestAttributesConditionalAttrTrueBranch(t *testing.T) { } func TestAttributesConditionalAttrFalseBranch(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) data := map[string]any{ @@ -381,13 +391,13 @@ func TestAttributesConditionalAttrFalseBranch(t *testing.T) { tv := attrs.AbsoluteAttribute(2, "a") fv := attrs.MaybeAttribute(3, "b") qualC := makeQualifier(t, attrs, nil, 4, "c") - fv.AddQualifier(qualC) + fv.AddQualifier(ctx, qualC) cond := attrs.ConditionalAttribute(1, NewConstValue(0, types.False), tv, fv) qualNeg1 := makeQualifier(t, attrs, nil, 5, int64(-1)) qual1 := makeQualifier(t, attrs, nil, 6, int64(1)) - cond.AddQualifier(qualNeg1) - cond.AddQualifier(qual1) - out, err := cond.Resolve(vars) + cond.AddQualifier(ctx, qualNeg1) + cond.AddQualifier(ctx, qual1) + out, err := cond.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -715,21 +725,22 @@ func TestAttributesOptional(t *testing.T) { for i, tst := range tests { tc := tst t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ctx := context.Background() i := int64(1) attr := attrs.AbsoluteAttribute(i, tc.varName) for _, q := range tc.quals { i++ - attr.AddQualifier(makeQualifier(t, attrs, nil, i, q)) + attr.AddQualifier(ctx, makeQualifier(t, attrs, nil, i, q)) } for _, oq := range tc.optQuals { i++ - attr.AddQualifier(makeOptQualifier(t, attrs, nil, i, oq)) + attr.AddQualifier(ctx, makeOptQualifier(t, attrs, nil, i, oq)) } vars, err := NewActivation(tc.vars) if err != nil { t.Fatalf("NewActivation() failed: %v", err) } - out, err := attr.Resolve(vars) + out, err := attr.Resolve(ctx, vars) if err != nil { if tc.err != nil { if tc.err.Error() == err.Error() { @@ -747,6 +758,7 @@ func TestAttributesOptional(t *testing.T) { } func TestAttributesConditionalAttrErrorUnknown(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) @@ -754,14 +766,14 @@ func TestAttributesConditionalAttrErrorUnknown(t *testing.T) { tv := attrs.AbsoluteAttribute(2, "a") fv := attrs.MaybeAttribute(3, "b") cond := attrs.ConditionalAttribute(1, NewConstValue(0, types.NewErr("test error")), tv, fv) - out, err := cond.Resolve(EmptyActivation()) + out, err := cond.Resolve(ctx, EmptyActivation()) if err == nil { t.Errorf("Got %v, wanted error", out) } // unk ? a : b condUnk := attrs.ConditionalAttribute(1, NewConstValue(0, types.NewUnknown(1, nil)), tv, fv) - out, err = condUnk.Resolve(EmptyActivation()) + out, err = condUnk.Resolve(ctx, EmptyActivation()) if err != nil { t.Fatal(err) } @@ -771,6 +783,7 @@ func TestAttributesConditionalAttrErrorUnknown(t *testing.T) { } func BenchmarkResolverFieldQualifier(b *testing.B) { + ctx := context.Background() msg := &proto3pb.TestAllTypes{ NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{ SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{ @@ -792,11 +805,11 @@ func BenchmarkResolverFieldQualifier(b *testing.B) { if !found { b.Fatal("FindType() could not find NestedMessage") } - attr.AddQualifier(makeQualifier(b, attrs, testExprTypeToType(b, opType), 2, "single_nested_message")) - attr.AddQualifier(makeQualifier(b, attrs, testExprTypeToType(b, fieldType), 3, "bb")) + attr.AddQualifier(ctx, makeQualifier(b, attrs, testExprTypeToType(b, opType), 2, "single_nested_message")) + attr.AddQualifier(ctx, makeQualifier(b, attrs, testExprTypeToType(b, fieldType), 3, "bb")) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := attr.Resolve(vars) + _, err := attr.Resolve(ctx, vars) if err != nil { b.Fatal(err) } @@ -804,6 +817,7 @@ func BenchmarkResolverFieldQualifier(b *testing.B) { } func TestResolverCustomQualifier(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := &custAttrFactory{ AttributeFactory: NewAttributeFactory(containers.DefaultContainer, reg, reg), @@ -817,8 +831,8 @@ func TestResolverCustomQualifier(t *testing.T) { attr := attrs.AbsoluteAttribute(1, "msg") fieldType := types.NewObjectType("google.expr.proto3.test.TestAllTypes.NestedMessage") qualBB := makeQualifier(t, attrs, fieldType, 2, "bb") - attr.AddQualifier(qualBB) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, qualBB) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Error(err) } @@ -828,6 +842,7 @@ func TestResolverCustomQualifier(t *testing.T) { } func TestAttributesMissingMsg(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) anyPB, _ := anypb.New(&proto3pb.TestAllTypes{}) @@ -838,8 +853,8 @@ func TestAttributesMissingMsg(t *testing.T) { // missing_msg.field attr := attrs.AbsoluteAttribute(1, "missing_msg") field := makeQualifier(t, attrs, nil, 2, "field") - attr.AddQualifier(field) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, field) + out, err := attr.Resolve(ctx, vars) if err == nil { t.Fatalf("got %v, wanted error", out) } @@ -849,6 +864,7 @@ func TestAttributesMissingMsg(t *testing.T) { } func TestAttributeMissingMsgUnknownField(t *testing.T) { + ctx := context.Background() reg := newTestRegistry(t) attrs := NewPartialAttributeFactory(containers.DefaultContainer, reg, reg) anyPB, _ := anypb.New(&proto3pb.TestAllTypes{}) @@ -859,8 +875,8 @@ func TestAttributeMissingMsgUnknownField(t *testing.T) { // missing_msg.field attr := attrs.AbsoluteAttribute(1, "missing_msg") field := makeQualifier(t, attrs, nil, 2, "field") - attr.AddQualifier(field) - out, err := attr.Resolve(vars) + attr.AddQualifier(ctx, field) + out, err := attr.Resolve(ctx, vars) if err != nil { t.Fatal(err) } @@ -1158,7 +1174,7 @@ func TestAttributeStateTracking(t *testing.T) { if err != nil { t.Fatal(err) } - out := i.Eval(in) + out := i.Eval(context.Background(), in) if types.IsUnknown(tc.out) && types.IsUnknown(out) { if !reflect.DeepEqual(tc.out, out) { t.Errorf("got %v, wanted %v", out, tc.out) @@ -1186,6 +1202,7 @@ func TestAttributeStateTracking(t *testing.T) { } func BenchmarkResolverCustomQualifier(b *testing.B) { + ctx := context.Background() reg := newTestRegistry(b) attrs := &custAttrFactory{ AttributeFactory: NewAttributeFactory(containers.DefaultContainer, reg, reg), @@ -1199,9 +1216,9 @@ func BenchmarkResolverCustomQualifier(b *testing.B) { attr := attrs.AbsoluteAttribute(1, "msg") fieldType := types.NewObjectType("google.expr.proto3.test.TestAllTypes.NestedMessage") qualBB := makeQualifier(b, attrs, fieldType, 2, "bb") - attr.AddQualifier(qualBB) + attr.AddQualifier(ctx, qualBB) for i := 0; i < b.N; i++ { - attr.Resolve(vars) + attr.Resolve(ctx, vars) } } @@ -1235,12 +1252,12 @@ func (q *nestedMsgQualifier) IsOptional() bool { return q.opt } -func (q *nestedMsgQualifier) Qualify(vars Activation, obj any) (any, error) { +func (q *nestedMsgQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { pb := obj.(*proto3pb.TestAllTypes_NestedMessage) return pb.GetBb(), nil } -func (q *nestedMsgQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *nestedMsgQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { pb := obj.(*proto3pb.TestAllTypes_NestedMessage) if pb.GetBb() == 0 { return nil, false, nil @@ -1250,7 +1267,7 @@ func (q *nestedMsgQualifier) QualifyIfPresent(vars Activation, obj any, presence func addQualifier(t testing.TB, attr Attribute, qual Qualifier) Attribute { t.Helper() - _, err := attr.AddQualifier(qual) + _, err := attr.AddQualifier(context.Background(), qual) if err != nil { t.Fatalf("attr.AddQualifier(%v) failed: %v", qual, err) } diff --git a/interpreter/decorators.go b/interpreter/decorators.go index 502db35f..cc7f3365 100644 --- a/interpreter/decorators.go +++ b/interpreter/decorators.go @@ -15,6 +15,8 @@ package interpreter import ( + "context" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -104,19 +106,19 @@ func decDisableShortcircuits() InterpretableDecorator { // conditionally precomputing the result. // - build list and map values with constant elements. // - convert 'in' operations to set membership tests if possible. -func decOptimize() InterpretableDecorator { +func decOptimize(ctx context.Context) InterpretableDecorator { return func(i Interpretable) (Interpretable, error) { switch inst := i.(type) { case *evalList: - return maybeBuildListLiteral(i, inst) + return maybeBuildListLiteral(ctx, i, inst) case *evalMap: - return maybeBuildMapLiteral(i, inst) + return maybeBuildMapLiteral(ctx, i, inst) case InterpretableCall: if inst.OverloadID() == overloads.InList { return maybeOptimizeSetMembership(i, inst) } if overloads.IsTypeConversionFunction(inst.Function()) { - return maybeOptimizeConstUnary(i, inst) + return maybeOptimizeConstUnary(ctx, i, inst) } } return i, nil @@ -165,7 +167,7 @@ func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDe } } -func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpretable, error) { +func maybeOptimizeConstUnary(ctx context.Context, i Interpretable, call InterpretableCall) (Interpretable, error) { args := call.Args() if len(args) != 1 { return i, nil @@ -174,24 +176,24 @@ func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpret if !isConst { return i, nil } - val := call.Eval(EmptyActivation()) + val := call.Eval(ctx, EmptyActivation()) if types.IsError(val) { return nil, val.(*types.Err) } return NewConstValue(call.ID(), val), nil } -func maybeBuildListLiteral(i Interpretable, l *evalList) (Interpretable, error) { +func maybeBuildListLiteral(ctx context.Context, i Interpretable, l *evalList) (Interpretable, error) { for _, elem := range l.elems { _, isConst := elem.(InterpretableConst) if !isConst { return i, nil } } - return NewConstValue(l.ID(), l.Eval(EmptyActivation())), nil + return NewConstValue(l.ID(), l.Eval(ctx, EmptyActivation())), nil } -func maybeBuildMapLiteral(i Interpretable, mp *evalMap) (Interpretable, error) { +func maybeBuildMapLiteral(ctx context.Context, i Interpretable, mp *evalMap) (Interpretable, error) { for idx, key := range mp.keys { _, isConst := key.(InterpretableConst) if !isConst { @@ -202,7 +204,7 @@ func maybeBuildMapLiteral(i Interpretable, mp *evalMap) (Interpretable, error) { return i, nil } } - return NewConstValue(mp.ID(), mp.Eval(EmptyActivation())), nil + return NewConstValue(mp.ID(), mp.Eval(ctx, EmptyActivation())), nil } // maybeOptimizeSetMembership may convert an 'in' operation against a list to map key membership diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 56123840..be08f15e 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "fmt" "github.com/google/cel-go/common/functions" @@ -33,7 +34,7 @@ type Interpretable interface { ID() int64 // Eval an Activation to produce an output. - Eval(activation Activation) ref.Val + Eval(ctx context.Context, activation Activation) ref.Val } // InterpretableConst interface for tracking whether the Interpretable is a constant value. @@ -60,22 +61,22 @@ type InterpretableAttribute interface { // Attribute, the Attribute should first be copied before adding the qualifier. Attributes // are not copyable by default, so this is a capable that would need to be added to the // AttributeFactory or specifically to the underlying Attribute implementation. - AddQualifier(Qualifier) (Attribute, error) + AddQualifier(context.Context, Qualifier) (Attribute, error) // Qualify replicates the Attribute.Qualify method to permit extension and interception // of object qualification. - Qualify(vars Activation, obj any) (any, error) + Qualify(ctx context.Context, vars Activation, obj any) (any, error) // QualifyIfPresent qualifies the object if the qualifier is declared or defined on the object. // The 'presenceOnly' flag indicates that the value is not necessary, just a boolean status as // to whether the qualifier is present. - QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) + QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) // IsOptional indicates whether the resulting value is an optional type. IsOptional() bool // Resolve returns the value of the Attribute given the current Activation. - Resolve(Activation) (any, error) + Resolve(context.Context, Activation) (any, error) } // InterpretableCall interface for inspecting Interpretable instructions related to function calls. @@ -121,8 +122,8 @@ func (test *evalTestOnly) ID() int64 { } // Eval implements the Interpretable interface method. -func (test *evalTestOnly) Eval(ctx Activation) ref.Val { - val, err := test.Resolve(ctx) +func (test *evalTestOnly) Eval(ctx context.Context, vars Activation) ref.Val { + val, err := test.Resolve(ctx, vars) // Return an error if the resolve step fails if err != nil { return types.LabelErrNode(test.id, types.WrapErr(err)) @@ -134,12 +135,12 @@ func (test *evalTestOnly) Eval(ctx Activation) ref.Val { } // AddQualifier appends a qualifier that will always and only perform a presence test. -func (test *evalTestOnly) AddQualifier(q Qualifier) (Attribute, error) { +func (test *evalTestOnly) AddQualifier(ctx context.Context, q Qualifier) (Attribute, error) { cq, ok := q.(ConstantQualifier) if !ok { return nil, fmt.Errorf("test only expressions must have constant qualifiers: %v", q) } - return test.InterpretableAttribute.AddQualifier(&testOnlyQualifier{ConstantQualifier: cq}) + return test.InterpretableAttribute.AddQualifier(ctx, &testOnlyQualifier{ConstantQualifier: cq}) } type testOnlyQualifier struct { @@ -147,8 +148,8 @@ type testOnlyQualifier struct { } // Qualify determines whether the test-only qualifier is present on the input object. -func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) { - out, present, err := q.ConstantQualifier.QualifyIfPresent(vars, obj, true) +func (q *testOnlyQualifier) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + out, present, err := q.ConstantQualifier.QualifyIfPresent(ctx, vars, obj, true) if err != nil { return nil, err } @@ -162,9 +163,9 @@ func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) { } // QualifyIfPresent returns whether the target field in the test-only expression is present. -func (q *testOnlyQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { +func (q *testOnlyQualifier) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { // Only ever test for presence. - return q.ConstantQualifier.QualifyIfPresent(vars, obj, true) + return q.ConstantQualifier.QualifyIfPresent(ctx, vars, obj, true) } // QualifierValueEquals determines whether the test-only constant qualifier equals the input value. @@ -192,7 +193,7 @@ func (cons *evalConst) ID() int64 { } // Eval implements the Interpretable interface method. -func (cons *evalConst) Eval(ctx Activation) ref.Val { +func (cons *evalConst) Eval(ctx context.Context, vars Activation) ref.Val { return cons.val } @@ -212,11 +213,11 @@ func (or *evalOr) ID() int64 { } // Eval implements the Interpretable interface method. -func (or *evalOr) Eval(ctx Activation) ref.Val { +func (or *evalOr) Eval(ctx context.Context, vars Activation) ref.Val { var err ref.Val = nil var unk *types.Unknown for _, term := range or.terms { - val := term.Eval(ctx) + val := term.Eval(ctx, vars) boolVal, ok := val.(types.Bool) // short-circuit on true. if ok && boolVal == types.True { @@ -255,11 +256,11 @@ func (and *evalAnd) ID() int64 { } // Eval implements the Interpretable interface method. -func (and *evalAnd) Eval(ctx Activation) ref.Val { +func (and *evalAnd) Eval(ctx context.Context, vars Activation) ref.Val { var err ref.Val = nil var unk *types.Unknown for _, term := range and.terms { - val := term.Eval(ctx) + val := term.Eval(ctx, vars) boolVal, ok := val.(types.Bool) // short-circuit on false. if ok && boolVal == types.False { @@ -299,9 +300,9 @@ func (eq *evalEq) ID() int64 { } // Eval implements the Interpretable interface method. -func (eq *evalEq) Eval(ctx Activation) ref.Val { - lVal := eq.lhs.Eval(ctx) - rVal := eq.rhs.Eval(ctx) +func (eq *evalEq) Eval(ctx context.Context, vars Activation) ref.Val { + lVal := eq.lhs.Eval(ctx, vars) + rVal := eq.rhs.Eval(ctx, vars) if types.IsUnknownOrError(lVal) { return lVal } @@ -338,9 +339,9 @@ func (ne *evalNe) ID() int64 { } // Eval implements the Interpretable interface method. -func (ne *evalNe) Eval(ctx Activation) ref.Val { - lVal := ne.lhs.Eval(ctx) - rVal := ne.rhs.Eval(ctx) +func (ne *evalNe) Eval(ctx context.Context, vars Activation) ref.Val { + lVal := ne.lhs.Eval(ctx, vars) + rVal := ne.rhs.Eval(ctx, vars) if types.IsUnknownOrError(lVal) { return lVal } @@ -369,7 +370,7 @@ type evalZeroArity struct { id int64 function string overload string - impl functions.FunctionOp + impl functions.FunctionContextOp } // ID implements the Interpretable interface method. @@ -378,8 +379,8 @@ func (zero *evalZeroArity) ID() int64 { } // Eval implements the Interpretable interface method. -func (zero *evalZeroArity) Eval(ctx Activation) ref.Val { - return types.LabelErrNode(zero.id, zero.impl()) +func (zero *evalZeroArity) Eval(ctx context.Context, vars Activation) ref.Val { + return types.LabelErrNode(zero.id, zero.impl(ctx)) } // Function implements the InterpretableCall interface method. @@ -403,7 +404,7 @@ type evalUnary struct { overload string arg Interpretable trait int - impl functions.UnaryOp + impl functions.UnaryContextOp nonStrict bool } @@ -413,8 +414,8 @@ func (un *evalUnary) ID() int64 { } // Eval implements the Interpretable interface method. -func (un *evalUnary) Eval(ctx Activation) ref.Val { - argVal := un.arg.Eval(ctx) +func (un *evalUnary) Eval(ctx context.Context, vars Activation) ref.Val { + argVal := un.arg.Eval(ctx, vars) // Early return if the argument to the function is unknown or error. strict := !un.nonStrict if strict && types.IsUnknownOrError(argVal) { @@ -423,7 +424,7 @@ func (un *evalUnary) Eval(ctx Activation) ref.Val { // If the implementation is bound and the argument value has the right traits required to // invoke it, then call the implementation. if un.impl != nil && (un.trait == 0 || (!strict && types.IsUnknownOrError(argVal)) || argVal.Type().HasTrait(un.trait)) { - return types.LabelErrNode(un.id, un.impl(argVal)) + return types.LabelErrNode(un.id, un.impl(ctx, argVal)) } // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the // operand (arg0). @@ -455,7 +456,7 @@ type evalBinary struct { lhs Interpretable rhs Interpretable trait int - impl functions.BinaryOp + impl functions.BinaryContextOp nonStrict bool } @@ -465,9 +466,9 @@ func (bin *evalBinary) ID() int64 { } // Eval implements the Interpretable interface method. -func (bin *evalBinary) Eval(ctx Activation) ref.Val { - lVal := bin.lhs.Eval(ctx) - rVal := bin.rhs.Eval(ctx) +func (bin *evalBinary) Eval(ctx context.Context, vars Activation) ref.Val { + lVal := bin.lhs.Eval(ctx, vars) + rVal := bin.rhs.Eval(ctx, vars) // Early return if any argument to the function is unknown or error. strict := !bin.nonStrict if strict { @@ -481,7 +482,7 @@ func (bin *evalBinary) Eval(ctx Activation) ref.Val { // If the implementation is bound and the argument value has the right traits required to // invoke it, then call the implementation. if bin.impl != nil && (bin.trait == 0 || (!strict && types.IsUnknownOrError(lVal)) || lVal.Type().HasTrait(bin.trait)) { - return types.LabelErrNode(bin.id, bin.impl(lVal, rVal)) + return types.LabelErrNode(bin.id, bin.impl(ctx, lVal, rVal)) } // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the // operand (arg0). @@ -512,12 +513,12 @@ type evalVarArgs struct { overload string args []Interpretable trait int - impl functions.FunctionOp + impl functions.FunctionContextOp nonStrict bool } // NewCall creates a new call Interpretable. -func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionOp) InterpretableCall { +func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionContextOp) InterpretableCall { return &evalVarArgs{ id: id, function: function, @@ -533,12 +534,12 @@ func (fn *evalVarArgs) ID() int64 { } // Eval implements the Interpretable interface method. -func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { +func (fn *evalVarArgs) Eval(ctx context.Context, vars Activation) ref.Val { argVals := make([]ref.Val, len(fn.args)) // Early return if any argument to the function is unknown or error. strict := !fn.nonStrict for i, arg := range fn.args { - argVals[i] = arg.Eval(ctx) + argVals[i] = arg.Eval(ctx, vars) if strict && types.IsUnknownOrError(argVals[i]) { return argVals[i] } @@ -547,7 +548,7 @@ func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { // invoke it, then call the implementation. arg0 := argVals[0] if fn.impl != nil && (fn.trait == 0 || (!strict && types.IsUnknownOrError(arg0)) || arg0.Type().HasTrait(fn.trait)) { - return types.LabelErrNode(fn.id, fn.impl(argVals...)) + return types.LabelErrNode(fn.id, fn.impl(ctx, argVals...)) } // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the // operand (arg0). @@ -586,11 +587,11 @@ func (l *evalList) ID() int64 { } // Eval implements the Interpretable interface method. -func (l *evalList) Eval(ctx Activation) ref.Val { +func (l *evalList) Eval(ctx context.Context, vars Activation) ref.Val { elemVals := make([]ref.Val, 0, len(l.elems)) // If any argument is unknown or error early terminate. for i, elem := range l.elems { - elemVal := elem.Eval(ctx) + elemVal := elem.Eval(ctx, vars) if types.IsUnknownOrError(elemVal) { return elemVal } @@ -632,15 +633,15 @@ func (m *evalMap) ID() int64 { } // Eval implements the Interpretable interface method. -func (m *evalMap) Eval(ctx Activation) ref.Val { +func (m *evalMap) Eval(ctx context.Context, vars Activation) ref.Val { entries := make(map[ref.Val]ref.Val) // If any argument is unknown or error early terminate. for i, key := range m.keys { - keyVal := key.Eval(ctx) + keyVal := key.Eval(ctx, vars) if types.IsUnknownOrError(keyVal) { return keyVal } - valVal := m.vals[i].Eval(ctx) + valVal := m.vals[i].Eval(ctx, vars) if types.IsUnknownOrError(valVal) { return valVal } @@ -696,11 +697,11 @@ func (o *evalObj) ID() int64 { } // Eval implements the Interpretable interface method. -func (o *evalObj) Eval(ctx Activation) ref.Val { +func (o *evalObj) Eval(ctx context.Context, vars Activation) ref.Val { fieldVals := make(map[string]ref.Val) // If any argument is unknown or error early terminate. for i, field := range o.fields { - val := o.vals[i].Eval(ctx) + val := o.vals[i].Eval(ctx, vars) if types.IsUnknownOrError(val) { return val } @@ -748,16 +749,16 @@ func (fold *evalFold) ID() int64 { } // Eval implements the Interpretable interface method. -func (fold *evalFold) Eval(ctx Activation) ref.Val { - foldRange := fold.iterRange.Eval(ctx) +func (fold *evalFold) Eval(ctx context.Context, vars Activation) ref.Val { + foldRange := fold.iterRange.Eval(ctx, vars) if !foldRange.Type().HasTrait(traits.IterableType) { return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange) } // Configure the fold activation with the accumulator initial value. accuCtx := varActivationPool.Get().(*varActivation) - accuCtx.parent = ctx + accuCtx.parent = vars accuCtx.name = fold.accuVar - accuCtx.val = fold.accu.Eval(ctx) + accuCtx.val = fold.accu.Eval(ctx, vars) // If the accumulator starts as an empty list, then the comprehension will build a list // so create a mutable list to optimize the cost of the inner loop. l, ok := accuCtx.val.(traits.Lister) @@ -777,15 +778,15 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val { iterCtx.val = it.Next() // Evaluate the condition, terminate the loop if false. - cond := fold.cond.Eval(iterCtx) + cond := fold.cond.Eval(ctx, iterCtx) condBool, ok := cond.(types.Bool) if !fold.exhaustive && ok && condBool != types.True { break } // Evaluate the evaluation step into accu var. - accuCtx.val = fold.step.Eval(iterCtx) + accuCtx.val = fold.step.Eval(ctx, iterCtx) if fold.interruptable { - if stop, found := ctx.ResolveName("#interrupted"); found && stop == true { + if stop, found := vars.ResolveName("#interrupted"); found && stop == true { interrupted = true break } @@ -798,7 +799,7 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val { } // Compute the result. - res := fold.result.Eval(accuCtx) + res := fold.result.Eval(ctx, accuCtx) varActivationPool.Put(accuCtx) // Convert a mutable list to an immutable one, if the comprehension has generated a list as a result. if !types.IsUnknownOrError(res) && buildingList { @@ -826,8 +827,8 @@ func (e *evalSetMembership) ID() int64 { } // Eval implements the Interpretable interface method. -func (e *evalSetMembership) Eval(ctx Activation) ref.Val { - val := e.arg.Eval(ctx) +func (e *evalSetMembership) Eval(ctx context.Context, vars Activation) ref.Val { + val := e.arg.Eval(ctx, vars) if types.IsUnknownOrError(val) { return val } @@ -845,8 +846,8 @@ type evalWatch struct { } // Eval implements the Interpretable interface method. -func (e *evalWatch) Eval(ctx Activation) ref.Val { - val := e.Interpretable.Eval(ctx) +func (e *evalWatch) Eval(ctx context.Context, vars Activation) ref.Val { + val := e.Interpretable.Eval(ctx, vars) e.observer(e.ID(), e.Interpretable, val) return val } @@ -862,7 +863,7 @@ type evalWatchAttr struct { // AddQualifier creates a wrapper over the incoming qualifier which observes the qualification // result. -func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) { +func (e *evalWatchAttr) AddQualifier(ctx context.Context, q Qualifier) (Attribute, error) { switch qual := q.(type) { // By default, the qualifier is either a constant or an attribute // There may be some custom cases where the attribute is neither. @@ -899,13 +900,13 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) { adapter: e.Adapter(), } } - _, err := e.InterpretableAttribute.AddQualifier(q) + _, err := e.InterpretableAttribute.AddQualifier(ctx, q) return e, err } // Eval implements the Interpretable interface method. -func (e *evalWatchAttr) Eval(vars Activation) ref.Val { - val := e.InterpretableAttribute.Eval(vars) +func (e *evalWatchAttr) Eval(ctx context.Context, vars Activation) ref.Val { + val := e.InterpretableAttribute.Eval(ctx, vars) e.observer(e.ID(), e.InterpretableAttribute, val) return val } @@ -919,8 +920,8 @@ type evalWatchConstQual struct { } // Qualify observes the qualification of a object via a constant boolean, int, string, or uint. -func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) { - out, err := e.ConstantQualifier.Qualify(vars, obj) +func (e *evalWatchConstQual) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + out, err := e.ConstantQualifier.Qualify(ctx, vars, obj) var val ref.Val if err != nil { val = types.LabelErrNode(e.ID(), types.WrapErr(err)) @@ -932,8 +933,8 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) { } // QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. -func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly) +func (e *evalWatchConstQual) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.ConstantQualifier.QualifyIfPresent(ctx, vars, obj, presenceOnly) var val ref.Val if err != nil { val = types.LabelErrNode(e.ID(), types.WrapErr(err)) @@ -962,8 +963,8 @@ type evalWatchAttrQual struct { } // Qualify observes the qualification of a object via a value computed at runtime. -func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) { - out, err := e.Attribute.Qualify(vars, obj) +func (e *evalWatchAttrQual) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + out, err := e.Attribute.Qualify(ctx, vars, obj) var val ref.Val if err != nil { val = types.LabelErrNode(e.ID(), types.WrapErr(err)) @@ -975,8 +976,8 @@ func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) { } // QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. -func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - out, present, err := e.Attribute.QualifyIfPresent(vars, obj, presenceOnly) +func (e *evalWatchAttrQual) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.Attribute.QualifyIfPresent(ctx, vars, obj, presenceOnly) var val ref.Val if err != nil { val = types.LabelErrNode(e.ID(), types.WrapErr(err)) @@ -999,8 +1000,8 @@ type evalWatchQual struct { } // Qualify observes the qualification of a object via a value computed at runtime. -func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) { - out, err := e.Qualifier.Qualify(vars, obj) +func (e *evalWatchQual) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + out, err := e.Qualifier.Qualify(ctx, vars, obj) var val ref.Val if err != nil { val = types.LabelErrNode(e.ID(), types.WrapErr(err)) @@ -1012,8 +1013,8 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) { } // QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. -func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { - out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly) +func (e *evalWatchQual) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.Qualifier.QualifyIfPresent(ctx, vars, obj, presenceOnly) var val ref.Val if err != nil { val = types.LabelErrNode(e.ID(), types.WrapErr(err)) @@ -1035,7 +1036,7 @@ type evalWatchConst struct { } // Eval implements the Interpretable interface method. -func (e *evalWatchConst) Eval(vars Activation) ref.Val { +func (e *evalWatchConst) Eval(ctx context.Context, vars Activation) ref.Val { val := e.Value() e.observer(e.ID(), e.InterpretableConst, val) return val @@ -1053,12 +1054,12 @@ func (or *evalExhaustiveOr) ID() int64 { } // Eval implements the Interpretable interface method. -func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val { +func (or *evalExhaustiveOr) Eval(ctx context.Context, vars Activation) ref.Val { var err ref.Val = nil var unk *types.Unknown isTrue := false for _, term := range or.terms { - val := term.Eval(ctx) + val := term.Eval(ctx, vars) boolVal, ok := val.(types.Bool) // flag the result as true if ok && boolVal == types.True { @@ -1100,12 +1101,12 @@ func (and *evalExhaustiveAnd) ID() int64 { } // Eval implements the Interpretable interface method. -func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val { +func (and *evalExhaustiveAnd) Eval(ctx context.Context, vars Activation) ref.Val { var err ref.Val = nil var unk *types.Unknown isFalse := false for _, term := range and.terms { - val := term.Eval(ctx) + val := term.Eval(ctx, vars) boolVal, ok := val.(types.Bool) // short-circuit on false. if ok && boolVal == types.False { @@ -1149,10 +1150,10 @@ func (cond *evalExhaustiveConditional) ID() int64 { } // Eval implements the Interpretable interface method. -func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { - cVal := cond.attr.expr.Eval(ctx) - tVal, tErr := cond.attr.truthy.Resolve(ctx) - fVal, fErr := cond.attr.falsy.Resolve(ctx) +func (cond *evalExhaustiveConditional) Eval(ctx context.Context, vars Activation) ref.Val { + cVal := cond.attr.expr.Eval(ctx, vars) + tVal, tErr := cond.attr.truthy.Resolve(ctx, vars) + fVal, fErr := cond.attr.falsy.Resolve(ctx, vars) cBool, ok := cVal.(types.Bool) if !ok { return types.ValOrErr(cVal, "no such overload") @@ -1184,8 +1185,8 @@ func (a *evalAttr) ID() int64 { } // AddQualifier implements the InterpretableAttribute interface method. -func (a *evalAttr) AddQualifier(qual Qualifier) (Attribute, error) { - attr, err := a.attr.AddQualifier(qual) +func (a *evalAttr) AddQualifier(ctx context.Context, qual Qualifier) (Attribute, error) { + attr, err := a.attr.AddQualifier(ctx, qual) a.attr = attr return attr, err } @@ -1201,8 +1202,8 @@ func (a *evalAttr) Adapter() types.Adapter { } // Eval implements the Interpretable interface method. -func (a *evalAttr) Eval(ctx Activation) ref.Val { - v, err := a.attr.Resolve(ctx) +func (a *evalAttr) Eval(ctx context.Context, vars Activation) ref.Val { + v, err := a.attr.Resolve(ctx, vars) if err != nil { return types.LabelErrNode(a.ID(), types.WrapErr(err)) } @@ -1210,13 +1211,13 @@ func (a *evalAttr) Eval(ctx Activation) ref.Val { } // Qualify proxies to the Attribute's Qualify method. -func (a *evalAttr) Qualify(ctx Activation, obj any) (any, error) { - return a.attr.Qualify(ctx, obj) +func (a *evalAttr) Qualify(ctx context.Context, vars Activation, obj any) (any, error) { + return a.attr.Qualify(ctx, vars, obj) } // QualifyIfPresent proxies to the Attribute's QualifyIfPresent method. -func (a *evalAttr) QualifyIfPresent(ctx Activation, obj any, presenceOnly bool) (any, bool, error) { - return a.attr.QualifyIfPresent(ctx, obj, presenceOnly) +func (a *evalAttr) QualifyIfPresent(ctx context.Context, vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return a.attr.QualifyIfPresent(ctx, vars, obj, presenceOnly) } func (a *evalAttr) IsOptional() bool { @@ -1224,8 +1225,8 @@ func (a *evalAttr) IsOptional() bool { } // Resolve proxies to the Attribute's Resolve method. -func (a *evalAttr) Resolve(ctx Activation) (any, error) { - return a.attr.Resolve(ctx) +func (a *evalAttr) Resolve(ctx context.Context, vars Activation) (any, error) { + return a.attr.Resolve(ctx, vars) } type evalWatchConstructor struct { @@ -1249,8 +1250,8 @@ func (c *evalWatchConstructor) ID() int64 { } // Eval implements the Interpretable Eval function. -func (c *evalWatchConstructor) Eval(ctx Activation) ref.Val { - val := c.constructor.Eval(ctx) +func (c *evalWatchConstructor) Eval(ctx context.Context, vars Activation) ref.Val { + val := c.constructor.Eval(ctx, vars) c.observer(c.ID(), c.constructor, val) return val } diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 0aca74d8..a0f4bf9c 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -18,6 +18,8 @@ package interpreter import ( + "context" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/types" @@ -118,7 +120,7 @@ func InterruptableEval() InterpretableDecorator { // Optimize will pre-compute operations such as list and map construction and optimize // call arguments to set membership tests. The set of optimizations will increase over time. func Optimize() InterpretableDecorator { - return decOptimize() + return decOptimize(context.TODO()) } // RegexOptimization provides a way to replace an InterpretableCall for a regex function when the diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 1ca97e54..96fd0d0c 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -594,7 +594,7 @@ func testData(t testing.TB) []testCase { { name: "literal_bytes_string2", expr: `string(b"""Kim\t""")`, - out: `Kim `, + out: `Kim `, }, { name: "literal_pb3_msg", @@ -1472,7 +1472,7 @@ func BenchmarkInterpreter(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - prg.Eval(vars) + prg.Eval(context.Background(), vars) } }) } @@ -1492,7 +1492,7 @@ func BenchmarkInterpreterParallel(b *testing.B) { func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - prg.Eval(vars) + prg.Eval(context.Background(), vars) } }) }) @@ -1513,7 +1513,7 @@ func TestInterpreter(t *testing.T) { if tc.out != nil { want = tc.out.(ref.Val) } - got := prg.Eval(vars) + got := prg.Eval(context.Background(), vars) _, expectUnk := want.(*types.Unknown) if expectUnk { if !reflect.DeepEqual(got, want) { @@ -1549,7 +1549,7 @@ func TestInterpreter(t *testing.T) { t.Fatal(err) } t.Run(mode, func(t *testing.T) { - got := prg.Eval(vars) + got := prg.Eval(context.Background(), vars) _, expectUnk := want.(*types.Unknown) if expectUnk { if !reflect.DeepEqual(got, want) { @@ -1648,7 +1648,7 @@ func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { "a": types.True, "b": types.Double(0.999), "c": types.NewStringList(reg, []string{"hello"})}) - result := interpretable.Eval(vars) + result := interpretable.Eval(context.Background(), vars) // Operator "_==_" is at Expr 7, should be evaluated in exhaustive mode // even though "a" is true ev, _ := state.Value(7) @@ -1696,7 +1696,7 @@ func TestInterpreter_InterruptableEval(t *testing.T) { } }, } - out := prg.Eval(ctxVars) + out := prg.Eval(context.Background(), ctxVars) if !types.IsError(out) || out.(*types.Err).String() != "operation interrupted" { t.Errorf("Got %v, wanted operation interrupted error", out) } @@ -1731,7 +1731,7 @@ func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { "a": true, "b": "b", }) - result := i.Eval(vars) + result := i.Eval(context.Background(), vars) rhv, _ := state.Value(3) // "==" should be evaluated in exhaustive mode though unnecessary if rhv != types.True { @@ -1791,7 +1791,7 @@ func TestInterpreter_SetProto2PrimitiveFields(t *testing.T) { vars, _ := NewActivation(map[string]any{ "input": reg.NativeToValue(input), }) - result := eval.Eval(vars) + result := eval.Eval(context.Background(), vars) got, ok := result.Value().(bool) if !ok { t.Fatalf("Got '%v', wanted 'true'.", result) @@ -1826,12 +1826,12 @@ func TestInterpreter_MissingIdentInSelect(t *testing.T) { }, }, NewAttributePattern("a.b").QualString("c")) - result := i.Eval(vars) + result := i.Eval(context.Background(), vars) if !types.IsUnknown(result) { t.Errorf("Got %v, wanted unknown", result) } - result = i.Eval(EmptyActivation()) + result = i.Eval(context.Background(), EmptyActivation()) if !types.IsError(result) { t.Errorf("Got %v, wanted error", result) } @@ -1897,7 +1897,7 @@ func TestInterpreter_TypeConversionOpt(t *testing.T) { if err2 != nil { t.Fatalf("got error, wanted interpretable: %v", i2) } - errVal := i2.Eval(EmptyActivation()) + errVal := i2.Eval(context.Background(), EmptyActivation()) errValStr := errVal.(*types.Err).Error() if errValStr != err.Error() { t.Errorf("got error %s, wanted error %s", errValStr, err.Error()) diff --git a/interpreter/optimizations.go b/interpreter/optimizations.go index 2fc87e69..f40d17d3 100644 --- a/interpreter/optimizations.go +++ b/interpreter/optimizations.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "regexp" "github.com/google/cel-go/common/types" @@ -32,7 +33,7 @@ var MatchesRegexOptimization = &RegexOptimization{ if err != nil { return nil, err } - return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(values ...ref.Val) ref.Val { + return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(ctx context.Context, values ...ref.Val) ref.Val { if len(values) != 2 { return types.NoSuchOverloadErr() } diff --git a/interpreter/planner.go b/interpreter/planner.go index cf371f95..0696fde2 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "fmt" "strings" @@ -201,7 +202,7 @@ func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { } } // Append the qualifier on the attribute. - _, err = attr.AddQualifier(qual) + _, err = attr.AddQualifier(context.TODO(), qual) return attr, err } @@ -306,7 +307,7 @@ func (p *planner) planCallUnary(expr ast.Expr, overload string, impl *functions.Overload, args []Interpretable) (Interpretable, error) { - var fn functions.UnaryOp + var fn functions.UnaryContextOp var trait int var nonStrict bool if impl != nil { @@ -334,7 +335,7 @@ func (p *planner) planCallBinary(expr ast.Expr, overload string, impl *functions.Overload, args []Interpretable) (Interpretable, error) { - var fn functions.BinaryOp + var fn functions.BinaryContextOp var trait int var nonStrict bool if impl != nil { @@ -363,7 +364,7 @@ func (p *planner) planCallVarArgs(expr ast.Expr, overload string, impl *functions.Overload, args []Interpretable) (Interpretable, error) { - var fn functions.FunctionOp + var fn functions.FunctionContextOp var trait int var nonStrict bool if impl != nil { @@ -478,7 +479,7 @@ func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bo } // Add the qualifier to the attribute - _, err = attr.AddQualifier(qual) + _, err = attr.AddQualifier(context.TODO(), qual) return attr, err } diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index d7bfb750..cca216c5 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "testing" "github.com/google/cel-go/common" @@ -471,7 +472,7 @@ func TestPrune(t *testing.T) { if err != nil { t.Fatalf("NewUncheckedInterpretable() failed: %v", err) } - interpretable.Eval(testActivation(t, tst.in)) + interpretable.Eval(context.Background(), testActivation(t, tst.in)) newExpr := PruneAst(parsed.Expr(), parsed.SourceInfo().MacroCalls(), state) if tst.iterRange != "" { if newExpr.Expr().Kind() != ast.ComprehensionKind { diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 687a47b8..a69e7cb9 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -15,6 +15,7 @@ package interpreter import ( + "context" "fmt" "math" "math/rand" @@ -159,7 +160,7 @@ func computeCost(t *testing.T, expr string, vars []*decls.VariableDecl, ctx Acti } } }() - prg.Eval(ctx) + prg.Eval(context.Background(), ctx) // TODO: enable this once all attributes are properly pushed and popped from stack. //if len(costTracker.stack) != 1 { // t.Fatalf(`Expected resulting stack size to be 1 but got %d: %#+v`, len(costTracker.stack), costTracker.stack)