diff options
| -rw-r--r-- | evaluator/evaluator.go | 30 | ||||
| -rw-r--r-- | evaluator/evaluator_test.go | 25 | ||||
| -rw-r--r-- | object/object.go | 17 |
3 files changed, 65 insertions, 7 deletions
diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 7aeaa24..3a9e54a 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -16,13 +16,13 @@ func Eval(node ast.Node) object.Object { switch node := node.(type) { // Statements case *ast.Program: - return evalStatements(node.Statements) + return evalProgram(node) case *ast.ExpressionStatement: return Eval(node.Expression) case *ast.BlockStatement: - return evalStatements(node.Statements) + return evalBlockStatement(node) // Expressions case *ast.IntegerLiteral: @@ -42,16 +42,38 @@ func Eval(node ast.Node) object.Object { case *ast.IfExpression: return evalIfExpression(node) + + case *ast.ReturnStatement: + val := Eval(node.ReturnValue) + return &object.ReturnValue{Value: val} } return nil } -func evalStatements(stmts []ast.Statement) object.Object { +func evalProgram(program *ast.Program) object.Object { var result object.Object - for _, statement := range stmts { + for _, statement := range program.Statements { result = Eval(statement) + + if returnValue, ok := result.(*object.ReturnValue); ok { + return returnValue.Value + } + } + + return result +} + +func evalBlockStatement(block *ast.BlockStatement) object.Object { + var result object.Object + + for _, statement := range block.Statements { + result = Eval(statement) + + if result != nil && result.Type() == object.RETURN_VALUE_OBJ { + return result + } } return result diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index b124144..ca7738f 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -112,6 +112,31 @@ func TestIfElseExpressions(t *testing.T) { } } +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"return 10;", 10}, + {"return 10; 9;", 10}, + {"return 2 * 5; 9;", 10}, + {"9; return 2 * 5; 9;", 10}, + {` + if (10 > 1) { + if (10 > 1) { + return 10; + } + return 1; + } + `, 10}, + } + + for _, tt := range tests { + evaulated := testEval(tt.input) + testIntegerObject(t, evaulated, tt.expected) + } +} + func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) diff --git a/object/object.go b/object/object.go index cac36ac..fbecf88 100644 --- a/object/object.go +++ b/object/object.go @@ -5,9 +5,10 @@ import "fmt" type ObjectType string const ( - INTEGER_OBJ = "INTEGER" - BOOLEAN_OBJ = "BOOLEAN" - NULL_OBJ = "NULL" + INTEGER_OBJ = "INTEGER" + BOOLEAN_OBJ = "BOOLEAN" + NULL_OBJ = "NULL" + RETURN_VALUE_OBJ = "RETURN_VALUE" ) type Object interface { @@ -25,6 +26,10 @@ type Boolean struct { type Null struct{} +type ReturnValue struct { + Value Object +} + func (i *Integer) Type() ObjectType { return INTEGER_OBJ } @@ -49,4 +54,10 @@ func (n *Null) Inspect() string { return "null" } +func (rv *ReturnValue) Type() ObjectType { + return RETURN_VALUE_OBJ +} +func (rv *ReturnValue) Inspect() string { + return rv.Value.Inspect() +} |
