Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:customFuncs add ctx support #219

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package cli

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -242,8 +243,8 @@ Usage:
gojq.WithFunction("debug", 0, 0, cli.funcDebug),
gojq.WithFunction("stderr", 0, 0, cli.funcStderr),
gojq.WithFunction("input_filename", 0, 0,
func(iter inputIter) func(any, []any) any {
return func(any, []any) any {
func(iter inputIter) func(context.Context, any, []any) any {
return func(context.Context, any, []any) any {
if fname := iter.Name(); fname != "" && (len(args) > 0 || !opts.InputNull) {
return fname
}
Expand Down Expand Up @@ -408,7 +409,7 @@ func (cli *cli) createMarshaler() marshaler {
return f
}

func (cli *cli) funcDebug(v any, _ []any) any {
func (cli *cli) funcDebug(_ context.Context, v any, _ []any) any {
if err := newEncoder(false, 0).marshal([]any{"DEBUG:", v}, cli.errStream); err != nil {
return err
}
Expand All @@ -418,7 +419,7 @@ func (cli *cli) funcDebug(v any, _ []any) any {
return v
}

func (cli *cli) funcStderr(v any, _ []any) any {
func (cli *cli) funcStderr(_ context.Context, v any, _ []any) any {
if err := newEncoder(false, 0).marshal(v, cli.errStream); err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,8 @@ func (c *compiler) compileBreak(label string) error {
return nil
}

func funcBreak(label string) func(any, []any) any {
return func(v any, _ []any) any {
func funcBreak(label string) func(context.Context, any, []any) any {
return func(_ context.Context, v any, _ []any) any {
return &breakError{label, v}
}
}
Expand Down
2 changes: 1 addition & 1 deletion execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ loop:
for i := 0; i < argcnt; i++ {
args[i] = env.pop()
}
w := v[0].(func(any, []any) any)(x, args)
w := v[0].(func(context.Context, any, []any) any)(env.ctx, x, args)
if e, ok := w.(error); ok {
if er, ok := e.(*exitCodeError); !ok || er.value != nil || er.halt {
err = e
Expand Down
23 changes: 12 additions & 11 deletions func.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gojq

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
Expand Down Expand Up @@ -33,7 +34,7 @@ const (
type function struct {
argcount int
iter bool
callback func(any, []any) any
callback func(context.Context, any, []any) any
}

func (fn function) accept(cnt int) bool {
Expand Down Expand Up @@ -202,31 +203,31 @@ func init() {

func argFunc0(f func(any) any) function {
return function{
argcount0, false, func(v any, _ []any) any {
argcount0, false, func(_ context.Context, v any, _ []any) any {
return f(v)
},
}
}

func argFunc1(f func(_, _ any) any) function {
return function{
argcount1, false, func(v any, args []any) any {
argcount1, false, func(_ context.Context, v any, args []any) any {
return f(v, args[0])
},
}
}

func argFunc2(f func(_, _, _ any) any) function {
return function{
argcount2, false, func(v any, args []any) any {
argcount2, false, func(_ context.Context, v any, args []any) any {
return f(v, args[0], args[1])
},
}
}

func argFunc3(f func(_, _, _, _ any) any) function {
return function{
argcount3, false, func(v any, args []any) any {
argcount3, false, func(_ context.Context, v any, args []any) any {
return f(v, args[0], args[1], args[2])
},
}
Expand Down Expand Up @@ -718,7 +719,7 @@ func funcImplode(v any) any {
return sb.String()
}

func funcSplit(v any, args []any) any {
func funcSplit(_ context.Context, v any, args []any) any {
s, ok := v.(string)
if !ok {
return &func0TypeError{"split", v}
Expand Down Expand Up @@ -809,7 +810,7 @@ func funcFormat(v, x any) any {
if f == nil {
return &formatNotFoundError{format}
}
return internalFuncs[f.Name].callback(v, nil)
return internalFuncs[f.Name].callback(context.Background(), v, nil)
}

var htmlEscaper = strings.NewReplacer(
Expand Down Expand Up @@ -1101,7 +1102,7 @@ func clampIndex(i, min, max int) int {
}
}

func funcFlatten(v any, args []any) any {
func funcFlatten(_ context.Context, v any, args []any) any {
vs, ok := values(v)
if !ok {
return &func0TypeError{"flatten", v}
Expand Down Expand Up @@ -1145,7 +1146,7 @@ func (iter *rangeIter) Next() (any, bool) {
return v, true
}

func funcRange(_ any, xs []any) any {
func funcRange(_ context.Context, _ any, xs []any) any {
for _, x := range xs {
switch x.(type) {
case int, float64, *big.Int:
Expand Down Expand Up @@ -2048,7 +2049,7 @@ func funcCapture(v any) any {
return w
}

func funcError(v any, args []any) any {
func funcError(_ context.Context, v any, args []any) any {
if len(args) > 0 {
v = args[0]
}
Expand All @@ -2063,7 +2064,7 @@ func funcHalt(any) any {
return &exitCodeError{nil, 0, true}
}

func funcHaltError(v any, args []any) any {
func funcHaltError(_ context.Context, v any, args []any) any {
code := 5
if len(args) > 0 {
var ok bool
Expand Down
21 changes: 12 additions & 9 deletions option.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package gojq

import "fmt"
import (
"context"
"fmt"
)

// CompilerOption is a compiler option.
type CompilerOption func(*compiler)
Expand Down Expand Up @@ -39,7 +42,7 @@ func WithVariables(variables []string) CompilerOption {
// function. If you want to emit multiple values, call the empty function,
// accept a filter for its argument, or call another built-in function, then
// use LoadInitModules of the module loader.
func WithFunction(name string, minarity, maxarity int, f func(any, []any) any) CompilerOption {
func WithFunction(name string, minarity, maxarity int, f func(context.Context, any, []any) any) CompilerOption {
return withFunction(name, minarity, maxarity, false, f)
}

Expand All @@ -48,15 +51,15 @@ func WithFunction(name string, minarity, maxarity int, f func(any, []any) any) C
// returns an Iter to emit multiple values. You cannot define both iterator and
// non-iterator functions of the same name (with possibly different arities).
// See also [NewIter], which can be used to convert values or an error to an Iter.
func WithIterFunction(name string, minarity, maxarity int, f func(any, []any) Iter) CompilerOption {
func WithIterFunction(name string, minarity, maxarity int, f func(context.Context, any, []any) Iter) CompilerOption {
return withFunction(name, minarity, maxarity, true,
func(v any, args []any) any {
return f(v, args)
func(ctx context.Context, v any, args []any) any {
return f(ctx, v, args)
},
)
}

func withFunction(name string, minarity, maxarity int, iter bool, f func(any, []any) any) CompilerOption {
func withFunction(name string, minarity, maxarity int, iter bool, f func(context.Context, any, []any) any) CompilerOption {
if !(0 <= minarity && minarity <= maxarity && maxarity <= 30) {
panic(fmt.Sprintf("invalid arity for %q: %d, %d", name, minarity, maxarity))
}
Expand All @@ -71,11 +74,11 @@ func withFunction(name string, minarity, maxarity int, iter bool, f func(any, []
}
c.customFuncs[name] = function{
argcount | fn.argcount, iter,
func(x any, xs []any) any {
func(ctx context.Context, x any, xs []any) any {
if argcount&(1<<len(xs)) != 0 {
return f(x, xs)
return f(ctx, x, xs)
}
return fn.callback(x, xs)
return fn.callback(ctx, x, xs)
},
}
} else {
Expand Down
3 changes: 2 additions & 1 deletion option_function_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gojq_test

import (
"context"
"encoding/json"
"fmt"
"log"
Expand Down Expand Up @@ -31,7 +32,7 @@ func ExampleWithFunction() {
}
code, err := gojq.Compile(
query,
gojq.WithFunction("f", 0, 1, func(x any, xs []any) any {
gojq.WithFunction("f", 0, 1, func(_ context.Context, x any, xs []any) any {
if x, ok := toFloat(x); ok {
if len(xs) == 1 {
if y, ok := toFloat(xs[0]); ok {
Expand Down
3 changes: 2 additions & 1 deletion option_iter_function_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gojq_test

import (
"context"
"fmt"
"log"

Expand Down Expand Up @@ -28,7 +29,7 @@ func ExampleWithIterFunction() {
}
code, err := gojq.Compile(
query,
gojq.WithIterFunction("f", 2, 2, func(_ any, xs []any) gojq.Iter {
gojq.WithIterFunction("f", 2, 2, func(_ context.Context, _ any, xs []any) gojq.Iter {
if x, ok := xs[0].(int); ok {
if y, ok := xs[1].(int); ok {
return &rangeIter{x, y}
Expand Down
Loading