diff options
| author | Bobby <[email protected]> | 2024-03-08 21:38:19 +0000 |
|---|---|---|
| committer | Bobby <[email protected]> | 2024-03-08 21:38:19 +0000 |
| commit | 87b188bae2c8a2a9f81e872805d072be7ec910b2 (patch) | |
| tree | b3c48ba591966e424b1bb55dba1f652733368a34 /evaluator | |
| parent | b52f4e9b4140f482ad966aa354b39cd305a212ec (diff) | |
| download | mana-87b188bae2c8a2a9f81e872805d072be7ec910b2.tar.xz mana-87b188bae2c8a2a9f81e872805d072be7ec910b2.zip | |
ast: fn init
Diffstat (limited to 'evaluator')
| -rw-r--r-- | evaluator/evaluator.go | 134 | ||||
| -rw-r--r-- | evaluator/evaluator_test.go | 176 |
2 files changed, 193 insertions, 117 deletions
diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 9cead89..f765d02 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -1,9 +1,9 @@ package evaluator import ( + "fmt" "mana/ast" "mana/object" - "fmt" ) var ( @@ -25,15 +25,15 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.BlockStatement: return evalBlockStatement(node, env) - case *ast.LetStatement: - val := Eval(node.Value, env) - if isError(val) { - return val - } - env.Set(node.Name.Value, val) + case *ast.LetStatement: + val := Eval(node.Value, env) + if isError(val) { + return val + } + env.Set(node.Name.Value, val) - case *ast.Identifier: - return evalIdentifier(node, env) + case *ast.Identifier: + return evalIdentifier(node, env) // Expressions case *ast.IntegerLiteral: @@ -44,30 +44,48 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.PrefixExpression: right := Eval(node.Right, env) - if isError(right) { - return right - } + if isError(right) { + return right + } return evalPrefixExpression(node.Operator, right) case *ast.InfixExpression: left := Eval(node.Left, env) - if isError(left) { - return left - } + if isError(left) { + return left + } right := Eval(node.Right, env) - if isError(right) { - return right - } + if isError(right) { + return right + } return evalInfixExpression(node.Operator, left, right) case *ast.IfExpression: return evalIfExpression(node, env) + case *ast.CallExpression: + function := Eval(node.Function, env) + + if isError(function) { + return function + } + + args := evalExpressions(node.Arguments, env) + if len(args) == 1 && isError(args[0]) { + return args[0] + } + + case *ast.FunctionLiteral: + params := node.Parameters + body := node.Body + + return &object.Function{Parameters: params, Body: body, Env: env} + case *ast.ReturnStatement: val := Eval(node.ReturnValue, env) - if isError(val) { - return val - } + if isError(val) { + return val + } return &object.ReturnValue{Value: val} } @@ -80,30 +98,44 @@ func evalProgram(program *ast.Program, env *object.Environment) object.Object { for _, statement := range program.Statements { result = Eval(statement, env) - switch result := result.(type) { - case *object.ReturnValue: - return result.Value - case *object.Error: - return result - } + switch result := result.(type) { + case *object.ReturnValue: + return result.Value + case *object.Error: + return result + } } return result } - func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object { var result object.Object for _, statement := range block.Statements { result = Eval(statement, env) - if result != nil { - rt := result.Type() - if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { - return result - } - } + if result != nil { + rt := result.Type() + if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { + return result + } + } + } + + return result +} + +func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object { + var result []object.Object + + for _, e := range exps { + evaluated := Eval(e, env) + if isError(evaluated) { + return []object.Object{evaluated} + } + + result = append(result, evaluated) } return result @@ -123,7 +155,7 @@ func evalPrefixExpression(operator string, right object.Object) object.Object { case "-": return evalMinusPrefixOperatorExpression(right) default: - return newError("unknown operator: %s%s", operator, right.Type()) + return newError("unknown operator: %s%s", operator, right.Type()) } } @@ -135,10 +167,10 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje return nativeBoolToBooleanObject(left == right) case operator == "!=": return nativeBoolToBooleanObject(left != right) - case left.Type() != right.Type(): - return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type()) + case left.Type() != right.Type(): + return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type()) default: - return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) + return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) } } @@ -195,9 +227,9 @@ func evalIntegerInfixExpression(operator string, left, right object.Object) obje func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object { condition := Eval(ie.Condition, env) - if isError(condition) { - return condition - } + if isError(condition) { + return condition + } if isTruthy(condition) { return Eval(ie.Consequence, env) } else if ie.Alternative != nil { @@ -208,11 +240,11 @@ func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Obje } func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object { - val, ok := env.Get(node.Value) - if !ok { - return newError("identifier not found: " + node.Value) - } - return val + val, ok := env.Get(node.Value) + if !ok { + return newError("identifier not found: " + node.Value) + } + return val } func isTruthy(obj object.Object) bool { @@ -229,12 +261,12 @@ func isTruthy(obj object.Object) bool { } func newError(format string, a ...interface{}) *object.Error { - return &object.Error{Message: fmt.Sprintf(format, a...)} + return &object.Error{Message: fmt.Sprintf(format, a...)} } func isError(obj object.Object) bool { - if obj != nil { - return obj.Type() == object.ERROR_OBJ - } - return false + if obj != nil { + return obj.Type() == object.ERROR_OBJ + } + return false } diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 38fed91..f83cf7f 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -5,7 +5,7 @@ import ( "mana/object" "mana/parser" - "testing" + "testing" ) func TestEvalIntegerExpression(t *testing.T) { @@ -138,36 +138,36 @@ func TestReturnStatements(t *testing.T) { } func TestErrorhandling(t *testing.T) { - tests := [] struct { - input string - expectedMessage string - } { - { - "5 + true;", - "type mismatch: INTEGER + BOOLEAN", - }, - { - "5 + true; 5;", - "type mismatch: INTEGER + BOOLEAN", - }, - { - "-true", - "unknown operator: -BOOLEAN", - }, - { - "true + false;", - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - "5; true + false; 5", - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - "if (10 > 1) { true + false; }", - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - ` + tests := []struct { + input string + expectedMessage string + }{ + { + "5 + true;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "5 + true; 5;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "-true", + "unknown operator: -BOOLEAN", + }, + { + "true + false;", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "5; true + false; 5", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "if (10 > 1) { true + false; }", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + ` if (10 > 1) { if (10 > 1) { return true + false; @@ -175,50 +175,94 @@ func TestErrorhandling(t *testing.T) { return 1; } `, - "unknown operator: BOOLEAN + BOOLEAN", - }, - { - "foobar", - "identifier not found: foobar", - }, - } - - for _, tt := range tests { - evaluated := testEval(tt.input) - - errObj, ok := evaluated.(*object.Error) - if !ok { - t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated) - continue - } - - if errObj.Message != tt.expectedMessage { - t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message) - } - } + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "foobar", + "identifier not found: foobar", + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + errObj, ok := evaluated.(*object.Error) + if !ok { + t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated) + continue + } + + if errObj.Message != tt.expectedMessage { + t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message) + } + } } func TestLetStatements(t *testing.T) { - tests := []struct { - input string - expected int64 - } { - {"let a = 5; a;", 5}, - {"let a = 5 * 5; a;", 25}, - {"let a = 5; let b = a; b;", 5}, - {"let a = 5; let b = a; let c = a + b + 5; c;", 15}, - } - - for _, tt := range tests { - testIntegerObject(t, testEval(tt.input), tt.expected) - } + tests := []struct { + input string + expected int64 + }{ + {"let a = 5; a;", 5}, + {"let a = 5 * 5; a;", 25}, + {"let a = 5; let b = a; b;", 5}, + {"let a = 5; let b = a; let c = a + b + 5; c;", 15}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + +func TestFunctionObject(t *testing.T) { + input := "fn(x) { x + 2; };" + + evaluated := testEval(input) + + fn, ok := evaluated.(*object.Function) + + if !ok { + t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated) + } + + if len(fn.Parameters) != 1 { + t.Fatalf("function has wrong parameters. Parameters=%+v", fn.Parameters) + } + + if fn.Parameters[0].String() != "x" { + t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0]) + } + + expectedBody := "(x + 2)" + + if fn.Body.String() != expectedBody { + t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String()) + } +} + +func TestFunctionApplication(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let identity = fn(x) { x; }; identity(5);", 5}, + {"let identity = fn(x) { return x; }; identity(5);", 5}, + {"let double = fn(x) { x * 2; }; double(5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5, 5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20}, + {"fn(x) { x; }(5)", 5}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } } func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) program := p.ParseProgram() - env := object.NewEnvironment() + env := object.NewEnvironment() return Eval(program, env) } |
