diff --git a/runtime/sam/expr/function/ip.go b/runtime/sam/expr/function/ip.go index aa1b26f8c5..c91dfe33ca 100644 --- a/runtime/sam/expr/function/ip.go +++ b/runtime/sam/expr/function/ip.go @@ -83,8 +83,16 @@ var errMatch = errors.New("match") func (c *CIDRMatch) Call(_ super.Allocator, args []super.Value) super.Value { maskVal := args[0] - if maskVal.Type().ID() != super.IDNet { - return c.zctx.WrapError("cidr_match: not a net", maskVal) + if id := maskVal.Type().ID(); id != super.IDNet && id != super.IDNull { + val := c.zctx.WrapError("cidr_match: not a net", maskVal) + if maskVal.IsNull() { + val = super.NewValue(val.Type(), nil) + } + return val + + } + if maskVal.IsNull() || args[1].IsNull() { + return super.NewValue(super.TypeBool, nil) } prefix := super.DecodeNet(maskVal.Bytes()) err := args[1].Walk(func(typ super.Type, body zcode.Bytes) error { diff --git a/runtime/vam/expr/function/cidrmatch.go b/runtime/vam/expr/function/cidrmatch.go new file mode 100644 index 0000000000..811afee475 --- /dev/null +++ b/runtime/vam/expr/function/cidrmatch.go @@ -0,0 +1,48 @@ +package function + +import ( + "github.com/brimdata/super" + "github.com/brimdata/super/runtime/vam/expr" + "github.com/brimdata/super/vector" +) + +type CIDRMatch struct { + zctx *super.Context + pw *expr.PredicateWalk +} + +func NewCIDRMatch(zctx *super.Context) *CIDRMatch { + return &CIDRMatch{zctx, expr.NewPredicateWalk(cidrMatch)} +} + +func (c *CIDRMatch) Call(args ...vector.Any) vector.Any { + if id := args[0].Type().ID(); id != super.IDNet && id != super.IDNull { + out := vector.NewWrappedError(c.zctx, "cidr_match: not a net", args[0]) + out.Nulls = vector.Or(vector.NullsOf(args[0]), vector.NullsOf(args[1])) + return out + } + return c.pw.Eval(args...) +} + +func cidrMatch(vec ...vector.Any) vector.Any { + netVec, valVec := vec[0], vec[1] + nulls := vector.Or(vector.NullsOf(netVec), vector.NullsOf(valVec)) + if id := valVec.Type().ID(); id != super.IDIP { + return vector.NewConst(super.False, valVec.Len(), nulls) + } + out := vector.NewBoolEmpty(valVec.Len(), nulls) + for i := range netVec.Len() { + net, null := vector.NetValue(netVec, i) + if null { + continue + } + ip, null := vector.IPValue(valVec, i) + if null { + continue + } + if net.Contains(ip) { + out.Set(i) + } + } + return out +} diff --git a/runtime/vam/expr/function/function.go b/runtime/vam/expr/function/function.go index b90e3f3bb5..506a201e22 100644 --- a/runtime/vam/expr/function/function.go +++ b/runtime/vam/expr/function/function.go @@ -26,6 +26,10 @@ func New(zctx *super.Context, name string, narg int) (expr.Function, field.Path, f = &Bucket{zctx: zctx, name: name} case "ceil": f = &Ceil{zctx} + case "cidr_match": + argmin = 2 + argmax = 2 + f = NewCIDRMatch(zctx) case "coalesce": argmax = -1 f = &Coalesce{} diff --git a/runtime/vam/expr/logic.go b/runtime/vam/expr/logic.go index d472fc8bd1..2971cb45c1 100644 --- a/runtime/vam/expr/logic.go +++ b/runtime/vam/expr/logic.go @@ -210,11 +210,11 @@ type In struct { zctx *super.Context lhs Evaluator rhs Evaluator - eq *Compare + pw *PredicateWalk } func NewIn(zctx *super.Context, lhs, rhs Evaluator) *In { - return &In{zctx, lhs, rhs, NewCompare(zctx, nil, nil, "==")} + return &In{zctx, lhs, rhs, NewPredicateWalk(NewCompare(zctx, nil, nil, "==").eval)} } func (i *In) Eval(this vector.Any) vector.Any { @@ -229,10 +229,18 @@ func (i *In) eval(vecs ...vector.Any) vector.Any { if rhs.Type().Kind() == super.ErrorKind { return rhs } - return i.evalResursive(lhs, rhs) + return i.pw.Eval(lhs, rhs) } -func (i *In) evalResursive(vecs ...vector.Any) vector.Any { +type PredicateWalk struct { + pred func(...vector.Any) vector.Any +} + +func NewPredicateWalk(pred func(...vector.Any) vector.Any) *PredicateWalk { + return &PredicateWalk{pred} +} + +func (p *PredicateWalk) Eval(vecs ...vector.Any) vector.Any { lhs, rhs := vecs[0], vecs[1] rhs = vector.Under(rhs) rhsOrig := rhs @@ -248,32 +256,32 @@ func (i *In) evalResursive(vecs ...vector.Any) vector.Any { if index != nil { f = vector.NewView(f, index) } - out = vector.Or(out, toBool(i.evalResursive(lhs, f))) + out = vector.Or(out, toBool(p.Eval(lhs, f))) } return out case *vector.Array: - return i.evalForList(lhs, rhs.Values, rhs.Offsets, index) + return p.evalForList(lhs, rhs.Values, rhs.Offsets, index) case *vector.Set: - return i.evalForList(lhs, rhs.Values, rhs.Offsets, index) + return p.evalForList(lhs, rhs.Values, rhs.Offsets, index) case *vector.Map: - return vector.Or(i.evalForList(lhs, rhs.Keys, rhs.Offsets, index), - i.evalForList(lhs, rhs.Values, rhs.Offsets, index)) + return vector.Or(p.evalForList(lhs, rhs.Keys, rhs.Offsets, index), + p.evalForList(lhs, rhs.Values, rhs.Offsets, index)) case *vector.Union: if index != nil { panic("vector.Union unexpected in vector.View") } - return vector.Apply(true, i.evalResursive, lhs, rhs) + return vector.Apply(true, p.Eval, lhs, rhs) case *vector.Error: if index != nil { panic("vector.Error unexpected in vector.View") } - return i.evalResursive(lhs, rhs.Vals) + return p.Eval(lhs, rhs.Vals) default: - return i.eq.eval(lhs, rhsOrig) + return p.pred(lhs, rhsOrig) } } -func (i *In) evalForList(lhs, rhs vector.Any, offsets, index []uint32) *vector.Bool { +func (p *PredicateWalk) evalForList(lhs, rhs vector.Any, offsets, index []uint32) *vector.Bool { out := vector.NewBoolEmpty(lhs.Len(), nil) var lhsIndex, rhsIndex []uint32 for j := range lhs.Len() { @@ -293,7 +301,7 @@ func (i *In) evalForList(lhs, rhs vector.Any, offsets, index []uint32) *vector.B } lhsView := vector.NewView(lhs, lhsIndex) rhsView := vector.NewView(rhs, rhsIndex) - if toBool(i.evalResursive(lhsView, rhsView)).TrueCount() > 0 { + if toBool(p.Eval(lhsView, rhsView)).TrueCount() > 0 { out.Set(j) } } diff --git a/runtime/ztests/expr/function/cidr_match.yaml b/runtime/ztests/expr/function/cidr_match.yaml new file mode 100644 index 0000000000..b6fe581c9f --- /dev/null +++ b/runtime/ztests/expr/function/cidr_match.yaml @@ -0,0 +1,38 @@ +zed: | + yield cidr_match(this[0], this[1]) + +vector: true + +input: | + [1.1.0.0/16, 1.1.1.1] + [1.1.0.0/16, {a:1.1.1.1,b:2.2.2.2}] + [1.1.0.0/16, [2.2.2.2,1.1.1.1]] + [1.1.0.0/16, |[2.2.2.2,1.1.1.1]|] + [1.1.0.0/16, 2.2.2.2] + [1.1.0.0/16, {a:2.2.2.2,b:3.3.3.3}] + [1.1.0.0/16, [2.2.2.2,3.3.3.3]] + [1.1.0.0/16, null(ip)] + [1.1.0.0/16, null] + [1.1.0.0/16, null(string)] + [null(net), 1.1.1.1] + [null, null] + [null(string), 1.1.1.1] + [1.1.0.0/16, "s"] + ["s", 1.1.1.1] + +output: | + true + true + true + true + false + false + false + null(bool) + null(bool) + null(bool) + null(bool) + null(bool) + null(error({message:string,on:string})) + false + error({message:"cidr_match: not a net",on:"s"})