diff --git a/evaluator/builtins.go b/evaluator/builtins.go index 27a4fed..3fd7b73 100644 --- a/evaluator/builtins.go +++ b/evaluator/builtins.go @@ -4,217 +4,249 @@ import ( "github.com/JunNishimura/go-lisp/object" ) -var builtinFuncs = map[string]*object.Builtin{ - "+": { - Fn: func(args ...object.Object) object.Object { - var sum int64 - for _, arg := range args { - if arg.Type() != object.INTEGER_OBJ { - return newError("argument to `+` must be INTEGER, got %s", arg.Type()) - } - sum += arg.(*object.Integer).Value - } - return &object.Integer{Value: sum} - }, - }, - "-": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } - if len(args) == 1 { - if args[0].Type() != object.INTEGER_OBJ { - return newError("argument to `-` must be INTEGER, got %s", args[0].Type()) - } - return &object.Integer{Value: -args[0].(*object.Integer).Value} - } +func getBuiltinFunctions(funcName string) (*object.Builtin, bool) { + switch funcName { + case "+": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + var sum int64 + for _, arg := range args { + if arg.Type() != object.INTEGER_OBJ { + return newError("argument to `+` must be INTEGER, got %s", arg.Type()) + } + sum += arg.(*object.Integer).Value + } + return &object.Integer{Value: sum} + }, + }, true + case "-": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } + if len(args) == 1 { + if args[0].Type() != object.INTEGER_OBJ { + return newError("argument to `-` must be INTEGER, got %s", args[0].Type()) + } + return &object.Integer{Value: -args[0].(*object.Integer).Value} + } - var diff int64 - for i, arg := range args { - if arg.Type() != object.INTEGER_OBJ { - return newError("argument to `-` must be INTEGER, got %s", arg.Type()) - } - if i == 0 { - diff = arg.(*object.Integer).Value - } else { - diff -= arg.(*object.Integer).Value - } - } - return &object.Integer{Value: diff} - }, - }, - "*": { - Fn: func(args ...object.Object) object.Object { - var product int64 = 1 - for _, arg := range args { - if arg.Type() != object.INTEGER_OBJ { - return newError("argument to `*` must be INTEGER, got %s", arg.Type()) - } - product *= arg.(*object.Integer).Value - } - return &object.Integer{Value: product} - }, - }, - "/": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } - if len(args) == 1 { - if args[0].Type() != object.INTEGER_OBJ { - return newError("argument to `/` must be INTEGER, got %s", args[0].Type()) - } - return &object.Integer{Value: 1 / args[0].(*object.Integer).Value} - } + var diff int64 + for i, arg := range args { + if arg.Type() != object.INTEGER_OBJ { + return newError("argument to `-` must be INTEGER, got %s", arg.Type()) + } + if i == 0 { + diff = arg.(*object.Integer).Value + } else { + diff -= arg.(*object.Integer).Value + } + } + return &object.Integer{Value: diff} + }, + }, true + case "*": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + var product int64 = 1 + for _, arg := range args { + if arg.Type() != object.INTEGER_OBJ { + return newError("argument to `*` must be INTEGER, got %s", arg.Type()) + } + product *= arg.(*object.Integer).Value + } + return &object.Integer{Value: product} + }, + }, true + case "/": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } + if len(args) == 1 { + if args[0].Type() != object.INTEGER_OBJ { + return newError("argument to `/` must be INTEGER, got %s", args[0].Type()) + } + return &object.Integer{Value: 1 / args[0].(*object.Integer).Value} + } - var quotient int64 - for i, arg := range args { - if arg.Type() != object.INTEGER_OBJ { - return newError("argument to `/` must be INTEGER, got %s", arg.Type()) - } - if i == 0 { - quotient = arg.(*object.Integer).Value - } else { - quotient /= arg.(*object.Integer).Value - } - } - return &object.Integer{Value: quotient} - }, - }, - "=": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } + var quotient int64 + for i, arg := range args { + if arg.Type() != object.INTEGER_OBJ { + return newError("argument to `/` must be INTEGER, got %s", arg.Type()) + } + if i == 0 { + quotient = arg.(*object.Integer).Value + } else { + quotient /= arg.(*object.Integer).Value + } + } + return &object.Integer{Value: quotient} + }, + }, true + case "=": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } - compTo, ok := args[0].(*object.Integer) - if !ok { - return newError("argument to `=` must be INTEGER, got %s", args[0].Type()) - } - for _, arg := range args[1:] { - compFrom, ok := arg.(*object.Integer) + compTo, ok := args[0].(*object.Integer) if !ok { - return newError("argument to `=` must be INTEGER, got %s", arg.Type()) - } - if compFrom.Value != compTo.Value { - return Nil - } - } - return True - }, - }, - "/=": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } + return newError("argument to `=` must be INTEGER, got %s", args[0].Type()) + } + for _, arg := range args[1:] { + compFrom, ok := arg.(*object.Integer) + if !ok { + return newError("argument to `=` must be INTEGER, got %s", arg.Type()) + } + if compFrom.Value != compTo.Value { + return Nil + } + } + return True + }, + }, true + case "/=": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } - compTo, ok := args[0].(*object.Integer) - if !ok { - return newError("argument to `/=` must be INTEGER, got %s", args[0].Type()) - } - for _, arg := range args[1:] { - compFrom, ok := arg.(*object.Integer) + compTo, ok := args[0].(*object.Integer) if !ok { - return newError("argument to `/=` must be INTEGER, got %s", arg.Type()) - } - if compFrom.Value == compTo.Value { - return Nil - } - } - return True - }, - }, - "<": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } + return newError("argument to `/=` must be INTEGER, got %s", args[0].Type()) + } + for _, arg := range args[1:] { + compFrom, ok := arg.(*object.Integer) + if !ok { + return newError("argument to `/=` must be INTEGER, got %s", arg.Type()) + } + if compFrom.Value == compTo.Value { + return Nil + } + } + return True + }, + }, true + case "<": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } - compTo, ok := args[0].(*object.Integer) - if !ok { - return newError("argument to `<` must be INTEGER, got %s", args[0].Type()) - } - for _, arg := range args[1:] { - compFrom, ok := arg.(*object.Integer) + compTo, ok := args[0].(*object.Integer) if !ok { - return newError("argument to `<` must be INTEGER, got %s", arg.Type()) - } - if compTo.Value >= compFrom.Value { - return Nil - } - compTo = compFrom - } - return True - }, - }, - "<=": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } + return newError("argument to `<` must be INTEGER, got %s", args[0].Type()) + } + for _, arg := range args[1:] { + compFrom, ok := arg.(*object.Integer) + if !ok { + return newError("argument to `<` must be INTEGER, got %s", arg.Type()) + } + if compTo.Value >= compFrom.Value { + return Nil + } + compTo = compFrom + } + return True + }, + }, true + case "<=": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } - compTo, ok := args[0].(*object.Integer) - if !ok { - return newError("argument to `<=` must be INTEGER, got %s", args[0].Type()) - } - for _, arg := range args[1:] { - compFrom, ok := arg.(*object.Integer) + compTo, ok := args[0].(*object.Integer) if !ok { - return newError("argument to `<=` must be INTEGER, got %s", arg.Type()) - } - if compTo.Value > compFrom.Value { - return Nil - } - compTo = compFrom - } - return True - }, - }, - ">": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } + return newError("argument to `<=` must be INTEGER, got %s", args[0].Type()) + } + for _, arg := range args[1:] { + compFrom, ok := arg.(*object.Integer) + if !ok { + return newError("argument to `<=` must be INTEGER, got %s", arg.Type()) + } + if compTo.Value > compFrom.Value { + return Nil + } + compTo = compFrom + } + return True + }, + }, true + case ">": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } - compTo, ok := args[0].(*object.Integer) - if !ok { - return newError("argument to `>` must be INTEGER, got %s", args[0].Type()) - } - for _, arg := range args[1:] { - compFrom, ok := arg.(*object.Integer) + compTo, ok := args[0].(*object.Integer) if !ok { - return newError("argument to `>` must be INTEGER, got %s", arg.Type()) - } - if compTo.Value <= compFrom.Value { - return Nil - } - compTo = compFrom - } - return True - }, - }, - ">=": { - Fn: func(args ...object.Object) object.Object { - if len(args) == 0 { - return newError("wrong number of arguments. got=%d, want=1", len(args)) - } + return newError("argument to `>` must be INTEGER, got %s", args[0].Type()) + } + for _, arg := range args[1:] { + compFrom, ok := arg.(*object.Integer) + if !ok { + return newError("argument to `>` must be INTEGER, got %s", arg.Type()) + } + if compTo.Value <= compFrom.Value { + return Nil + } + compTo = compFrom + } + return True + }, + }, true + case ">=": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) == 0 { + return newError("wrong number of arguments. got=%d, want=1", len(args)) + } - compTo, ok := args[0].(*object.Integer) - if !ok { - return newError("argument to `>=` must be INTEGER, got %s", args[0].Type()) - } - for _, arg := range args[1:] { - compFrom, ok := arg.(*object.Integer) + compTo, ok := args[0].(*object.Integer) if !ok { - return newError("argument to `>=` must be INTEGER, got %s", arg.Type()) + return newError("argument to `>=` must be INTEGER, got %s", args[0].Type()) + } + for _, arg := range args[1:] { + compFrom, ok := arg.(*object.Integer) + if !ok { + return newError("argument to `>=` must be INTEGER, got %s", arg.Type()) + } + if compTo.Value < compFrom.Value { + return Nil + } + compTo = compFrom + } + return True + }, + }, true + case "apply": + return &object.Builtin{ + Fn: func(env *object.Environment, args ...object.Object) object.Object { + if len(args) != 2 { + return newError("wrong number of arguments. got=%d, want=2", len(args)) } - if compTo.Value < compFrom.Value { - return Nil + + var evaluatedArgs []object.Object + quote, ok := args[1].(*object.Quote) + if ok { + evaluatedArgs = evalArgs(quote.SExpression, env) + } else if _, isNil := args[1].(*object.Nil); !isNil { + return newError("second argument to `apply` must be QUOTE, got %s", args[1].Type()) } - compTo = compFrom - } - return True - }, - }, + + return applyFunction(args[0], evaluatedArgs, env) + }, + }, true + default: + return nil, false + } } diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 9548afd..0a019ae 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -97,7 +97,8 @@ func evalSymbol(symbol *ast.Symbol, env *object.Environment) object.Object { return val } - if builtin, ok := builtinFuncs[symbol.Value]; ok { + builtin, ok := getBuiltinFunctions(symbol.Value) + if ok { return builtin } @@ -141,7 +142,7 @@ func evalNormalForm(consCell *ast.ConsCell, env *object.Environment) object.Obje if len(args) == 1 && isError(args[0]) { return args[0] } - return applyFunction(car, args) + return applyFunction(car, args, env) } func evalArgs(sexp ast.SExpression, env *object.Environment) []object.Object { @@ -185,7 +186,7 @@ func evalValueList(consCell *ast.ConsCell, env *object.Environment) []object.Obj } } -func applyFunction(fn object.Object, args []object.Object) object.Object { +func applyFunction(fn object.Object, args []object.Object, env *object.Environment) object.Object { switch fn := fn.(type) { case *object.Function: extendedEnv, err := extendFunctionEnv(fn, args) @@ -201,7 +202,7 @@ func applyFunction(fn object.Object, args []object.Object) object.Object { symbolFunc := fn.Function return Eval(symbolFunc.Body, extendedEnv) case *object.Builtin: - return fn.Fn(args...) + return fn.Fn(env, args...) default: return newError("not a function: %s", fn.Type()) } diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index b4daf47..fda8792 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -233,3 +233,19 @@ func TestSetqExpression(t *testing.T) { testIntegerObject(t, evaluated, tt.expected) } } + +func TestApplyExpression(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"(setq f (lambda () (+ 1 1))) (apply f ())", 2}, + {"(setq f (lambda (x) (+ x x))) (apply f '(1))", 2}, + {"(setq f (lambda (x y) (+ x y))) (apply f '(1 1))", 2}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + testIntegerObject(t, evaluated, tt.expected) + } +} diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 5e46ad5..df96a4d 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -480,6 +480,23 @@ func TestList(t *testing.T) { {Type: token.EOF, Literal: ""}, }, }, + { + name: "apply function", + input: "(apply f '(1 2 3))", + expected: []token.Token{ + {Type: token.LPAREN, Literal: "("}, + {Type: token.SYMBOL, Literal: "apply"}, + {Type: token.SYMBOL, Literal: "f"}, + {Type: token.QUOTE, Literal: "'"}, + {Type: token.LPAREN, Literal: "("}, + {Type: token.INT, Literal: "1"}, + {Type: token.INT, Literal: "2"}, + {Type: token.INT, Literal: "3"}, + {Type: token.RPAREN, Literal: ")"}, + {Type: token.RPAREN, Literal: ")"}, + {Type: token.EOF, Literal: ""}, + }, + }, } for _, tt := range tests { diff --git a/object/object.go b/object/object.go index b7be1c5..3efa1ee 100644 --- a/object/object.go +++ b/object/object.go @@ -21,7 +21,7 @@ const ( LIST_OBJ = "LIST" ) -type BuiltInFunction func(args ...Object) Object +type BuiltInFunction func(env *Environment, args ...Object) Object type ObjectType string diff --git a/parser/parser_test.go b/parser/parser_test.go index 70ceaf1..172bb57 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1298,6 +1298,115 @@ func TestSetq(t *testing.T) { } } +func TestApplyFunction(t *testing.T) { + tests := []struct { + name string + input string + expected *ast.ConsCell + }{ + { + name: "apply function with no parameter", + input: "(apply f ())", + expected: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "apply"}, Value: "apply"}, + CdrField: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "f"}, Value: "f"}, + CdrField: &ast.ConsCell{ + CarField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + }, + }, + { + name: "apply function with no parameter defined by nil", + input: "(apply f nil)", + expected: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "apply"}, Value: "apply"}, + CdrField: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "f"}, Value: "f"}, + CdrField: &ast.ConsCell{ + CarField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + }, + }, + { + name: "apply function with one parameter", + input: "(apply f '(1))", + expected: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "apply"}, Value: "apply"}, + CdrField: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "f"}, Value: "f"}, + CdrField: &ast.ConsCell{ + CarField: &ast.ConsCell{ + CarField: &ast.SpecialForm{Token: token.Token{Type: token.QUOTE, Literal: "'"}, Value: "quote"}, + CdrField: &ast.ConsCell{ + CarField: &ast.ConsCell{ + CarField: &ast.IntegerLiteral{Token: token.Token{Type: token.INT, Literal: "1"}, Value: 1}, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + }, + }, + { + name: "apply function with multiple parameters", + input: "(apply f '(1 2))", + expected: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "apply"}, Value: "apply"}, + CdrField: &ast.ConsCell{ + CarField: &ast.Symbol{Token: token.Token{Type: token.SYMBOL, Literal: "f"}, Value: "f"}, + CdrField: &ast.ConsCell{ + CarField: &ast.ConsCell{ + CarField: &ast.SpecialForm{Token: token.Token{Type: token.QUOTE, Literal: "'"}, Value: "quote"}, + CdrField: &ast.ConsCell{ + CarField: &ast.ConsCell{ + CarField: &ast.IntegerLiteral{Token: token.Token{Type: token.INT, Literal: "1"}, Value: 1}, + CdrField: &ast.ConsCell{ + CarField: &ast.IntegerLiteral{Token: token.Token{Type: token.INT, Literal: "2"}, Value: 2}, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + CdrField: &ast.Nil{Token: token.Token{Type: token.NIL, Literal: "nil"}}, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Expressions) != 1 { + t.Fatalf("program.Expressions does not contain 1 expressions. got=%d", len(program.Expressions)) + } + + cc, ok := program.Expressions[0].(*ast.ConsCell) + if !ok { + t.Fatalf("exp not *ast.ConsCell. got=%T", program.Expressions[0]) + } + + if cc.String() != tt.expected.String() { + t.Fatalf("cc.String() not %s. got=%s", tt.expected.String(), cc.String()) + } + }) + + } +} + func TestProgram(t *testing.T) { tests := []struct { name string