Skip to content

Commit

Permalink
Improve type checking for $env
Browse files Browse the repository at this point in the history
Fixes #462
  • Loading branch information
antonmedv committed Nov 16, 2023
1 parent 0354d1b commit e9607c7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 17 deletions.
18 changes: 18 additions & 0 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ type IdentifierNode struct {
MethodIndex int // index of method, set only if Method is true
}

func (n *IdentifierNode) SetFieldIndex(field []int) {
n.FieldIndex = field
}

func (n *IdentifierNode) SetMethodIndex(methodIndex int) {
n.Method = true
n.MethodIndex = methodIndex
}

type IntegerNode struct {
base
Value int
Expand Down Expand Up @@ -111,6 +120,15 @@ type MemberNode struct {
MethodIndex int
}

func (n *MemberNode) SetFieldIndex(field []int) {
n.FieldIndex = field
}

func (n *MemberNode) SetMethodIndex(methodIndex int) {
n.Method = true
n.MethodIndex = methodIndex
}

type SliceNode struct {
base
Node Node
Expand Down
46 changes: 30 additions & 16 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,34 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
if node.Value == "$env" {
return mapType, info{}
}
if fn, ok := v.config.Builtins[node.Value]; ok {
return v.env(node, node.Value, true)
}

type NodeWithIndexes interface {
ast.Node
SetFieldIndex(field []int)
SetMethodIndex(methodIndex int)
}

func (v *checker) env(node NodeWithIndexes, name string, strict bool) (reflect.Type, info) {
if fn, ok := v.config.Builtins[name]; ok {
return functionType, info{fn: fn}
}
if fn, ok := v.config.Functions[node.Value]; ok {
if fn, ok := v.config.Functions[name]; ok {
return functionType, info{fn: fn}
}
if t, ok := v.config.Types[node.Value]; ok {
if t, ok := v.config.Types[name]; ok {
if t.Ambiguous {
return v.error(node, "ambiguous identifier %v", node.Value)
return v.error(node, "ambiguous identifier %v", name)
}
node.SetFieldIndex(t.FieldIndex)
if t.Method {
node.SetMethodIndex(t.MethodIndex)
}
node.Method = t.Method
node.MethodIndex = t.MethodIndex
node.FieldIndex = t.FieldIndex
return t.Type, info{method: t.Method}
}
if v.config.Strict {
return v.error(node, "unknown name %v", node.Value)
if v.config.Strict && strict {
return v.error(node, "unknown name %v", name)
}
if v.config.DefaultType != nil {
return v.config.DefaultType, info{}
Expand Down Expand Up @@ -433,12 +444,16 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
prop, _ := v.visit(node.Property)

if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" {
// If the index is a constant string, can save some
// cycles later by finding the type of its referent
if name, ok := node.Property.(*ast.StringNode); ok {
if t, ok := v.config.Types[name.Value]; ok {
return t.Type, info{method: t.Method}
} // No error if no type found; it may be added to env between compile and run
strict := v.config.Strict
if node.Optional {
// If user explicitly set optional flag, then we should not
// throw error if field is not found (as user trying to handle
// this case). But if user did not set optional flag, then we
// should throw error if field is not found & v.config.Strict.
strict = false
}
return v.env(node, name.Value, strict)
}
return anyType, info{}
}
Expand All @@ -460,8 +475,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
// the same interface.
return m.Type, info{}
} else {
node.Method = true
node.MethodIndex = m.Index
node.SetMethodIndex(m.Index)
node.Name = name.Value
return m.Type, info{method: true}
}
Expand Down
12 changes: 11 additions & 1 deletion expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,7 @@ func TestEnv_keyword(t *testing.T) {
{"$env[red + irect]", 10},
{"$env['String Map']?.five", ""},
{"$env.red", "n"},
{"$env?.blue", nil},
{"$env?.unknown", nil},
{"$env.mylist[1]", 2},
{"$env?.OtherMap?.a", "b"},
{"$env?.OtherMap?.d", ""},
Expand Down Expand Up @@ -2102,3 +2102,13 @@ func TestIssue461(t *testing.T) {
})
}
}

func TestIssue462(t *testing.T) {
env := map[string]any{
"foo": func() (string, error) {
return "bar", nil
},
}
_, err := expr.Compile(`$env.unknown(int())`, expr.Env(env))
require.Error(t, err)
}

0 comments on commit e9607c7

Please sign in to comment.