aboutsummaryrefslogtreecommitdiff
path: root/evaluator
diff options
context:
space:
mode:
authorBobby <[email protected]>2024-03-08 21:38:19 +0000
committerBobby <[email protected]>2024-03-08 21:38:19 +0000
commit87b188bae2c8a2a9f81e872805d072be7ec910b2 (patch)
treeb3c48ba591966e424b1bb55dba1f652733368a34 /evaluator
parentb52f4e9b4140f482ad966aa354b39cd305a212ec (diff)
downloadmana-87b188bae2c8a2a9f81e872805d072be7ec910b2.tar.xz
mana-87b188bae2c8a2a9f81e872805d072be7ec910b2.zip
ast: fn init
Diffstat (limited to 'evaluator')
-rw-r--r--evaluator/evaluator.go134
-rw-r--r--evaluator/evaluator_test.go176
2 files changed, 193 insertions, 117 deletions
diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go
index 9cead89..f765d02 100644
--- a/evaluator/evaluator.go
+++ b/evaluator/evaluator.go
@@ -1,9 +1,9 @@
package evaluator
import (
+ "fmt"
"mana/ast"
"mana/object"
- "fmt"
)
var (
@@ -25,15 +25,15 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.BlockStatement:
return evalBlockStatement(node, env)
- case *ast.LetStatement:
- val := Eval(node.Value, env)
- if isError(val) {
- return val
- }
- env.Set(node.Name.Value, val)
+ case *ast.LetStatement:
+ val := Eval(node.Value, env)
+ if isError(val) {
+ return val
+ }
+ env.Set(node.Name.Value, val)
- case *ast.Identifier:
- return evalIdentifier(node, env)
+ case *ast.Identifier:
+ return evalIdentifier(node, env)
// Expressions
case *ast.IntegerLiteral:
@@ -44,30 +44,48 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.PrefixExpression:
right := Eval(node.Right, env)
- if isError(right) {
- return right
- }
+ if isError(right) {
+ return right
+ }
return evalPrefixExpression(node.Operator, right)
case *ast.InfixExpression:
left := Eval(node.Left, env)
- if isError(left) {
- return left
- }
+ if isError(left) {
+ return left
+ }
right := Eval(node.Right, env)
- if isError(right) {
- return right
- }
+ if isError(right) {
+ return right
+ }
return evalInfixExpression(node.Operator, left, right)
case *ast.IfExpression:
return evalIfExpression(node, env)
+ case *ast.CallExpression:
+ function := Eval(node.Function, env)
+
+ if isError(function) {
+ return function
+ }
+
+ args := evalExpressions(node.Arguments, env)
+ if len(args) == 1 && isError(args[0]) {
+ return args[0]
+ }
+
+ case *ast.FunctionLiteral:
+ params := node.Parameters
+ body := node.Body
+
+ return &object.Function{Parameters: params, Body: body, Env: env}
+
case *ast.ReturnStatement:
val := Eval(node.ReturnValue, env)
- if isError(val) {
- return val
- }
+ if isError(val) {
+ return val
+ }
return &object.ReturnValue{Value: val}
}
@@ -80,30 +98,44 @@ func evalProgram(program *ast.Program, env *object.Environment) object.Object {
for _, statement := range program.Statements {
result = Eval(statement, env)
- switch result := result.(type) {
- case *object.ReturnValue:
- return result.Value
- case *object.Error:
- return result
- }
+ switch result := result.(type) {
+ case *object.ReturnValue:
+ return result.Value
+ case *object.Error:
+ return result
+ }
}
return result
}
-
func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object {
var result object.Object
for _, statement := range block.Statements {
result = Eval(statement, env)
- if result != nil {
- rt := result.Type()
- if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ {
- return result
- }
- }
+ if result != nil {
+ rt := result.Type()
+ if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ {
+ return result
+ }
+ }
+ }
+
+ return result
+}
+
+func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object {
+ var result []object.Object
+
+ for _, e := range exps {
+ evaluated := Eval(e, env)
+ if isError(evaluated) {
+ return []object.Object{evaluated}
+ }
+
+ result = append(result, evaluated)
}
return result
@@ -123,7 +155,7 @@ func evalPrefixExpression(operator string, right object.Object) object.Object {
case "-":
return evalMinusPrefixOperatorExpression(right)
default:
- return newError("unknown operator: %s%s", operator, right.Type())
+ return newError("unknown operator: %s%s", operator, right.Type())
}
}
@@ -135,10 +167,10 @@ func evalInfixExpression(operator string, left, right object.Object) object.Obje
return nativeBoolToBooleanObject(left == right)
case operator == "!=":
return nativeBoolToBooleanObject(left != right)
- case left.Type() != right.Type():
- return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type())
+ case left.Type() != right.Type():
+ return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type())
default:
- return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
+ return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
}
}
@@ -195,9 +227,9 @@ func evalIntegerInfixExpression(operator string, left, right object.Object) obje
func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object {
condition := Eval(ie.Condition, env)
- if isError(condition) {
- return condition
- }
+ if isError(condition) {
+ return condition
+ }
if isTruthy(condition) {
return Eval(ie.Consequence, env)
} else if ie.Alternative != nil {
@@ -208,11 +240,11 @@ func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Obje
}
func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object {
- val, ok := env.Get(node.Value)
- if !ok {
- return newError("identifier not found: " + node.Value)
- }
- return val
+ val, ok := env.Get(node.Value)
+ if !ok {
+ return newError("identifier not found: " + node.Value)
+ }
+ return val
}
func isTruthy(obj object.Object) bool {
@@ -229,12 +261,12 @@ func isTruthy(obj object.Object) bool {
}
func newError(format string, a ...interface{}) *object.Error {
- return &object.Error{Message: fmt.Sprintf(format, a...)}
+ return &object.Error{Message: fmt.Sprintf(format, a...)}
}
func isError(obj object.Object) bool {
- if obj != nil {
- return obj.Type() == object.ERROR_OBJ
- }
- return false
+ if obj != nil {
+ return obj.Type() == object.ERROR_OBJ
+ }
+ return false
}
diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go
index 38fed91..f83cf7f 100644
--- a/evaluator/evaluator_test.go
+++ b/evaluator/evaluator_test.go
@@ -5,7 +5,7 @@ import (
"mana/object"
"mana/parser"
- "testing"
+ "testing"
)
func TestEvalIntegerExpression(t *testing.T) {
@@ -138,36 +138,36 @@ func TestReturnStatements(t *testing.T) {
}
func TestErrorhandling(t *testing.T) {
- tests := [] struct {
- input string
- expectedMessage string
- } {
- {
- "5 + true;",
- "type mismatch: INTEGER + BOOLEAN",
- },
- {
- "5 + true; 5;",
- "type mismatch: INTEGER + BOOLEAN",
- },
- {
- "-true",
- "unknown operator: -BOOLEAN",
- },
- {
- "true + false;",
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- "5; true + false; 5",
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- "if (10 > 1) { true + false; }",
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- `
+ tests := []struct {
+ input string
+ expectedMessage string
+ }{
+ {
+ "5 + true;",
+ "type mismatch: INTEGER + BOOLEAN",
+ },
+ {
+ "5 + true; 5;",
+ "type mismatch: INTEGER + BOOLEAN",
+ },
+ {
+ "-true",
+ "unknown operator: -BOOLEAN",
+ },
+ {
+ "true + false;",
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ "5; true + false; 5",
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ "if (10 > 1) { true + false; }",
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ `
if (10 > 1) {
if (10 > 1) {
return true + false;
@@ -175,50 +175,94 @@ func TestErrorhandling(t *testing.T) {
return 1;
}
`,
- "unknown operator: BOOLEAN + BOOLEAN",
- },
- {
- "foobar",
- "identifier not found: foobar",
- },
- }
-
- for _, tt := range tests {
- evaluated := testEval(tt.input)
-
- errObj, ok := evaluated.(*object.Error)
- if !ok {
- t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated)
- continue
- }
-
- if errObj.Message != tt.expectedMessage {
- t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message)
- }
- }
+ "unknown operator: BOOLEAN + BOOLEAN",
+ },
+ {
+ "foobar",
+ "identifier not found: foobar",
+ },
+ }
+
+ for _, tt := range tests {
+ evaluated := testEval(tt.input)
+
+ errObj, ok := evaluated.(*object.Error)
+ if !ok {
+ t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated)
+ continue
+ }
+
+ if errObj.Message != tt.expectedMessage {
+ t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMessage, errObj.Message)
+ }
+ }
}
func TestLetStatements(t *testing.T) {
- tests := []struct {
- input string
- expected int64
- } {
- {"let a = 5; a;", 5},
- {"let a = 5 * 5; a;", 25},
- {"let a = 5; let b = a; b;", 5},
- {"let a = 5; let b = a; let c = a + b + 5; c;", 15},
- }
-
- for _, tt := range tests {
- testIntegerObject(t, testEval(tt.input), tt.expected)
- }
+ tests := []struct {
+ input string
+ expected int64
+ }{
+ {"let a = 5; a;", 5},
+ {"let a = 5 * 5; a;", 25},
+ {"let a = 5; let b = a; b;", 5},
+ {"let a = 5; let b = a; let c = a + b + 5; c;", 15},
+ }
+
+ for _, tt := range tests {
+ testIntegerObject(t, testEval(tt.input), tt.expected)
+ }
+}
+
+func TestFunctionObject(t *testing.T) {
+ input := "fn(x) { x + 2; };"
+
+ evaluated := testEval(input)
+
+ fn, ok := evaluated.(*object.Function)
+
+ if !ok {
+ t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated)
+ }
+
+ if len(fn.Parameters) != 1 {
+ t.Fatalf("function has wrong parameters. Parameters=%+v", fn.Parameters)
+ }
+
+ if fn.Parameters[0].String() != "x" {
+ t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0])
+ }
+
+ expectedBody := "(x + 2)"
+
+ if fn.Body.String() != expectedBody {
+ t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String())
+ }
+}
+
+func TestFunctionApplication(t *testing.T) {
+ tests := []struct {
+ input string
+ expected int64
+ }{
+ {"let identity = fn(x) { x; }; identity(5);", 5},
+ {"let identity = fn(x) { return x; }; identity(5);", 5},
+ {"let double = fn(x) { x * 2; }; double(5);", 10},
+ {"let add = fn(x, y) { x + y; }; add(5, 5);", 10},
+ {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20},
+ {"fn(x) { x; }(5)", 5},
+ }
+
+ for _, tt := range tests {
+ testIntegerObject(t, testEval(tt.input), tt.expected)
+ }
}
func testEval(input string) object.Object {
l := lexer.New(input)
p := parser.New(l)
program := p.ParseProgram()
- env := object.NewEnvironment()
+ env := object.NewEnvironment()
return Eval(program, env)
}