From 7e9f0adb8e633992cfbad39b990075e207f54061 Mon Sep 17 00:00:00 2001
From: Ryan Bullock <ryan@piratel.com>
Date: Fri, 27 Oct 2023 15:53:41 -0700
Subject: [PATCH] Add ExprNative interface. Allows a type to present an expr
 native/friendly value for evaluation.

---
 bench_test.go        | 87 ++++++++++++++++++++++++++++++++++++++++++++
 compiler/compiler.go |  1 +
 conf/config.go       |  3 +-
 conf/types_table.go  | 14 ++++++-
 docgen/docgen.go     |  2 +-
 expr.go              |  9 +++++
 expr_test.go         | 62 +++++++++++++++++++++++++++++++
 vm/program.go        | 19 +++++-----
 vm/vm.go             | 17 +++++++++
 9 files changed, 201 insertions(+), 13 deletions(-)

diff --git a/bench_test.go b/bench_test.go
index 8d1a272c3..eb29cef17 100644
--- a/bench_test.go
+++ b/bench_test.go
@@ -526,3 +526,90 @@ func Benchmark_reduce(b *testing.B) {
 
 	require.Equal(b, 5050, out.(int))
 }
+
+func Benchmark_nativeAdd(b *testing.B) {
+	env := make(map[string]any)
+
+	env["testOne"] = 1
+	env["testTwo"] = 2
+
+	program, err := expr.Compile("testOne + testTwo", expr.Env(env))
+	require.NoError(b, err)
+
+	var out any
+	v := vm.VM{}
+
+	b.ResetTimer()
+	for n := 0; n < b.N; n++ {
+		out, err = v.Run(program, env)
+	}
+	b.StopTimer()
+
+	require.NoError(b, err)
+	require.Equal(b, 3, out.(int))
+}
+
+func Benchmark_nativeEnabledAdd(b *testing.B) {
+	env := make(map[string]any)
+
+	env["testOne"] = 1
+	env["testTwo"] = 2
+
+	program, err := expr.Compile("testOne + testTwo", expr.ExprNative(true), expr.Env(env))
+	require.NoError(b, err)
+
+	var out any
+	v := vm.VM{}
+
+	b.ResetTimer()
+	for n := 0; n < b.N; n++ {
+		out, err = v.Run(program, env)
+	}
+	b.StopTimer()
+
+	require.NoError(b, err)
+	require.Equal(b, 3, out.(int))
+}
+
+func Benchmark_exprNativeAdd(b *testing.B) {
+	env := make(map[string]any)
+
+	env["testOne"] = &exprNativeInt{MyInt: 1}
+	env["testTwo"] = &exprNativeInt{MyInt: 2}
+
+	program, err := expr.Compile("testOne + testTwo", expr.ExprNative(true), expr.Env(env))
+	require.NoError(b, err)
+
+	var out any
+	v := vm.VM{}
+	b.ResetTimer()
+	for n := 0; n < b.N; n++ {
+		out, err = v.Run(program, env)
+	}
+	b.StopTimer()
+
+	require.NoError(b, err)
+	require.Equal(b, 3, out.(int))
+}
+
+func Benchmark_callAdd(b *testing.B) {
+	env := make(map[string]any)
+
+	env["testOne"] = &exprNativeInt{MyInt: 1}
+	env["testTwo"] = &exprNativeInt{MyInt: 2}
+
+	program, err := expr.Compile("testOne.Value() + testTwo.Value()", expr.Env(env))
+	require.NoError(b, err)
+
+	var out any
+	v := vm.VM{}
+
+	b.ResetTimer()
+	for n := 0; n < b.N; n++ {
+		out, err = v.Run(program, env)
+	}
+	b.StopTimer()
+
+	require.NoError(b, err)
+	require.Equal(b, 3, out.(int))
+}
diff --git a/compiler/compiler.go b/compiler/compiler.go
index 8e26d8788..6681d96e7 100644
--- a/compiler/compiler.go
+++ b/compiler/compiler.go
@@ -57,6 +57,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro
 		Arguments: c.arguments,
 		Functions: c.functions,
 		DebugInfo: c.debugInfo,
+		ExprNative: config.ExprNative,
 	}
 	return
 }
diff --git a/conf/config.go b/conf/config.go
index 5fb5e1194..48b81f9fe 100644
--- a/conf/config.go
+++ b/conf/config.go
@@ -19,6 +19,7 @@ type Config struct {
 	ExpectAny   bool
 	Optimize    bool
 	Strict      bool
+	ExprNative  bool
 	ConstFns    map[string]reflect.Value
 	Visitors    []ast.Visitor
 	Functions   map[string]*ast.Function
@@ -61,7 +62,7 @@ func (c *Config) WithEnv(env any) {
 	}
 
 	c.Env = env
-	c.Types = CreateTypesTable(env)
+	c.Types = CreateTypesTable(env, c.ExprNative)
 	c.MapEnv = mapEnv
 	c.DefaultType = mapValueType
 	c.Strict = true
diff --git a/conf/types_table.go b/conf/types_table.go
index 8ebb76c35..9aaf10eb1 100644
--- a/conf/types_table.go
+++ b/conf/types_table.go
@@ -1,6 +1,7 @@
 package conf
 
 import (
+	"github.com/antonmedv/expr/vm"
 	"reflect"
 )
 
@@ -20,7 +21,7 @@ type TypesTable map[string]Tag
 //
 // If map is passed, all items will be treated as variables
 // (key as name, value as type).
-func CreateTypesTable(i any) TypesTable {
+func CreateTypesTable(i any, exprNative bool) TypesTable {
 	if i == nil {
 		return nil
 	}
@@ -57,7 +58,16 @@ func CreateTypesTable(i any) TypesTable {
 				if key.String() == "$env" { // Could check for all keywords here
 					panic("attempt to misuse env keyword as env map key")
 				}
-				types[key.String()] = Tag{Type: reflect.TypeOf(value.Interface())}
+
+				v := value.Interface()
+				if exprNative {
+					if ev, ok := v.(vm.ExprNative); ok {
+						types[key.String()] = Tag{Type: ev.ExprNativeType()}
+						continue
+					}
+				}
+
+				types[key.String()] = Tag{Type: reflect.TypeOf(v)}
 			}
 		}
 
diff --git a/docgen/docgen.go b/docgen/docgen.go
index a1145586f..8eab554a8 100644
--- a/docgen/docgen.go
+++ b/docgen/docgen.go
@@ -83,7 +83,7 @@ func CreateDoc(i any) *Context {
 		PkgPath:   dereference(reflect.TypeOf(i)).PkgPath(),
 	}
 
-	for name, t := range conf.CreateTypesTable(i) {
+	for name, t := range conf.CreateTypesTable(i, false) {
 		if t.Ambiguous {
 			continue
 		}
diff --git a/expr.go b/expr.go
index eb9eb7683..53bfd4722 100644
--- a/expr.go
+++ b/expr.go
@@ -100,6 +100,15 @@ func Optimize(b bool) Option {
 	}
 }
 
+// ExprNative sets a flag in compiled program teling the runtime to use the value return by ExprNative interface instead of the direct variable
+// This option must be passed BEFORE the Env() option during the compile phase or type checking will most likely be broken
+// Only applies to map environments
+func ExprNative(b bool) Option {
+	return func(c *conf.Config) {
+		c.ExprNative = b
+	}
+}
+
 // Patch adds visitor to list of visitors what will be applied before compiling AST to bytecode.
 func Patch(visitor ast.Visitor) Option {
 	return func(c *conf.Config) {
diff --git a/expr_test.go b/expr_test.go
index 4cb1902a5..d489a2086 100644
--- a/expr_test.go
+++ b/expr_test.go
@@ -435,6 +435,68 @@ func ExampleAllowUndefinedVariables_zero_value_functions() {
 	// Output: [foo bar]
 }
 
+type exprNativeInt struct {
+	MyInt int
+}
+
+func (n *exprNativeInt) ExprNativeValue() any {
+	return n.MyInt
+}
+
+func (n *exprNativeInt) ExprNativeType() reflect.Type {
+	return reflect.TypeOf(n.MyInt)
+}
+
+func (n *exprNativeInt) Value() int {
+	return n.MyInt
+}
+
+func ExampleExprNative() {
+	env := make(map[string]any)
+
+	env["testOne"] = &exprNativeInt{MyInt: 1}
+	env["testTwo"] = &exprNativeInt{MyInt: 2}
+
+	program, err := expr.Compile("testOne + testTwo", expr.ExprNative(true), expr.Env(env))
+	if err != nil {
+		fmt.Printf("%v", err)
+		return
+	}
+
+	output, err := expr.Run(program, env)
+	if err != nil {
+		fmt.Printf("%v", err)
+		return
+	}
+
+	fmt.Printf("%T(%v)", output, output)
+
+	// Output: int(3)
+}
+
+func ExampleCallAdd() {
+	env := make(map[string]any)
+
+	env["testOne"] = &exprNativeInt{MyInt: 1}
+	env["testTwo"] = &exprNativeInt{MyInt: 2}
+
+	program, err := expr.Compile("testOne.Value() + testTwo.Value()", expr.Env(env))
+	if err != nil {
+		fmt.Printf("%v", err)
+		return
+	}
+
+	output, err := expr.Run(program, env)
+	if err != nil {
+		fmt.Printf("%v", err)
+		return
+	}
+
+	fmt.Printf("%T(%v)", output, output)
+
+	// Output: int(3)
+}
+
 type patcher struct{}
 
 func (p *patcher) Visit(node *ast.Node) {
diff --git a/vm/program.go b/vm/program.go
index c45a2bff2..f32b4a6bb 100644
--- a/vm/program.go
+++ b/vm/program.go
@@ -16,15 +16,16 @@ import (
 )
 
 type Program struct {
-	Node      ast.Node
-	Source    *file.Source
-	Locations []file.Location
-	Variables []any
-	Constants []any
-	Bytecode  []Opcode
-	Arguments []int
-	Functions []Function
-	DebugInfo map[string]string
+	Node       ast.Node
+	Source     *file.Source
+	Locations  []file.Location
+	Variables  []any
+	Constants  []any
+	Bytecode   []Opcode
+	Arguments  []int
+	Functions  []Function
+	DebugInfo  map[string]string
+	ExprNative bool
 }
 
 func (program *Program) Disassemble() string {
diff --git a/vm/vm.go b/vm/vm.go
index d020a31a8..4b1a93e22 100644
--- a/vm/vm.go
+++ b/vm/vm.go
@@ -34,6 +34,7 @@ type VM struct {
 	debug        bool
 	step         chan struct{}
 	curr         chan int
+	exprNative   bool
 	memory       uint
 	memoryBudget uint
 }
@@ -47,6 +48,14 @@ type Scope struct {
 	Acc     any
 }
 
+// ExprNative is an interface for a type to provide values that expr can use natively
+type ExprNative interface {
+	// ExprNativeValue returns a native value that expr can use directly
+	ExprNativeValue() any
+	// ExprNativeType returns the reflect.Type of the type that will be returned by ExprNativeValue
+	ExprNativeType() reflect.Type
+}
+
 func Debug() *VM {
 	vm := &VM{
 		debug: true,
@@ -83,6 +92,7 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
 	vm.memoryBudget = MemoryBudget
 	vm.memory = 0
 	vm.ip = 0
+	vm.exprNative = program.ExprNative
 
 	for vm.ip < len(program.Bytecode) {
 		if vm.debug {
@@ -514,6 +524,12 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
 }
 
 func (vm *VM) push(value any) {
+	if vm.exprNative {
+		if ev, ok := value.(ExprNative); ok {
+			value = ev.ExprNativeValue()
+		}
+	}
+
 	vm.stack = append(vm.stack, value)
 }
 
@@ -524,6 +540,7 @@ func (vm *VM) current() any {
 func (vm *VM) pop() any {
 	value := vm.stack[len(vm.stack)-1]
 	vm.stack = vm.stack[:len(vm.stack)-1]
+
 	return value
 }