diff options
| author | Bobby <[email protected]> | 2024-03-08 21:38:19 +0000 |
|---|---|---|
| committer | Bobby <[email protected]> | 2024-03-08 21:38:19 +0000 |
| commit | 87b188bae2c8a2a9f81e872805d072be7ec910b2 (patch) | |
| tree | b3c48ba591966e424b1bb55dba1f652733368a34 | |
| parent | b52f4e9b4140f482ad966aa354b39cd305a212ec (diff) | |
| download | mana-87b188bae2c8a2a9f81e872805d072be7ec910b2.tar.xz mana-87b188bae2c8a2a9f81e872805d072be7ec910b2.zip | |
ast: fn init
| -rw-r--r-- | ast/ast.go | 33 | ||||
| -rw-r--r-- | evaluator/evaluator.go | 134 | ||||
| -rw-r--r-- | evaluator/evaluator_test.go | 176 | ||||
| -rw-r--r-- | object/environment.go | 14 | ||||
| -rw-r--r-- | object/object.go | 43 | ||||
| -rw-r--r-- | parser/parser.go | 40 | ||||
| -rw-r--r-- | parser/parser_test.go | 14 | ||||
| -rw-r--r-- | parser/parser_tracing.go | 2 | ||||
| -rw-r--r-- | repl/repl.go | 12 |
9 files changed, 297 insertions, 171 deletions
@@ -91,14 +91,14 @@ func (i *Identifier) String() string { // BlockStatement represents a block statement. type BlockStatement struct { - Token tokens.Token // the token.LBRACE token - Statements []Statement + Token tokens.Token // the token.LBRACE token + Statements []Statement } -func (bs *BlockStatement) statementNode() {} +func (bs *BlockStatement) statementNode() {} func (bs *BlockStatement) TokenLiteral() string { return bs.Token.Literal } func (bs *BlockStatement) String() string { - var out bytes.Buffer + var out bytes.Buffer for _, s := range bs.Statements { out.WriteString(s.String()) @@ -186,7 +186,6 @@ func (ie *InfixExpression) TokenLiteral() string { func (ie *InfixExpression) String() string { var out bytes.Buffer - out.WriteString("(") out.WriteString(ie.Left.String()) out.WriteString(" " + ie.Operator + " ") @@ -196,7 +195,6 @@ func (ie *InfixExpression) String() string { return out.String() } - // IntegerLiteral represents an integer literal. type IntegerLiteral struct { Token tokens.Token // the token.INT token @@ -225,13 +223,13 @@ func (b *Boolean) String() string { } type IfExpression struct { - Token tokens.Token // the 'if' token - Condition Expression + Token tokens.Token // the 'if' token + Condition Expression Consequence *BlockStatement Alternative *BlockStatement } -func (ie *IfExpression) expressionNode() {} +func (ie *IfExpression) expressionNode() {} func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *IfExpression) String() string { var out bytes.Buffer @@ -250,12 +248,12 @@ func (ie *IfExpression) String() string { } type FunctionLiteral struct { - Token tokens.Token // the 'fn' token - Parameters []*Identifier - Body *BlockStatement + Token tokens.Token // the 'fn' token + Parameters []*Identifier + Body *BlockStatement } -func (fl *FunctionLiteral) expressionNode() {} +func (fl *FunctionLiteral) expressionNode() {} func (fl *FunctionLiteral) TokenLiteral() string { return fl.Token.Literal } func (fl *FunctionLiteral) String() string { var out bytes.Buffer @@ -276,12 +274,12 @@ func (fl *FunctionLiteral) String() string { } type CallExpression struct { - Token tokens.Token // the '(' token - Function Expression // Identifier or FunctionLiteral - Arguments []Expression + Token tokens.Token // the '(' token + Function Expression // Identifier or FunctionLiteral + Arguments []Expression } -func (ce *CallExpression) expressionNode() {} +func (ce *CallExpression) expressionNode() {} func (ce *CallExpression) TokenLiteral() string { return ce.Token.Literal } func (ce *CallExpression) String() string { var out bytes.Buffer @@ -299,4 +297,3 @@ func (ce *CallExpression) String() string { return out.String() } - diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 9cead89..f765d02 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -1,9 +1,9 @@ package evaluator import ( + "fmt" "mana/ast" "mana/object" - "fmt" ) var ( @@ -25,15 +25,15 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.BlockStatement: return evalBlockStatement(node, env) - case *ast.LetStatement: - val := Eval(node.Value, env) - if isError(val) { - return val - } - env.Set(node.Name.Value, val) + case *ast.LetStatement: + val := Eval(node.Value, env) + if isError(val) { + return val + } + env.Set(node.Name.Value, val) - case *ast.Identifier: - return evalIdentifier(node, env) + case *ast.Identifier: + return evalIdentifier(node, env) // Expressions case *ast.IntegerLiteral: @@ -44,30 +44,48 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.PrefixExpression: right := Eval(node.Right, env) - if isError(right) { - return right - } + if isError(right) { + return right + } return evalPrefixExpression(node.Operator, right) case *ast.InfixExpression: left := Eval(node.Left, env) - if isError(left) { - return left - } + if isError(left) { + return left + } right := Eval(node.Right, env) - if isError(right) { - return right - } + if isError(right) { + return right + } return evalInfixExpression(node.Operator, left, right) case *ast.IfExpression: return evalIfExpression(node, env) + case *ast.CallExpression: + function := Eval(node.Function, env) + + if isError(function) { + return function + } + + args := evalExpressions(node.Arguments, env) + if len(args) == 1 && isError(args[0]) { + return args[0] + } + + case *ast.FunctionLiteral: + params := node.Parameters + body := node.Body + + return &object.Function{Parameters: params, Body: body, Env: env} + case *ast.ReturnStatement: val := Eval(node.ReturnValue, env) - if isError(val) { - return val - } + if isError(val) { + return val + } return &object.ReturnValue{Value: val} } @@ -80,30 +98,44 @@ func evalProgram(program *ast.Program, env *object.Environment) object.Object { for _, statement := range program.Statements { result = Eval(statement, env) - switch result := result.(type) { - case *object.ReturnValue: - return result.Value - case *object.Error: - return result - } + switch result := result.(type) { + case *object.ReturnValue: + return result.Value + case *object.Error: + return result + } } return result } - func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object { var result object.Object for _, statement := range block.Statements { result = Eval(statement, env) - if result != nil { - rt := result.Type() - if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { - return result - } - } + if result != nil { + rt := result.Type() + if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { + return result + } + } + } + + return result +} + +func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object { + var result []object.Object + + for _, e := range exps { + evaluated := Eval(e, env) + if isError(evaluated) { + return []object.Object{evaluated} + } + + result = append(result, evaluated) } return result @@ -123,7 +155,7 @@ func evalPrefixExpression(operator string, right object.Object) object.Object { case "-": return evalMinusPrefixOperatorExpression(right) default: - return newError("unknown operator: %s%s", operator, right.Type()) + return newError("unknown operator: %s%s", operator, right.Type()) } } @@ -135,10 +167,10 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje return nativeBoolToBooleanObject(left == right) case operator == "!=": return nativeBoolToBooleanObject(left != right) - case left.Type() != right.Type(): - return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type()) + case left.Type() != right.Type(): + return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type()) default: - return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) + return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) } } @@ -195,9 +227,9 @@ func evalIntegerInfixExpression(operator string, left, right object.Object) obje func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object { condition := Eval(ie.Condition, env) - if isError(condition) { - return condition - } + if isError(condition) { + return condition + } if isTruthy(condition) { return Eval(ie.Consequence, env) } else if ie.Alternative != nil { @@ -208,11 +240,11 @@ func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Obje } func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object { - val, ok := env.Get(node.Value) - if !ok { - return newError("identifier not found: " + node.Value) - } - return val + val, ok := env.Get(node.Value) + if !ok { + return newError("identifier not found: " + node.Value) + } + return val } func isTruthy(obj object.Object) bool { @@ -229,12 +261,12 @@ func isTruthy(obj object.Object) bool { } func newError(format string, a ...interface{}) *object.Error { - return &object.Error{Message: fmt.Sprintf(format, a...)} + return &object.Error{Message: fmt.Sprintf(format, a...)} } func isError(obj object.Object) bool { - if obj != nil { - return obj.Type() == object.ERROR_OBJ - } - return false + if obj != nil { + return obj.Type() == object.ERROR_OBJ + } + return false } diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 38fed91..f83cf7f 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -5,7 +5,7 @@ import ( "mana/object" "mana/parser" - "testing" + "testing" ) func TestEvalIntegerExpression(t *testing.T) { @@ -138,36 +138,36 @@ func TestReturnStatements(t *testing.T) { } func TestErrorhandling(t *testing.T) { - tests := [] struct { - input string - expectedMessage string - } { - { - "5 + true;", - "type mismatch: INTEGER + BOOLEAN", - }, - { - "5 + true; 5;", - "type mismatch: INTEGER + BOOLEAN", - }, - { - "-true", - "unknown operator: -BOOLEAN", - }, - { - "true + false;", - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - "5; true + false; 5", - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - "if (10 > 1) { true + false; }", - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - ` + tests := []struct { + input string + expectedMessage string + }{ + { + "5 + true;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "5 + true; 5;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "-true", + "unknown operator: -BOOLEAN", + }, + { + "true + false;", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "5; true + false; 5", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "if (10 > 1) { true + false; }", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + ` if (10 > 1) { if (10 > 1) { return true + false; @@ -175,50 +175,94 @@ func TestErrorhandling(t *testing.T) { return 1; } `, - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - "foobar", - "identifier not found: foobar", - }, - } - - for _, tt := range tests { - evaluated := testEval(tt.input) - - errObj, ok := evaluated.(*object.Error) - if !ok { - t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated) - continue - } - - if errObj.Message != tt.expectedMessage { - t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message) - } - } + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "foobar", + "identifier not found: foobar", + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + errObj, ok := evaluated.(*object.Error) + if !ok { + t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated) + continue + } + + if errObj.Message != tt.expectedMessage { + t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message) + } + } } func TestLetStatements(t *testing.T) { - tests := []struct { - input string - expected int64 - } { - {"let a = 5; a;", 5}, - {"let a = 5 * 5; a;", 25}, - {"let a = 5; let b = a; b;", 5}, - {"let a = 5; let b = a; let c = a + b + 5; c;", 15}, - } - - for _, tt := range tests { - testIntegerObject(t, testEval(tt.input), tt.expected) - } + tests := []struct { + input string + expected int64 + }{ + {"let a = 5; a;", 5}, + {"let a = 5 * 5; a;", 25}, + {"let a = 5; let b = a; b;", 5}, + {"let a = 5; let b = a; let c = a + b + 5; c;", 15}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + +func TestFunctionObject(t *testing.T) { + input := "fn(x) { x + 2; };" + + evaluated := testEval(input) + + fn, ok := evaluated.(*object.Function) + + if !ok { + t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated) + } + + if len(fn.Parameters) != 1 { + t.Fatalf("function has wrong parameters. Parameters=%+v", fn.Parameters) + } + + if fn.Parameters[0].String() != "x" { + t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0]) + } + + expectedBody := "(x + 2)" + + if fn.Body.String() != expectedBody { + t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String()) + } +} + +func TestFunctionApplication(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let identity = fn(x) { x; }; identity(5);", 5}, + {"let identity = fn(x) { return x; }; identity(5);", 5}, + {"let double = fn(x) { x * 2; }; double(5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5, 5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20}, + {"fn(x) { x; }(5)", 5}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } } func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) program := p.ParseProgram() - env := object.NewEnvironment() + env := object.NewEnvironment() return Eval(program, env) } diff --git a/object/environment.go b/object/environment.go index 4b37749..bd89056 100644 --- a/object/environment.go +++ b/object/environment.go @@ -1,20 +1,20 @@ package object func NewEnvironment() *Environment { - s := make(map[string]Object) - return &Environment{store: s} + s := make(map[string]Object) + return &Environment{store: s} } type Environment struct { - store map[string]Object + store map[string]Object } func (e *Environment) Get(name string) (Object, bool) { - obj, ok := e.store[name] - return obj, ok + obj, ok := e.store[name] + return obj, ok } func (e *Environment) Set(name string, val Object) Object { - e.store[name] = val - return val + e.store[name] = val + return val } diff --git a/object/object.go b/object/object.go index 20deb0c..793ff6b 100644 --- a/object/object.go +++ b/object/object.go @@ -1,6 +1,11 @@ package object -import "fmt" +import ( + "bytes" + "fmt" + "mana/ast" + "strings" +) type ObjectType string @@ -9,7 +14,8 @@ const ( BOOLEAN_OBJ = "BOOLEAN" NULL_OBJ = "NULL" RETURN_VALUE_OBJ = "RETURN_VALUE" - ERROR_OBJ = "ERROR" + FUNCTION_OBJ = "FUNCTION" + ERROR_OBJ = "ERROR" ) type Object interface { @@ -32,7 +38,13 @@ type ReturnValue struct { } type Error struct { - Message string + Message string +} + +type Function struct { + Parameters []*ast.Identifier + Body *ast.BlockStatement + Env *Environment } func (i *Integer) Type() ObjectType { @@ -67,5 +79,28 @@ func (rv *ReturnValue) Inspect() string { return rv.Value.Inspect() } +func (f *Function) Type() ObjectType { + return FUNCTION_OBJ +} + +func (f *Function) Inspect() string { + var out bytes.Buffer + + params := []string{} + + for _, p := range f.Parameters { + params = append(params, p.String()) + } + + out.WriteString("fn") + out.WriteString("(") + out.WriteString(strings.Join(params, ", ")) + out.WriteString(") {\n") + out.WriteString(f.Body.String()) + out.WriteString("\n}") + + return out.String() +} + func (e *Error) Type() ObjectType { return ERROR_OBJ } -func (e *Error) Inspect() string { return "ERROR:" + e.Message } +func (e *Error) Inspect() string { return "ERROR:" + e.Message } diff --git a/parser/parser.go b/parser/parser.go index c7d6234..3fe3d9f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -252,7 +252,7 @@ func (p *Parser) parseExpression(precedence int) ast.Expression { p.nextToken() leftExp = infix(leftExp) - + } return leftExp @@ -294,21 +294,29 @@ func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression { func (p *Parser) parseIfExpression() ast.Expression { expression := &ast.IfExpression{Token: p.curToken} - if !p.expectPeek(tokens.LPAREN) { return nil } + if !p.expectPeek(tokens.LPAREN) { + return nil + } p.nextToken() expression.Condition = p.parseExpression(LOWEST) - if !p.expectPeek(tokens.RPAREN) { return nil } + if !p.expectPeek(tokens.RPAREN) { + return nil + } - if !p.expectPeek(tokens.LBRACE) { return nil } + if !p.expectPeek(tokens.LBRACE) { + return nil + } expression.Consequence = p.parseBlockStatement() if p.peekTokenIs(tokens.ELSE) { p.nextToken() - if !p.expectPeek(tokens.LBRACE) { return nil } + if !p.expectPeek(tokens.LBRACE) { + return nil + } expression.Alternative = p.parseBlockStatement() } @@ -338,11 +346,15 @@ func (p *Parser) parseBlockStatement() *ast.BlockStatement { func (p *Parser) parseFunctionLiteral() ast.Expression { lit := &ast.FunctionLiteral{Token: p.curToken} - if !p.expectPeek(tokens.LPAREN) { return nil } + if !p.expectPeek(tokens.LPAREN) { + return nil + } lit.Parameters = p.parseFunctionParameters() - if !p.expectPeek(tokens.LBRACE) { return nil } + if !p.expectPeek(tokens.LBRACE) { + return nil + } lit.Body = p.parseBlockStatement() @@ -371,7 +383,9 @@ func (p *Parser) parseFunctionParameters() []*ast.Identifier { identifiers = append(identifiers, ident) } - if !p.expectPeek(tokens.RPAREN) { return nil } + if !p.expectPeek(tokens.RPAREN) { + return nil + } return identifiers } @@ -382,7 +396,9 @@ func (p *Parser) parseGroupedExpression() ast.Expression { exp := p.parseExpression(LOWEST) - if !p.expectPeek(tokens.RPAREN) { return nil } + if !p.expectPeek(tokens.RPAREN) { + return nil + } return exp } @@ -413,7 +429,9 @@ func (p *Parser) parseCallArguments() []ast.Expression { args = append(args, p.parseExpression(LOWEST)) } - if !p.expectPeek(tokens.RPAREN) { return nil } + if !p.expectPeek(tokens.RPAREN) { + return nil + } return args } @@ -453,7 +471,7 @@ func (p *Parser) peekError(t tokens.TokenType) { p.errors = append(p.errors, msg) } -// peek and cur precedences +// peek and cur precedences func (p *Parser) peekPrecedence() int { if p, ok := precedences[p.peekToken.Type]; ok { diff --git a/parser/parser_test.go b/parser/parser_test.go index a9754f5..08539b7 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -9,10 +9,10 @@ import ( func TestLetStatements(t *testing.T) { tests := []struct { - input string + input string expectedIdentifier string - expectedValue interface{} - } { + expectedValue interface{} + }{ {"let x = 5;", "x", 5}, {"let y = true;", "y", true}, {"let foobar = y;", "foobar", "y"}, @@ -303,9 +303,9 @@ func TestBooleanExpression(t *testing.T) { // Prefix expression tests. func TestParsingPrefixExpressions(t *testing.T) { var prefixTests = []struct { - input string - operator string - value interface{} + input string + operator string + value interface{} }{ {"!5;", "!", 5}, {"-15;", "-", 15}, @@ -722,7 +722,7 @@ func TestFunctionParameterParsing(t *testing.T) { var program *ast.Program = p.ParseProgram() checkParserErrors(t, p) - + stmt := program.Statements[0].(*ast.ExpressionStatement) function := stmt.Expression.(*ast.FunctionLiteral) diff --git a/parser/parser_tracing.go b/parser/parser_tracing.go index a0819da..2f3fcdb 100644 --- a/parser/parser_tracing.go +++ b/parser/parser_tracing.go @@ -34,4 +34,4 @@ func trace(msg string) string { func untrace(msg string) { tracePrint("END " + msg) decIdent() -}
\ No newline at end of file +} diff --git a/repl/repl.go b/repl/repl.go index 59c56d6..afee75e 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -4,10 +4,10 @@ import ( "bufio" "fmt" "io" + "mana/evaluator" "mana/lexer" + "mana/object" "mana/parser" - "mana/evaluator" - "mana/object" ) // PROMPT is the prompt for the REPL. @@ -19,13 +19,13 @@ const MANA_START = ` ██║╚██╔╝██║██╔══██║██║╚████║██╔══██║ ██║░╚═╝░██║██║░░██║██║░╚███║██║░░██║ ╚═╝░░░░░╚═╝╚═╝░░╚═╝╚═╝░░╚══╝╚═╝░░╚═╝ -` +` func Start(in io.Reader, out io.Writer) { var scanner *bufio.Scanner = bufio.NewScanner(in) - env := object.NewEnvironment() + env := object.NewEnvironment() - io.WriteString(out, MANA_START + "\n") + io.WriteString(out, MANA_START+"\n") for { fmt.Fprint(out, PROMPT) @@ -57,6 +57,6 @@ func Start(in io.Reader, out io.Writer) { func printParserErrors(out io.Writer, errors []string) { io.WriteString(out, "ParseError:\n") for _, msg := range errors { - io.WriteString(out, "\t" + msg + "\n") + io.WriteString(out, "\t"+msg+"\n") } } |
