diff --git a/expr_test.go b/expr_test.go index 4cb1902a5..507729947 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1204,6 +1204,21 @@ func TestExpr_fetch_from_func(t *testing.T) { assert.Contains(t, err.Error(), "cannot fetch Value from func()") } +func TestExpr_fetch_from_interface(t *testing.T) { + type FooBar struct { + Value string + } + foobar := &FooBar{"waldo"} + var foobarAny any = foobar + var foobarPtrAny any = &foobarAny + + res, err := expr.Eval("foo.Value", map[string]any{ + "foo": foobarPtrAny, + }) + assert.NoError(t, err) + assert.Equal(t, "waldo", res) +} + func TestExpr_map_default_values(t *testing.T) { env := map[string]any{ "foo": map[string]string{}, diff --git a/test/deref/deref_test.go b/test/deref/deref_test.go index 4af128bde..283631c0d 100644 --- a/test/deref/deref_test.go +++ b/test/deref/deref_test.go @@ -137,3 +137,66 @@ func TestDeref_multiple_pointers(t *testing.T) { require.Equal(t, 44, output) }) } + +func TestDeref_pointer_of_interface(t *testing.T) { + v := 42 + a := &v + b := any(a) + c := any(&b) + t.Run("returned as is", func(t *testing.T) { + output, err := expr.Eval(`c`, map[string]any{ + "c": c, + }) + require.NoError(t, err) + require.Equal(t, c, output) + require.IsType(t, (*interface{})(nil), output) + }) + t.Run("+ works", func(t *testing.T) { + output, err := expr.Eval(`c+2`, map[string]any{ + "c": c, + }) + require.NoError(t, err) + require.Equal(t, 44, output) + }) +} + +func TestDeref_nil(t *testing.T) { + var b *int = nil + c := &b + t.Run("returned as is", func(t *testing.T) { + output, err := expr.Eval(`c`, map[string]any{ + "c": c, + }) + require.NoError(t, err) + require.Equal(t, c, output) + require.IsType(t, (**int)(nil), output) + }) + t.Run("== nil works", func(t *testing.T) { + output, err := expr.Eval(`c == nil`, map[string]any{ + "c": c, + }) + require.NoError(t, err) + require.Equal(t, true, output) + }) +} + +func TestDeref_nil_in_pointer_of_interface(t *testing.T) { + var a *int32 = nil + b := any(a) + c := any(&b) + t.Run("returned as is", func(t *testing.T) { + output, err := expr.Eval(`c`, map[string]any{ + "c": c, + }) + require.NoError(t, err) + require.Equal(t, c, output) + require.IsType(t, (*interface{})(nil), output) + }) + t.Run("== nil works", func(t *testing.T) { + output, err := expr.Eval(`c == nil`, map[string]any{ + "c": c, + }) + require.NoError(t, err) + require.Equal(t, true, output) + }) +} diff --git a/vm/runtime/runtime.go b/vm/runtime/runtime.go index 98c34f5d8..406f85096 100644 --- a/vm/runtime/runtime.go +++ b/vm/runtime/runtime.go @@ -8,6 +8,14 @@ import ( "reflect" ) +func deref(kind reflect.Kind, value reflect.Value) (reflect.Kind, reflect.Value) { + for kind == reflect.Ptr || kind == reflect.Interface { + value = value.Elem() + kind = value.Kind() + } + return kind, value +} + func Fetch(from, i any) any { v := reflect.ValueOf(from) kind := v.Kind() @@ -28,10 +36,8 @@ func Fetch(from, i any) any { // Structs, maps, and slices can be access through a pointer or through // a value, when they are accessed through a pointer we don't want to // copy them to a value. - if kind == reflect.Ptr { - v = reflect.Indirect(v) - kind = v.Kind() - } + // De-reference everything if necessary (interface and pointers) + kind, v = deref(kind, v) // TODO: We can create separate opcodes for each of the cases below to make // the little bit faster. @@ -145,27 +151,13 @@ func Deref(i any) any { v := reflect.ValueOf(i) - if v.Kind() == reflect.Interface { + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { if v.IsNil() { - return i + return nil } v = v.Elem() } -loop: - for v.Kind() == reflect.Ptr { - if v.IsNil() { - return i - } - indirect := reflect.Indirect(v) - switch indirect.Kind() { - case reflect.Struct, reflect.Map, reflect.Array, reflect.Slice: - break loop - default: - v = v.Elem() - } - } - if v.IsValid() { return v.Interface() }