aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ast/ast.go33
-rw-r--r--evaluator/evaluator.go134
-rw-r--r--evaluator/evaluator_test.go176
-rw-r--r--object/environment.go14
-rw-r--r--object/object.go43
-rw-r--r--parser/parser.go40
-rw-r--r--parser/parser_test.go14
-rw-r--r--parser/parser_tracing.go2
-rw-r--r--repl/repl.go12
9 files changed, 297 insertions, 171 deletions
diff --git a/ast/ast.go b/ast/ast.go
index 88c07e8..346d96d 100644
--- a/ast/ast.go
+++ b/ast/ast.go
@@ -91,14 +91,14 @@ func (i *Identifier) String() string {
// BlockStatement represents a block statement.
type BlockStatement struct {
- Token tokens.Token // the token.LBRACE token
- Statements []Statement
+ Token tokens.Token // the token.LBRACE token
+ Statements []Statement
}
-func (bs *BlockStatement) statementNode() {}
+func (bs *BlockStatement) statementNode() {}
func (bs *BlockStatement) TokenLiteral() string { return bs.Token.Literal }
func (bs *BlockStatement) String() string {
- var out bytes.Buffer
+ var out bytes.Buffer
for _, s := range bs.Statements {
out.WriteString(s.String())
@@ -186,7 +186,6 @@ func (ie *InfixExpression) TokenLiteral() string {
func (ie *InfixExpression) String() string {
var out bytes.Buffer
-
out.WriteString("(")
out.WriteString(ie.Left.String())
out.WriteString(" " + ie.Operator + " ")
@@ -196,7 +195,6 @@ func (ie *InfixExpression) String() string {
return out.String()
}
-
// IntegerLiteral represents an integer literal.
type IntegerLiteral struct {
Token tokens.Token // the token.INT token
@@ -225,13 +223,13 @@ func (b *Boolean) String() string {
}
type IfExpression struct {
- Token tokens.Token // the 'if' token
- Condition Expression
+ Token tokens.Token // the 'if' token
+ Condition Expression
Consequence *BlockStatement
Alternative *BlockStatement
}
-func (ie *IfExpression) expressionNode() {}
+func (ie *IfExpression) expressionNode() {}
func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal }
func (ie *IfExpression) String() string {
var out bytes.Buffer
@@ -250,12 +248,12 @@ func (ie *IfExpression) String() string {
}
type FunctionLiteral struct {
- Token tokens.Token // the 'fn' token
- Parameters []*Identifier
- Body *BlockStatement
+ Token tokens.Token // the 'fn' token
+ Parameters []*Identifier
+ Body *BlockStatement
}
-func (fl *FunctionLiteral) expressionNode() {}
+func (fl *FunctionLiteral) expressionNode() {}
func (fl *FunctionLiteral) TokenLiteral() string { return fl.Token.Literal }
func (fl *FunctionLiteral) String() string {
var out bytes.Buffer
@@ -276,12 +274,12 @@ func (fl *FunctionLiteral) String() string {
}
type CallExpression struct {
- Token tokens.Token // the '(' token
- Function Expression // Identifier or FunctionLiteral
- Arguments []Expression
+ Token tokens.Token // the '(' token
+ Function Expression // Identifier or FunctionLiteral
+ Arguments []Expression
}
-func (ce *CallExpression) expressionNode() {}
+func (ce *CallExpression) expressionNode() {}
func (ce *CallExpression) TokenLiteral() string { return ce.Token.Literal }
func (ce *CallExpression) String() string {
var out bytes.Buffer
@@ -299,4 +297,3 @@ func (ce *CallExpression) String() string {
return out.String()
}
-
diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go
index 9cead89..f765d02 100644
--- a/evaluator/evaluator.go
+++ b/evaluator/evaluator.go
@@ -1,9 +1,9 @@
package evaluator
import (
+ "fmt"
"mana/ast"
"mana/object"
- "fmt"
)
var (
@@ -25,15 +25,15 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.BlockStatement:
return evalBlockStatement(node, env)
- case *ast.LetStatement:
- val := Eval(node.Value, env)
- if isError(val) {
- return val
- }
- env.Set(node.Name.Value, val)
+ case *ast.LetStatement:
+ val := Eval(node.Value, env)
+ if isError(val) {
+ return val
+ }
+ env.Set(node.Name.Value, val)
- case *ast.Identifier:
- return evalIdentifier(node, env)
+ case *ast.Identifier:
+ return evalIdentifier(node, env)
// Expressions
case *ast.IntegerLiteral:
@@ -44,30 +44,48 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.PrefixExpression:
right := Eval(node.Right, env)
- if isError(right) {
- return right
- }
+ if isError(right) {
+ return right
+ }
return evalPrefixExpression(node.Operator, right)
case *ast.InfixExpression:
left := Eval(node.Left, env)
- if isError(left) {
- return left
- }
+ if isError(left) {
+ return left
+ }
right := Eval(node.Right, env)
- if isError(right) {
- return right
- }
+ if isError(right) {
+ return right
+ }
return evalInfixExpression(node.Operator, left, right)
case *ast.IfExpression:
return evalIfExpression(node, env)
+ case *ast.CallExpression:
+ function := Eval(node.Function, env)
+
+ if isError(function) {
+ return function
+ }
+
+ args := evalExpressions(node.Arguments, env)
+ if len(args) == 1 && isError(args[0]) {
+ return args[0]
+ }
+
+ case *ast.FunctionLiteral:
+ params := node.Parameters
+ body := node.Body
+
+ return &object.Function{Parameters: params, Body: body, Env: env}
+
case *ast.ReturnStatement:
val := Eval(node.ReturnValue, env)
- if isError(val) {
- return val
- }
+ if isError(val) {
+ return val
+ }
return &object.ReturnValue{Value: val}
}
@@ -80,30 +98,44 @@ func evalProgram(program *ast.Program, env *object.Environment) object.Object {
for _, statement := range program.Statements {
result = Eval(statement, env)
- switch result := result.(type) {
- case *object.ReturnValue:
- return result.Value
- case *object.Error:
- return result
- }
+ switch result := result.(type) {
+ case *object.ReturnValue:
+ return result.Value
+ case *object.Error:
+ return result
+ }
}
return result
}
-
func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object {
var result object.Object
for _, statement := range block.Statements {
result = Eval(statement, env)
- if result != nil {
- rt := result.Type()
- if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ {
- return result
- }
- }
+ if result != nil {
+ rt := result.Type()
+ if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ {
+ return result
+ }
+ }
+ }
+
+ return result
+}
+
+func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object {
+ var result []object.Object
+
+ for _, e := range exps {
+ evaluated := Eval(e, env)
+ if isError(evaluated) {
+ return []object.Object{evaluated}
+ }
+
+ result = append(result, evaluated)
}
return result
@@ -123,7 +155,7 @@ func evalPrefixExpression(operator string, right object.Object) object.Object {
case "-":
return evalMinusPrefixOperatorExpression(right)
default:
- return newError("unknown operator: %s%s", operator, right.Type())
+ return newError("unknown operator: %s%s", operator, right.Type())
}
}
@@ -135,10 +167,10 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje
return nativeBoolToBooleanObject(left == right)
case operator == "!=":
return nativeBoolToBooleanObject(left != right)
- case left.Type() != right.Type():
- return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type())
+ case left.Type() != right.Type():
+ return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type())
default:
- return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
+ return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
}
}
@@ -195,9 +227,9 @@ func evalIntegerInfixExpression(operator string, left, right object.Object) obje
func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object {
condition := Eval(ie.Condition, env)
- if isError(condition) {
- return condition
- }
+ if isError(condition) {
+ return condition
+ }
if isTruthy(condition) {
return Eval(ie.Consequence, env)
} else if ie.Alternative != nil {
@@ -208,11 +240,11 @@ func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Obje
}
func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object {
- val, ok := env.Get(node.Value)
- if !ok {
- return newError("identifier not found: " + node.Value)
- }
- return val
+ val, ok := env.Get(node.Value)
+ if !ok {
+ return newError("identifier not found: " + node.Value)
+ }
+ return val
}
func isTruthy(obj object.Object) bool {
@@ -229,12 +261,12 @@ func isTruthy(obj object.Object) bool {
}
func newError(format string, a ...interface{}) *object.Error {
- return &object.Error{Message: fmt.Sprintf(format, a...)}
+ return &object.Error{Message: fmt.Sprintf(format, a...)}
}
func isError(obj object.Object) bool {
- if obj != nil {
- return obj.Type() == object.ERROR_OBJ
- }
- return false
+ if obj != nil {
+ return obj.Type() == object.ERROR_OBJ
+ }
+ return false
}
diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go
index 38fed91..f83cf7f 100644
--- a/evaluator/evaluator_test.go
+++ b/evaluator/evaluator_test.go
@@ -5,7 +5,7 @@ import (
"mana/object"
"mana/parser"
- "testing"
+ "testing"
)
func TestEvalIntegerExpression(t *testing.T) {
@@ -138,36 +138,36 @@ func TestReturnStatements(t *testing.T) {
}
func TestErrorhandling(t *testing.T) {
- tests := [] struct {
- input string
- expectedMessage string
- } {
- {
- "5 + true;",
- "type mismatch: INTEGER + BOOLEAN",
- },
- {
- "5 + true; 5;",
- "type mismatch: INTEGER + BOOLEAN",
- },
- {
- "-true",
- "unknown operator: -BOOLEAN",
- },
- {
- "true + false;",
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- "5; true + false; 5",
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- "if (10 > 1) { true + false; }",
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- `
+ tests := []struct {
+ input string
+ expectedMessage string
+ }{
+ {
+ "5 + true;",
+ "type mismatch: INTEGER + BOOLEAN",
+ },
+ {
+ "5 + true; 5;",
+ "type mismatch: INTEGER + BOOLEAN",
+ },
+ {
+ "-true",
+ "unknown operator: -BOOLEAN",
+ },
+ {
+ "true + false;",
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ "5; true + false; 5",
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ "if (10 > 1) { true + false; }",
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ `
if (10 > 1) {
if (10 > 1) {
return true + false;
@@ -175,50 +175,94 @@ func TestErrorhandling(t *testing.T) {
return 1;
}
`,
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- "foobar",
- "identifier not found: foobar",
- },
- }
-
- for _, tt := range tests {
- evaluated := testEval(tt.input)
-
- errObj, ok := evaluated.(*object.Error)
- if !ok {
- t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated)
- continue
- }
-
- if errObj.Message != tt.expectedMessage {
- t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message)
- }
- }
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ "foobar",
+ "identifier not found: foobar",
+ },
+ }
+
+ for _, tt := range tests {
+ evaluated := testEval(tt.input)
+
+ errObj, ok := evaluated.(*object.Error)
+ if !ok {
+ t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated)
+ continue
+ }
+
+ if errObj.Message != tt.expectedMessage {
+ t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message)
+ }
+ }
}
func TestLetStatements(t *testing.T) {
- tests := []struct {
- input string
- expected int64
- } {
- {"let a = 5; a;", 5},
- {"let a = 5 * 5; a;", 25},
- {"let a = 5; let b = a; b;", 5},
- {"let a = 5; let b = a; let c = a + b + 5; c;", 15},
- }
-
- for _, tt := range tests {
- testIntegerObject(t, testEval(tt.input), tt.expected)
- }
+ tests := []struct {
+ input string
+ expected int64
+ }{
+ {"let a = 5; a;", 5},
+ {"let a = 5 * 5; a;", 25},
+ {"let a = 5; let b = a; b;", 5},
+ {"let a = 5; let b = a; let c = a + b + 5; c;", 15},
+ }
+
+ for _, tt := range tests {
+ testIntegerObject(t, testEval(tt.input), tt.expected)
+ }
+}
+
+func TestFunctionObject(t *testing.T) {
+ input := "fn(x) { x + 2; };"
+
+ evaluated := testEval(input)
+
+ fn, ok := evaluated.(*object.Function)
+
+ if !ok {
+ t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated)
+ }
+
+ if len(fn.Parameters) != 1 {
+ t.Fatalf("function has wrong parameters. Parameters=%+v", fn.Parameters)
+ }
+
+ if fn.Parameters[0].String() != "x" {
+ t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0])
+ }
+
+ expectedBody := "(x + 2)"
+
+ if fn.Body.String() != expectedBody {
+ t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String())
+ }
+}
+
+func TestFunctionApplication(t *testing.T) {
+ tests := []struct {
+ input string
+ expected int64
+ }{
+ {"let identity = fn(x) { x; }; identity(5);", 5},
+ {"let identity = fn(x) { return x; }; identity(5);", 5},
+ {"let double = fn(x) { x * 2; }; double(5);", 10},
+ {"let add = fn(x, y) { x + y; }; add(5, 5);", 10},
+ {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20},
+ {"fn(x) { x; }(5)", 5},
+ }
+
+ for _, tt := range tests {
+ testIntegerObject(t, testEval(tt.input), tt.expected)
+ }
}
func testEval(input string) object.Object {
l := lexer.New(input)
p := parser.New(l)
program := p.ParseProgram()
- env := object.NewEnvironment()
+ env := object.NewEnvironment()
return Eval(program, env)
}
diff --git a/object/environment.go b/object/environment.go
index 4b37749..bd89056 100644
--- a/object/environment.go
+++ b/object/environment.go
@@ -1,20 +1,20 @@
package object
func NewEnvironment() *Environment {
- s := make(map[string]Object)
- return &Environment{store: s}
+ s := make(map[string]Object)
+ return &Environment{store: s}
}
type Environment struct {
- store map[string]Object
+ store map[string]Object
}
func (e *Environment) Get(name string) (Object, bool) {
- obj, ok := e.store[name]
- return obj, ok
+ obj, ok := e.store[name]
+ return obj, ok
}
func (e *Environment) Set(name string, val Object) Object {
- e.store[name] = val
- return val
+ e.store[name] = val
+ return val
}
diff --git a/object/object.go b/object/object.go
index 20deb0c..793ff6b 100644
--- a/object/object.go
+++ b/object/object.go
@@ -1,6 +1,11 @@
package object
-import "fmt"
+import (
+ "bytes"
+ "fmt"
+ "mana/ast"
+ "strings"
+)
type ObjectType string
@@ -9,7 +14,8 @@ const (
BOOLEAN_OBJ = "BOOLEAN"
NULL_OBJ = "NULL"
RETURN_VALUE_OBJ = "RETURN_VALUE"
- ERROR_OBJ = "ERROR"
+ FUNCTION_OBJ = "FUNCTION"
+ ERROR_OBJ = "ERROR"
)
type Object interface {
@@ -32,7 +38,13 @@ type ReturnValue struct {
}
type Error struct {
- Message string
+ Message string
+}
+
+type Function struct {
+ Parameters []*ast.Identifier
+ Body *ast.BlockStatement
+ Env *Environment
}
func (i *Integer) Type() ObjectType {
@@ -67,5 +79,28 @@ func (rv *ReturnValue) Inspect() string {
return rv.Value.Inspect()
}
+func (f *Function) Type() ObjectType {
+ return FUNCTION_OBJ
+}
+
+func (f *Function) Inspect() string {
+ var out bytes.Buffer
+
+ params := []string{}
+
+ for _, p := range f.Parameters {
+ params = append(params, p.String())
+ }
+
+ out.WriteString("fn")
+ out.WriteString("(")
+ out.WriteString(strings.Join(params, ", "))
+ out.WriteString(") {\n")
+ out.WriteString(f.Body.String())
+ out.WriteString("\n}")
+
+ return out.String()
+}
+
func (e *Error) Type() ObjectType { return ERROR_OBJ }
-func (e *Error) Inspect() string { return "ERROR:" + e.Message }
+func (e *Error) Inspect() string { return "ERROR:" + e.Message }
diff --git a/parser/parser.go b/parser/parser.go
index c7d6234..3fe3d9f 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -252,7 +252,7 @@ func (p *Parser) parseExpression(precedence int) ast.Expression {
p.nextToken()
leftExp = infix(leftExp)
-
+
}
return leftExp
@@ -294,21 +294,29 @@ func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression {
func (p *Parser) parseIfExpression() ast.Expression {
expression := &ast.IfExpression{Token: p.curToken}
- if !p.expectPeek(tokens.LPAREN) { return nil }
+ if !p.expectPeek(tokens.LPAREN) {
+ return nil
+ }
p.nextToken()
expression.Condition = p.parseExpression(LOWEST)
- if !p.expectPeek(tokens.RPAREN) { return nil }
+ if !p.expectPeek(tokens.RPAREN) {
+ return nil
+ }
- if !p.expectPeek(tokens.LBRACE) { return nil }
+ if !p.expectPeek(tokens.LBRACE) {
+ return nil
+ }
expression.Consequence = p.parseBlockStatement()
if p.peekTokenIs(tokens.ELSE) {
p.nextToken()
- if !p.expectPeek(tokens.LBRACE) { return nil }
+ if !p.expectPeek(tokens.LBRACE) {
+ return nil
+ }
expression.Alternative = p.parseBlockStatement()
}
@@ -338,11 +346,15 @@ func (p *Parser) parseBlockStatement() *ast.BlockStatement {
func (p *Parser) parseFunctionLiteral() ast.Expression {
lit := &ast.FunctionLiteral{Token: p.curToken}
- if !p.expectPeek(tokens.LPAREN) { return nil }
+ if !p.expectPeek(tokens.LPAREN) {
+ return nil
+ }
lit.Parameters = p.parseFunctionParameters()
- if !p.expectPeek(tokens.LBRACE) { return nil }
+ if !p.expectPeek(tokens.LBRACE) {
+ return nil
+ }
lit.Body = p.parseBlockStatement()
@@ -371,7 +383,9 @@ func (p *Parser) parseFunctionParameters() []*ast.Identifier {
identifiers = append(identifiers, ident)
}
- if !p.expectPeek(tokens.RPAREN) { return nil }
+ if !p.expectPeek(tokens.RPAREN) {
+ return nil
+ }
return identifiers
}
@@ -382,7 +396,9 @@ func (p *Parser) parseGroupedExpression() ast.Expression {
exp := p.parseExpression(LOWEST)
- if !p.expectPeek(tokens.RPAREN) { return nil }
+ if !p.expectPeek(tokens.RPAREN) {
+ return nil
+ }
return exp
}
@@ -413,7 +429,9 @@ func (p *Parser) parseCallArguments() []ast.Expression {
args = append(args, p.parseExpression(LOWEST))
}
- if !p.expectPeek(tokens.RPAREN) { return nil }
+ if !p.expectPeek(tokens.RPAREN) {
+ return nil
+ }
return args
}
@@ -453,7 +471,7 @@ func (p *Parser) peekError(t tokens.TokenType) {
p.errors = append(p.errors, msg)
}
-// peek and cur precedences
+// peek and cur precedences
func (p *Parser) peekPrecedence() int {
if p, ok := precedences[p.peekToken.Type]; ok {
diff --git a/parser/parser_test.go b/parser/parser_test.go
index a9754f5..08539b7 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -9,10 +9,10 @@ import (
func TestLetStatements(t *testing.T) {
tests := []struct {
- input string
+ input string
expectedIdentifier string
- expectedValue interface{}
- } {
+ expectedValue interface{}
+ }{
{"let x = 5;", "x", 5},
{"let y = true;", "y", true},
{"let foobar = y;", "foobar", "y"},
@@ -303,9 +303,9 @@ func TestBooleanExpression(t *testing.T) {
// Prefix expression tests.
func TestParsingPrefixExpressions(t *testing.T) {
var prefixTests = []struct {
- input string
- operator string
- value interface{}
+ input string
+ operator string
+ value interface{}
}{
{"!5;", "!", 5},
{"-15;", "-", 15},
@@ -722,7 +722,7 @@ func TestFunctionParameterParsing(t *testing.T) {
var program *ast.Program = p.ParseProgram()
checkParserErrors(t, p)
-
+
stmt := program.Statements[0].(*ast.ExpressionStatement)
function := stmt.Expression.(*ast.FunctionLiteral)
diff --git a/parser/parser_tracing.go b/parser/parser_tracing.go
index a0819da..2f3fcdb 100644
--- a/parser/parser_tracing.go
+++ b/parser/parser_tracing.go
@@ -34,4 +34,4 @@ func trace(msg string) string {
func untrace(msg string) {
tracePrint("END " + msg)
decIdent()
-} \ No newline at end of file
+}
diff --git a/repl/repl.go b/repl/repl.go
index 59c56d6..afee75e 100644
--- a/repl/repl.go
+++ b/repl/repl.go
@@ -4,10 +4,10 @@ import (
"bufio"
"fmt"
"io"
+ "mana/evaluator"
"mana/lexer"
+ "mana/object"
"mana/parser"
- "mana/evaluator"
- "mana/object"
)
// PROMPT is the prompt for the REPL.
@@ -19,13 +19,13 @@ const MANA_START = `
██║╚██╔╝██║██╔══██║██║╚████║██╔══██║
██║░╚═╝░██║██║░░██║██║░╚███║██║░░██║
╚═╝░░░░░╚═╝╚═╝░░╚═╝╚═╝░░╚══╝╚═╝░░╚═╝
-`
+`
func Start(in io.Reader, out io.Writer) {
var scanner *bufio.Scanner = bufio.NewScanner(in)
- env := object.NewEnvironment()
+ env := object.NewEnvironment()
- io.WriteString(out, MANA_START + "\n")
+ io.WriteString(out, MANA_START+"\n")
for {
fmt.Fprint(out, PROMPT)
@@ -57,6 +57,6 @@ func Start(in io.Reader, out io.Writer) {
func printParserErrors(out io.Writer, errors []string) {
io.WriteString(out, "ParseError:\n")
for _, msg := range errors {
- io.WriteString(out, "\t" + msg + "\n")
+ io.WriteString(out, "\t"+msg+"\n")
}
}