Skip to content

Commit

Permalink
fix bytecode encoding/decoding of builtin modules (#154)
Browse files Browse the repository at this point in the history
* fix bytecode encoding/decoding of builtin modules

* Bytecode.Decode() to take map[string]objects.Importable

* add objects.ModuleMap

* update docs

* stdlib.GetModuleMap()
  • Loading branch information
d5 authored Mar 20, 2019
1 parent e785e38 commit 3c30109
Show file tree
Hide file tree
Showing 22 changed files with 308 additions and 140 deletions.
26 changes: 15 additions & 11 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Options struct {
Version string

// Import modules
Modules map[string]objects.Importable
Modules *objects.ModuleMap
}

// Run CLI
Expand All @@ -56,7 +56,7 @@ func Run(options *Options) {

if options.InputFile == "" {
// REPL
runREPL(options.Modules, os.Stdin, os.Stdout)
RunREPL(options.Modules, os.Stdin, os.Stdout)
return
}

Expand All @@ -67,17 +67,17 @@ func Run(options *Options) {
}

if options.CompileOutput != "" {
if err := compileOnly(options.Modules, inputData, options.InputFile, options.CompileOutput); err != nil {
if err := CompileOnly(options.Modules, inputData, options.InputFile, options.CompileOutput); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
} else if filepath.Ext(options.InputFile) == sourceFileExt {
if err := compileAndRun(options.Modules, inputData, options.InputFile); err != nil {
if err := CompileAndRun(options.Modules, inputData, options.InputFile); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
} else {
if err := runCompiled(inputData); err != nil {
if err := RunCompiled(options.Modules, inputData); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
Expand Down Expand Up @@ -116,7 +116,8 @@ func doHelp() {
fmt.Println()
}

func compileOnly(modules map[string]objects.Importable, data []byte, inputFile, outputFile string) (err error) {
// CompileOnly compiles the source code and writes the compiled binary into outputFile.
func CompileOnly(modules *objects.ModuleMap, data []byte, inputFile, outputFile string) (err error) {
bytecode, err := compileSrc(modules, data, filepath.Base(inputFile))
if err != nil {
return
Expand Down Expand Up @@ -148,7 +149,8 @@ func compileOnly(modules map[string]objects.Importable, data []byte, inputFile,
return
}

func compileAndRun(modules map[string]objects.Importable, data []byte, inputFile string) (err error) {
// CompileAndRun compiles the source code and executes it.
func CompileAndRun(modules *objects.ModuleMap, data []byte, inputFile string) (err error) {
bytecode, err := compileSrc(modules, data, filepath.Base(inputFile))
if err != nil {
return
Expand All @@ -164,9 +166,10 @@ func compileAndRun(modules map[string]objects.Importable, data []byte, inputFile
return
}

func runCompiled(data []byte) (err error) {
// RunCompiled reads the compiled binary from file and executes it.
func RunCompiled(modules *objects.ModuleMap, data []byte) (err error) {
bytecode := &compiler.Bytecode{}
err = bytecode.Decode(bytes.NewReader(data))
err = bytecode.Decode(bytes.NewReader(data), modules)
if err != nil {
return
}
Expand All @@ -181,7 +184,8 @@ func runCompiled(data []byte) (err error) {
return
}

func runREPL(modules map[string]objects.Importable, in io.Reader, out io.Writer) {
// RunREPL starts REPL.
func RunREPL(modules *objects.ModuleMap, in io.Reader, out io.Writer) {
stdin := bufio.NewScanner(in)

fileSet := source.NewFileSet()
Expand Down Expand Up @@ -254,7 +258,7 @@ func runREPL(modules map[string]objects.Importable, in io.Reader, out io.Writer)
}
}

func compileSrc(modules map[string]objects.Importable, src []byte, filename string) (*compiler.Bytecode, error) {
func compileSrc(modules *objects.ModuleMap, src []byte, filename string) (*compiler.Bytecode, error) {
fileSet := source.NewFileSet()
srcFile := fileSet.AddFile(filename, -1, len(src))

Expand Down
64 changes: 64 additions & 0 deletions cli/cli_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package cli_test

import (
"io/ioutil"
"os"
"path/filepath"
"regexp"
"testing"

"github.com/d5/tengo/assert"
"github.com/d5/tengo/cli"
"github.com/d5/tengo/stdlib"
)

func TestCLICompileAndRun(t *testing.T) {
tempDir := filepath.Join(os.TempDir(), "tengo_tests")
_ = os.MkdirAll(tempDir, os.ModePerm)
binFile := filepath.Join(tempDir, "cli_bin")
outFile := filepath.Join(tempDir, "cli_out")
defer func() {
_ = os.RemoveAll(tempDir)
}()

src := []byte(`
os := import("os")
rand := import("rand")
times := import("times")
rand.seed(times.time_nanosecond(times.now()))
rand_num := func() {
return rand.intn(100)
}
file := os.create("` + outFile + `")
file.write_string("random number is " + rand_num())
file.close()
`)

mods := stdlib.GetModuleMap(stdlib.AllModuleNames()...)

err := cli.CompileOnly(mods, src, "src", binFile)
if !assert.NoError(t, err) {
return
}

compiledBin, err := ioutil.ReadFile(binFile)
if !assert.NoError(t, err) {
return
}

err = cli.RunCompiled(mods, compiledBin)
if !assert.NoError(t, err) {
return
}

read, err := ioutil.ReadFile(outFile)
if !assert.NoError(t, err) {
return
}
ok, err := regexp.Match(`^random number is \d+$`, read)
assert.NoError(t, err)
assert.True(t, ok, string(read))
}
2 changes: 1 addition & 1 deletion cmd/tengo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func main() {
ShowVersion: showVersion,
Version: version,
CompileOutput: compileOutput,
Modules: stdlib.GetModules(stdlib.AllModuleNames()...),
Modules: stdlib.GetModuleMap(stdlib.AllModuleNames()...),
InputFile: flag.Arg(0),
})
}
57 changes: 46 additions & 11 deletions compiler/bytecode_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package compiler

import (
"encoding/gob"
"fmt"
"io"

"github.com/d5/tengo/objects"
)

// Decode reads Bytecode data from the reader.
func (b *Bytecode) Decode(r io.Reader) error {
func (b *Bytecode) Decode(r io.Reader, modules *objects.ModuleMap) error {
if modules == nil {
modules = objects.NewModuleMap()
}

dec := gob.NewDecoder(r)

if err := dec.Decode(&b.FileSet); err != nil {
Expand All @@ -25,38 +30,68 @@ func (b *Bytecode) Decode(r io.Reader) error {
return err
}
for i, v := range b.Constants {
b.Constants[i] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return err
}
b.Constants[i] = fv
}

return nil
}

func fixDecoded(o objects.Object) objects.Object {
func fixDecoded(o objects.Object, modules *objects.ModuleMap) (objects.Object, error) {
switch o := o.(type) {
case *objects.Bool:
if o.IsFalsy() {
return objects.FalseValue
return objects.FalseValue, nil
}
return objects.TrueValue
return objects.TrueValue, nil
case *objects.Undefined:
return objects.UndefinedValue
return objects.UndefinedValue, nil
case *objects.Array:
for i, v := range o.Value {
o.Value[i] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[i] = fv
}
case *objects.ImmutableArray:
for i, v := range o.Value {
o.Value[i] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[i] = fv
}
case *objects.Map:
for k, v := range o.Value {
o.Value[k] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[k] = fv
}
case *objects.ImmutableMap:
modName := moduleName(o)
if mod := modules.GetBuiltinModule(modName); mod != nil {
return mod.AsImmutableMap(modName), nil
}

for k, v := range o.Value {
o.Value[k] = fixDecoded(v)
// encoding of user function not supported
if _, isUserFunction := v.(*objects.UserFunction); isUserFunction {
return nil, fmt.Errorf("user function not decodable")
}

fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[k] = fv
}
}

return o
return o, nil
}
2 changes: 1 addition & 1 deletion compiler/bytecode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func testBytecodeSerialization(t *testing.T, b *compiler.Bytecode) {
assert.NoError(t, err)

r := &compiler.Bytecode{}
err = r.Decode(bytes.NewReader(buf.Bytes()))
err = r.Decode(bytes.NewReader(buf.Bytes()), nil)
assert.NoError(t, err)

assert.Equal(t, b.FileSet, r.FileSet)
Expand Down
14 changes: 7 additions & 7 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Compiler struct {
symbolTable *SymbolTable
scopes []CompilationScope
scopeIndex int
importModules map[string]objects.Importable
modules *objects.ModuleMap
compiledModules map[string]*objects.CompiledFunction
allowFileImport bool
loops []*Loop
Expand All @@ -34,7 +34,7 @@ type Compiler struct {
}

// NewCompiler creates a Compiler.
func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, importModules map[string]objects.Importable, trace io.Writer) *Compiler {
func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, modules *objects.ModuleMap, trace io.Writer) *Compiler {
mainScope := CompilationScope{
symbolInit: make(map[string]bool),
sourceMap: make(map[int]source.Pos),
Expand All @@ -51,8 +51,8 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object
}

// builtin modules
if importModules == nil {
importModules = make(map[string]objects.Importable)
if modules == nil {
modules = objects.NewModuleMap()
}

return &Compiler{
Expand All @@ -63,7 +63,7 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object
scopeIndex: 0,
loopIndex: -1,
trace: trace,
importModules: importModules,
modules: modules,
compiledModules: make(map[string]*objects.CompiledFunction),
}
}
Expand Down Expand Up @@ -513,7 +513,7 @@ func (c *Compiler) Compile(node ast.Node) error {
return c.errorf(node, "empty module name")
}

if mod, ok := c.importModules[node.ModuleName]; ok {
if mod := c.modules.Get(node.ModuleName); mod != nil {
v, err := mod.Import(node.ModuleName)
if err != nil {
return err
Expand Down Expand Up @@ -644,7 +644,7 @@ func (c *Compiler) EnableFileImport(enable bool) {
}

func (c *Compiler) fork(file *source.File, modulePath string, symbolTable *SymbolTable) *Compiler {
child := NewCompiler(file, symbolTable, nil, c.importModules, c.trace)
child := NewCompiler(file, symbolTable, nil, c.modules, c.trace)
child.modulePath = modulePath // module file path
child.parent = c // parent to set to current compiler

Expand Down
12 changes: 4 additions & 8 deletions docs/interoperability.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,13 @@ Users can add and use a custom user type in Tengo code by implementing [Object](

To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions.

#### Script.SetImports(modules map[string]objects.Importable)
#### Script.SetImports(modules *objects.ModuleMap)

SetImports sets the import modules with corresponding names. Script **does not** include any modules by default. You can use this function to include the [Standard Library](https://github.com/d5/tengo/blob/master/docs/stdlib.md).

```golang
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))

s.SetImports(map[string]objects.Importable{
"math": stdlib.BuiltinModules["math"],
})
// or
s.SetImports(stdlib.GetModules("math"))
// or, to include all stdlib at once
s.SetImports(stdlib.GetModules(stdlib.AllModuleNames()...))
Expand All @@ -140,9 +136,9 @@ You can also include Tengo's written module using `objects.SourceModule` (which
```golang
s := script.New([]byte(`double := import("double"); a := double(20)`))

s.SetImports(map[string]objects.Importable{
"double": &objects.SourceModule{Src: []byte(`export func(x) { return x * 2 }`)},
})
mods := objects.NewModuleMap()
mods.AddSourceModule("double", []byte(`export func(x) { return x * 2 }`))
s.SetImports(mods)
```


Expand Down
13 changes: 10 additions & 3 deletions objects/builtin_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ type BuiltinModule struct {
}

// Import returns an immutable map for the module.
func (m *BuiltinModule) Import(name string) (interface{}, error) {
func (m *BuiltinModule) Import(moduleName string) (interface{}, error) {
return m.AsImmutableMap(moduleName), nil
}

// AsImmutableMap converts builtin module into an immutable map.
func (m *BuiltinModule) AsImmutableMap(moduleName string) *ImmutableMap {
attrs := make(map[string]Object, len(m.Attrs))
for k, v := range m.Attrs {
attrs[k] = v.Copy()
}
attrs["__module_name__"] = &String{Value: name}
return &ImmutableMap{Value: attrs}, nil

attrs["__module_name__"] = &String{Value: moduleName}

return &ImmutableMap{Value: attrs}
}
2 changes: 1 addition & 1 deletion objects/importable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package objects
// Importable interface represents importable module instance.
type Importable interface {
// Import should return either an Object or module source code ([]byte).
Import(name string) (interface{}, error)
Import(moduleName string) (interface{}, error)
}
Loading

0 comments on commit 3c30109

Please sign in to comment.