diff --git a/quartz/ast.qz b/quartz/ast.qz index 78052ff5..cdc51419 100644 --- a/quartz/ast.qz +++ b/quartz/ast.qz @@ -244,6 +244,11 @@ enum Expression { }?, expansion: LExpression?, }, + t_closure_call: struct { + callee_type: Type, + callee: LExpression, + args: vec[LExpression], + }, t_project: struct { expr: LExpression, field: Ident, @@ -324,6 +329,13 @@ enum Expression { expr: LExpression, expr_type: Type, }, + t_closure: struct { + func: Function, + captures: vec[struct { + name: string, + type_: Type, + }], + }, } module Expression { @@ -382,6 +394,10 @@ enum Type { params: vec[Type], result: Type, }, + t_closure: struct { + params: vec[Type], + result: Type, + }, t_variadic_func: struct { params: vec[Type], result: Type, @@ -497,6 +513,20 @@ module Type { return builder.to_string(); } else if self.t_ptr != nil { return "ptr[{}]".format(self.t_ptr!.to_string()); + } else if self.t_closure != nil { + let builder = stringbuilder::new(); + builder.append("closure ("); + for i in 0..self.t_closure!.params.length { + let param = self.t_closure!.params.at(i); + builder.append(param.to_string()); + + if i != self.t_closure!.params.length - 1 { + builder.append(", "); + } + } + builder.append(format(") -> {}", self.t_closure!.result.to_string())); + + return builder.to_string(); } panic("Unknown type: {}", derive::to_string(self)); diff --git a/quartz/compiler.qz b/quartz/compiler.qz index b98c4760..8bc2509e 100644 --- a/quartz/compiler.qz +++ b/quartz/compiler.qz @@ -10,6 +10,7 @@ import quartz::value; import quartz::location; import quartz::preprocessor; import quartz::ir_path::let_call; +import quartz::ir_path::escaping; struct LoadedModule { path: Path, @@ -193,6 +194,7 @@ module Compiler { let term = irgen.run(module_).context("ircodegen").try; term = transform_let_call(term); + term = transform_escaping(term); let envs = environs(); if envs.has("GENERATE_IR") { diff --git a/quartz/generator.qz b/quartz/generator.qz index c6f800b3..c9d8a9f5 100644 --- a/quartz/generator.qz +++ b/quartz/generator.qz @@ -625,6 +625,7 @@ module Generator { }]](), result_type: self.main_signature_result!, body: start_body, + escaping: make[vec[string]](), ..nil, }); diff --git a/quartz/ir.qz b/quartz/ir.qz index 9d249c62..6e31d958 100644 --- a/quartz/ir.qz +++ b/quartz/ir.qz @@ -206,6 +206,7 @@ struct IrFunc { result_type: IrType, body: vec[IrTerm], ffi_export: string?, + escaping: vec[string], } struct IrCall { @@ -329,6 +330,19 @@ module IrType { }, }; } + if t.t_closure != nil { + let params = make[vec[IrType]](); + for param in t.t_closure!.params { + params.push(IrType::new(param)); + } + + return IrType { + t_func: struct { + params: params, + result_type: IrType::new(t.t_closure!.result), + }, + }; + } return panic("unknown type: {} (IrType::new)", t.to_string()); } diff --git a/quartz/ir_code_gen.qz b/quartz/ir_code_gen.qz index 74e8dea8..e0382837 100644 --- a/quartz/ir_code_gen.qz +++ b/quartz/ir_code_gen.qz @@ -205,6 +205,16 @@ module FunctionTable { return (self.functions.length - 1) as i32; } + + fun has(self, key: string): bool { + for i in 0..self.functions.length { + if self.functions.at(i).equal(key) { + return true; + } + } + + return false; + } } struct InterfaceTable { @@ -291,6 +301,15 @@ struct IrCodeGenerator { test_mode: bool, test_functions: vec[string], entrypoint: string, + closures: map[string, struct { + closure: IrTerm, + captures: vec[Type], + }], + captures: vec[struct { + name: string, + type_: Type, + }], + escaping: vec[string], } module IrCodeGenerator { @@ -314,6 +333,15 @@ module IrCodeGenerator { test_mode: test_mode, test_functions: make[vec[string]](), entrypoint: entrypoint, + closures: make[map[string, struct { + closure: IrTerm, + captures: vec[Type], + }]](), + captures: make[vec[struct { + name: string, + type_: Type, + }]](), + escaping: make[vec[string]](), }; } @@ -477,6 +505,7 @@ module IrCodeGenerator { }, }, }), + escaping: make[vec[string]](), ..nil, }, }); @@ -523,6 +552,7 @@ module IrCodeGenerator { t_address: true, }, body: body, + escaping: make[vec[string]](), ..nil, }, }); @@ -598,7 +628,10 @@ module IrCodeGenerator { } elements.push(IrTerm { - t_func: self.function(d.t_func!).try, + t_func: self.function(d.t_func!, make[vec[struct { + name: string, + type_: Type, + }]]()).try, }); } if d.t_let != nil { @@ -628,6 +661,16 @@ module IrCodeGenerator { if d.t_interface != nil { continue; } + if self.closures.list_keys().length > 0 { + for cname in self.closures.list_keys() { + elements.push(self.closures.(cname).closure); + } + + self.closures = make[map[string, struct { + closure: IrTerm, + captures: vec[Type], + }]](); + } } return elements; @@ -881,6 +924,7 @@ module IrCodeGenerator { t_nil: true, }, body: body, + escaping: make[vec[string]](), ..nil, }, }; @@ -997,6 +1041,7 @@ module IrCodeGenerator { result_type: IrType { t_nil: true, }, + escaping: make[vec[string]](), ..nil, }, }); @@ -1146,6 +1191,7 @@ module IrCodeGenerator { result_type: IrType { t_nil: true, }, + escaping: make[vec[string]](), ..nil, }, }); @@ -1164,9 +1210,16 @@ module IrCodeGenerator { }; } - fun function(self, func: Function): IrFunc or error { + fun function(self, func: Function, captures: vec[struct { + name: string, + type_: Type, + }]): IrFunc or error { self.current_function_name = func.name.data; self.function_no_allocation = func.no_allocation; + self.captures = captures; + + let prev_escaping = self.escaping; + self.escaping = make[vec[string]](); let body = make[vec[IrTerm]](); if !self.path_to(func.name.data).starts_with("quartz_core") { @@ -1204,12 +1257,16 @@ module IrCodeGenerator { }); } + let escaping_this = self.escaping; + self.escaping = prev_escaping; + return IrFunc { name: self.path_to(func.name.data), params: params, result_type: IrType::new(func.result_type), body: body, ffi_export: func.ffi_export, + escaping: escaping_this, }; } @@ -1881,6 +1938,25 @@ module IrCodeGenerator { t_ident: expr.t_ident!.resolved_path!.join("_"), }; } + for i in 0..self.captures.length { + let c = self.captures.(i); + if c.name == expr.t_ident!.name { + return IrTerm { + t_load: struct { + type_: IrType::new(c.type_), + address: IrTerm { + t_ident: "env", + }, + offset: IrCodeGenerator::wrap_mult_sizeof( + IrType::new(c.type_), + IrTerm { + t_i32: i, + }, + ), + }, + }; + } + } return IrTerm { t_ident: expr.t_ident!.name, @@ -3022,6 +3098,105 @@ module IrCodeGenerator { }, }; } + } else if expr.t_closure != nil { + expr.t_closure!.func.params.push(struct { + name: "env", + type_: Type { + t_any: true, + }, + }); + + let captures_types = make[vec[Type]](); + for c in expr.t_closure!.captures { + captures_types.push(c.type_); + self.escaping.push(c.name); + } + + self.closures.(expr.t_closure!.func.name.data) = (struct { + closure: IrTerm { + t_func: self.function(expr.t_closure!.func, expr.t_closure!.captures).try, + }, + captures: captures_types, + }); + + let elements = make[vec[IrTerm]](); + let captures_ir = make[vec[IrType]](); + let values = make[vec[struct { + label: string?, + type_: IrType, + term: IrTerm, + }]](); + for c in expr.t_closure!.captures { + captures_ir.push(IrType::new(c.type_)); + values.push(struct { + label: nil, + type_: IrType { + t_address: true, + }, + term: IrTerm { + t_load: struct { + type_: IrType::new(c.type_), + address: IrTerm { + t_ident: c.name + }, + offset: IrTerm { + t_i32: 0, + }, + } + }, + }); + } + + elements.push(IrTerm { + t_let: IrLet { + name: "env", + type_: IrType { + t_address: true, + }, + value: self.generate_array( + "struct", + captures_ir, + values, + ).try, + }, + }); + + let symbol = self.path_to(expr.t_closure!.func.name.data); + let function_id = self.functions.get_or_insert(symbol); + + elements.push(IrTerm { + t_i32: function_id, + }); + + return IrTerm { + t_seq: struct { + terms: elements, + }, + }; + } else if expr.t_closure_call != nil { + let callee_type = expr.t_closure_call!.callee_type; + callee_type.t_closure!.params.push( + Type { + t_any: true, + }, + ); + + let terms = make[vec[IrTerm]](); + for arg in expr.t_closure_call!.args { + terms.push(self.expression(arg).try); + } + + terms.push(IrTerm { + t_ident: "env", + }); + + return IrTerm { + t_dynamic_call: struct { + callee_type: IrType::new(callee_type), + callee_id: self.expression(expr.t_closure_call!.callee).try, + args: terms, + }, + }; } return panic("expression exhausted: {}", expr.to_string()); diff --git a/quartz/ir_path/escaping.qz b/quartz/ir_path/escaping.qz new file mode 100644 index 00000000..69658a6b --- /dev/null +++ b/quartz/ir_path/escaping.qz @@ -0,0 +1,122 @@ +struct EscapingTransformer { + function_body: vec[IrTerm], + escaping: vec[string], +} + +module EscapingTransformer { + fun new(): EscapingTransformer { + return EscapingTransformer{ + function_body: make[vec[IrTerm]](), + escaping: make[vec[string]](), + }; + } + + fun run(self, term: IrTerm): IrTerm { + let result = make[vec[IrTerm]](); + + if term.t_module != nil { + result.extend(self.decls(term.t_module!.elements)); + } + + return IrTerm { + t_module: struct { + elements: result, + }, + }; + } + + fun decls(self, decls: vec[IrTerm]): vec[IrTerm] { + let result = make[vec[IrTerm]](); + + for decl in decls { + if decl.t_func != nil { + result.push(self.function(decl.t_func!)); + } else if decl.t_module != nil { + result.push(IrTerm { + t_module: struct { + elements: self.decls(decl.t_module!.elements), + }, + }); + } else { + result.push(decl); + } + } + + return result; + } + + fun function(self, func: IrFunc): IrTerm { + self.escaping = func.escaping; + self.function_body = make[vec[IrTerm]](); + + self.expressions(func.body); + + return IrTerm { + t_func: IrFunc { + name: func.name, + params: func.params, + result_type: func.result_type, + body: self.function_body, + ffi_export: func.ffi_export, + escaping: make[vec[string]](), + }, + }; + } + + fun expressions(self, terms: vec[IrTerm]) { + for term in terms { + self.expression(term); + } + } + + fun expression(self, term: IrTerm) { + if term.t_let != nil { + let v_name = term.t_let!.name; + for name in self.escaping { + if v_name == name { + let value = term.t_let!.value; + term.t_let!.type_ = IrType { + t_address: true, + }; + term.t_let!.value = IrTerm { + t_call: IrCall { + callee: IrTerm { + t_ident: "quartz_core_alloc_with_rep", + }, + args: make[vec[IrTerm]]( + IrTerm { + t_nil: true, + }, + IrTerm { + t_i32: 8, + }, + ), + }, + }; + self.function_body.push(term); + self.function_body.push(IrTerm { + t_store: struct { + type_: term.t_let!.type_, + address: IrTerm { + t_ident: term.t_let!.name, + }, + offset: IrTerm { + t_i32: 0, + }, + value: value, + }, + }); + + return; + } + } + } + + self.function_body.push(term); + } +} + +fun transform_escaping(term: IrTerm): IrTerm { + let transformer = EscapingTransformer::new(); + return transformer.run(term); +} diff --git a/quartz/ir_path/let_call.qz b/quartz/ir_path/let_call.qz index 3d7a3856..f70b3032 100644 --- a/quartz/ir_path/let_call.qz +++ b/quartz/ir_path/let_call.qz @@ -57,6 +57,7 @@ module LetCallTransformer { result_type: func.result_type, body: self.function_body, ffi_export: func.ffi_export, + escaping: func.escaping, }, }; } diff --git a/quartz/parser.qz b/quartz/parser.qz index 4132314d..cb9043c4 100644 --- a/quartz/parser.qz +++ b/quartz/parser.qz @@ -465,6 +465,48 @@ module Parser { }; } + fun anonymous_function( + self, + ): struct { + data: Function, + location: Location, + } or error { + let start_token = self.expect("fun").try; + self.expect("lparen").try; + let params = self.parameters().try; + self.expect("rparen").try; + + let result_type = Type { + t_nil: true, + }; + if self.peek().try.lexeme.equal("colon") { + self.consume().try; + result_type = self.type_().try.data; + } + + let body = self.block().try; + + return struct { + data: Function { + name: Ident { + data: "", + location: Location::unknown(), + }, + result_type: result_type, + body: body, + params: params.params, + variadic: params.variadic, + no_allocation: false, + is_test: false, + ffi_export: nil, + }, + location: Location { + start: start_token.location.start, + end: body.location.end, + }, + }; + } + fun parameters( self, ): struct { @@ -2067,6 +2109,22 @@ module Parser { location: token.location, }; } + if self.peek().try.lexeme == "fun" { + let f = self.anonymous_function().try; + + return LExpression { + data: Expression { + t_closure: struct { + func: f.data, + captures: make[vec[struct { + name: string, + type_: Type, + }]](), + }, + }, + location: f.location, + }; + } return _ or error::new(format("Unexpected token {} (term)", self.peek().try.lexeme.to_string())); } diff --git a/quartz/typecheck.qz b/quartz/typecheck.qz index 67e26d3d..4da15ceb 100644 --- a/quartz/typecheck.qz +++ b/quartz/typecheck.qz @@ -36,6 +36,7 @@ struct Typechecker { }], }], current_function_name: string?, + captured: vec[string], } module Typechecker { @@ -660,6 +661,7 @@ module Typechecker { }], }]](), current_function_name: nil, + captured: make[vec[string]](), }; } @@ -817,7 +819,7 @@ module Typechecker { }, ); - self.function(d.t_func!).try; + self.function(d.t_func!, make[map[string, LType]]()).try; self.globals.insert( self.path_to(d.t_func!.name.data), @@ -932,9 +934,15 @@ module Typechecker { return _ or error::new(format("unimplemented: decl, {}", d.to_string())); } - fun function(self, f: Function): nil or error { + fun function(self, f: Function, locals_init: map[string, LType]): Type or error { let locals = self.locals; self.locals = make[map[string, LType]](); + for k in locals_init.list_keys() { + self.locals.insert(k, locals_init.(k)); + } + + self.captured = make[vec[string]](); + self.current_function_name = f.name.data?; for i in 0..f.params.length { @@ -962,10 +970,40 @@ module Typechecker { self.result_type = f.result_type?; self.block(f.body).try; + let captured_local = make[map[string, bool]](); + for c in self.captured { + captured_local.insert(c, true); + } + for p in f.params { + captured_local.insert(p.name, false); + } + for l in self.locals.list_keys() { + if !locals_init.has(l) { + captured_local.insert(l, false); + } + } + + self.captured = make[vec[string]](); + for c in captured_local.list_keys() { + if captured_local.(c) { + self.captured.push(c); + } + } + self.locals = locals; self.current_function_name = nil; - return nil; + let ps = make[vec[Type]](); + for param in f.params { + ps.push(param.type_); + } + + return Type { + t_func: struct { + params: ps, + result: f.result_type, + }, + }; } fun block(self, b: LBlock): nil or error { @@ -1303,6 +1341,7 @@ module Typechecker { return _ or error::new("unimplemented: binop, {}".format(expr.to_string())); } else if expr.t_ident != nil { if self.locals.has(expr.t_ident!.name) { + self.captured.push(expr.t_ident!.name); let t = self.locals.at(expr.t_ident!.name); self.set_search_node_type(t.data, lexpr.location); self.set_search_node_definition(self.current_path, t.location, lexpr.location); @@ -1398,6 +1437,38 @@ module Typechecker { index: func_type.params.length, }?; + return func_type.result; + } else if callee_type.t_closure != nil { + let func_type = callee_type.t_closure!; + if func_type.params.length != expr.t_call!.args.length { + return _ or error::new("wrong number of arguments, {} != {}".format( + derive::to_string(func_type.params), + derive::to_string(expr.t_call!.args), + )).context(ErrorSource { + path: self.current_path, + location: expr.t_call!.callee.location, + }); + } + + for i in 0..expr.t_call!.args.length { + let arg = expr.t_call!.args.at(i); + let param_type = func_type.params.at(i); + + let result = self.check_expression(arg, param_type).try; + if result.expr != nil { + expr.t_call!.args.(i) = result.expr!; + } + } + + // NOTE: `expr = ..` won't work, since expr is a local variable defined in this function + lexpr.data = Expression { + t_closure_call: struct { + callee_type: callee_type, + callee: expr.t_call!.callee, + args: expr.t_call!.args, + } + }; + return func_type.result; } else { return _ or error::new("not a function, {}".format(callee_type.to_string())); @@ -1993,6 +2064,35 @@ module Typechecker { return Type { t_ptr: expr_type, }; + } else if expr.t_closure != nil { + if self.current_function_name == nil { + return _ or error::new("closure outside of function, {}".format(expr.to_string())); + } + + expr.t_closure!.func.name.data = "{}__closure_{}".format(self.current_function_name!, lexpr.location.start!.to_string()); + + let prev = self.current_function_name; + let func_type = self.function(expr.t_closure!.func, self.locals).try; + self.current_function_name = prev; + + let captures = make[vec[struct { + name: string, + type_: Type, + }]](); + for c in self.captured { + captures.push(struct { + name: c, + type_: self.locals.(c).data, + }); + } + expr.t_closure!.captures = captures; + + return Type { + t_closure: struct { + params: func_type.t_func!.params, + result: func_type.t_func!.result, + }, + }; } else { return _ or error::new("unimplemented: expression, {}".format(expr.to_string())); } @@ -2107,7 +2207,7 @@ module Typechecker { return make[map[string, Type]](); } - return _ or error::new("unimplemented: resolve_record_type, {}".format(derive::to_string(type_))); + return _ or error::new("unimplemented: resolve_record_type, {} with args {}".format(derive::to_string(type_), derive::to_string(args))); } fun resolve_path( diff --git a/tests/cases/closures/test1.qz b/tests/cases/closures/test1.qz new file mode 100644 index 00000000..2616dbd0 --- /dev/null +++ b/tests/cases/closures/test1.qz @@ -0,0 +1,13 @@ +fun main() { + let f = fun (x: i32): i32 { + return x + 1; + }; + let g = fun (x: i32): i32 { + return x + 2; + }; + + assert_eq(f(100), 101); + assert_eq(g(100), 102); + + println("ok"); +} diff --git a/tests/cases/closures/test1.stdout b/tests/cases/closures/test1.stdout new file mode 100644 index 00000000..9766475a --- /dev/null +++ b/tests/cases/closures/test1.stdout @@ -0,0 +1 @@ +ok diff --git a/tests/cases/closures/test2.qz b/tests/cases/closures/test2.qz new file mode 100644 index 00000000..c9693144 --- /dev/null +++ b/tests/cases/closures/test2.qz @@ -0,0 +1,16 @@ +fun main() { + let a = 10; + let z = 1000; + let f = fun (x: i32): i32 { + let b = a; + return x + b; + }; + let g = fun (x: i32): i32 { + return x + a + z; + }; + + assert_eq(f(100), 110); + assert_eq(g(100), 1110); + + println("ok"); +} diff --git a/tests/cases/closures/test2.stdout b/tests/cases/closures/test2.stdout new file mode 100644 index 00000000..9766475a --- /dev/null +++ b/tests/cases/closures/test2.stdout @@ -0,0 +1 @@ +ok