diff --git a/compiler_dyn.py b/compiler_dyn.py index beb4b48..d2df38f 100644 --- a/compiler_dyn.py +++ b/compiler_dyn.py @@ -79,7 +79,7 @@ def calculate_tag(size, ty, arith=None): tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6, endian='little') if arith: tag[57:62] = bitarray.util.int2ba(arith, length=5, endian='little') - print("tags", bitarray.util.ba2base(2, tag)) + # print("tags", bitarray.util.ba2base(2, tag)) return bitarray.util.ba2int(tag) class Compiler: @@ -569,7 +569,6 @@ def cast_insert_exp(self, e): right = Project(right, IntType()) return Inject(Compare(left, [cmp], [right]), BoolType()) case IfExp(expr_test, expr_body, expr_orelse): - # 所有的这种表达式可以用 children 来做 t = self.cast_insert_exp(expr_test) t = Project(t, BoolType()) b = self.cast_insert_exp(expr_body) @@ -656,8 +655,6 @@ def cast_insert(self, p): def reveal_casts_exp(self, e): match e: - # case Call(Name('input_int'), []): - # return Call(Project(e.func, FunctionType([], AnyType())), []) case Project(e, ftype): match ftype: case BoolType() | IntType(): @@ -731,7 +728,6 @@ def reveal_casts_exp(self, e): right = self.reveal_casts_exp(right) return Compare(left, [cmp], [right]) case IfExp(expr_test, expr_body, expr_orelse): - # 所有的这种表达式可以用 children 来做 t = self.reveal_casts_exp(expr_test) b = self.reveal_casts_exp(expr_body) e = self.reveal_casts_exp(expr_orelse) @@ -748,12 +744,12 @@ def reveal_casts_exp(self, e): params = [(x, AnyType()) for x in params] # breakpoint() body = self.reveal_casts_exp(body) + # breakpoint() return AnnLambda(params, AnyType(), body) case _: raise Exception('interp: reveal_casts_exp ' + repr(e)) def reveal_casts_stmt(self, stmt): - # TODO 每次都要展开 stmt 能不能不展开,直接处理 子children match stmt: case Expr(Call(Name('print'), [arg])): new_arg = self.reveal_casts_exp(arg) @@ -809,6 +805,7 @@ def reveal_casts(self, p): raise Exception('interp: unexpected ' + repr(p)) trace(result) + # type_check_Lany.TypeCheckLany().type_check(p) return result def convert_assignments_exp(self, e): @@ -823,7 +820,6 @@ def convert_assignments_exp(self, e): case TagOf(value): value = self.convert_assignments_exp(value) return TagOf(value) - pass case ValueOf(value, typ): value = self.convert_assignments_exp(value) return ValueOf(value, typ) @@ -832,7 +828,6 @@ def convert_assignments_exp(self, e): if id not in self.box_dict: return e else: - # breakpoint() return Subscript(Name(self.box_dict[id]), Constant(0), Load()) case BinOp(left, op, right): left = self.convert_assignments_exp( left) @@ -1024,8 +1019,6 @@ def convert_to_closures_exp(self, e): # if isinstance(c.my_extra_type, AnyType): # breakpoint() if isinstance(c.my_extra_type, AnyType): - - # breakpoint() ne.my_extra_type = AnyType() else: @@ -1046,12 +1039,13 @@ def convert_to_closures_exp(self, e): case ValueOf(value, typ): value = self.convert_to_closures_exp(value) if isinstance(value, Name) and value.id in self.func_val_real_types: - typ = self.func_val_real_types[value.id] - + if isinstance(self.func_val_real_types[value.id], TupleType): + typ = self.func_val_real_types[value.id] - print("ValueOf ", value, value.my_extra_type) - ne = ValueOf(value, value.my_extra_type) - ne.my_extra_type = value.my_extra_type + # bool cam't change + # print("ValueOf ", value, value.my_extra_type) + ne = ValueOf(value, typ) + ne.my_extra_type = typ return ne case Constant(v): @@ -1067,14 +1061,20 @@ def convert_to_closures_exp(self, e): # return TagOf(closureTy) return e - case Name(id): # breakpoint() e.my_extra_type = self.func_val_real_types[id] - print("id is .... {}".format(id)) if id not in self.box_dict: - # breakpoint() + print("id is .... {} {} {}".format(id, self.top_funs, self.func_val_real_types)) + # if isinstance(id, ) + if id in self.top_funs: + # pass + funref = self.func_map[id] + + ne = Closure(funref.arity, [funref]) + ne.my_extra_type = self.func_val_real_types[id] + return ne return e else: return Subscript(Name(self.box_dict[id]), Constant(0), Load()) @@ -1091,12 +1091,16 @@ def convert_to_closures_exp(self, e): case Compare(left, [cmp], [right]): left = self.convert_to_closures_exp(left) + right = self.convert_to_closures_exp(right) if isinstance(left, TagOf) and left.value.id in self.func_val_real_types: # breakpoint() - right = Constant(tagof(self.func_val_real_types[left.value.id])) - else: - right = self.convert_to_closures_exp(right) - return Compare(left, [cmp], [right]) + if isinstance(self.func_val_real_types[left.value.id], TupleType): + right = Constant(tagof(self.func_val_real_types[left.value.id])) + + # right = self.convert_to_closures_exp(right) + ne = Compare(left, [cmp], [right]) + ne.my_extra_type = BoolType() + return ne case IfExp(expr_test, expr_body, expr_orelse): t = self.convert_to_closures_exp(expr_test) @@ -1198,6 +1202,7 @@ def convert_to_closures_exp(self, e): return c case Uninitialized(ty): + e.my_extra_type = ty return e case _: raise Exception('convert_to_closures_exp: unexpected ' + repr(e)) @@ -1215,7 +1220,7 @@ def convert_to_closures_stmt(self, stmt): # l = self.convert_to_closures_exp(l) v_expr = self.convert_to_closures_exp(value) - # if isinstance(v_expr, Tuple): + # if isinstance(v_expr, Closure): # breakpoint() # Subscript single type don't need to be check if not isinstance(l, Subscript): @@ -1762,7 +1767,7 @@ def rco_exp(self, e: expr, need_atomic: bool) -> Tupling[expr, Temporaries]: # 如果这里不这个name if body_tmps: - body = Begin([ Assign([name], expr)for name,expr in body_tmps], body) + body = Begin([Assign([name], expr) for name,expr in body_tmps], body) if orelse_tmp: orelse_expr = Begin([Assign([name], expr) for name, expr in orelse_tmp], orelse_expr) return_expr = IfExp(test_expr, body, orelse_expr) @@ -1952,7 +1957,7 @@ def explicate_effect(self, e: expr, cont: List[stmt], new_body = self.explicate_stmt(s, new_body, basic_blocks) return new_body case _: - print("......", e) + # print("......", e) return [] + cont def explicate_pred(self, cnd: expr, thn: List[stmt], els: List[stmt], @@ -2108,7 +2113,7 @@ def select_arg(self, e: expr) -> arg: return Immediate(0) case Constant(value): return Immediate(value) - case Uninitialized(ty) if isinstance(ty, IntType) : + case Uninitialized(ty): return Immediate(0) # case FunRef(name, arith): # # breakpoint() @@ -2128,11 +2133,22 @@ def select_compare(self, expr, then_label, else_label) -> List[instr]: y = self.select_arg(y) return [ Instr('cmpq', [y, x]), - # Instr('j{}'.format(op_dict[str(op)]), [then_label]), JumpIf(op_dict[str(op)], then_label), Jump(else_label) # Instr('jmp', [else_label]) ] + case Name(id): + x = self.select_arg(expr) + y = Immediate(1) + # the result == 1 ? + return [ + Instr('cmpq', [y, x]), + JumpIf(op_dict["=="], then_label), + Jump(else_label) + ] + case _: + breakpoint() + raise Exception("no match {} ".format(expr)) def select_stmt(self, s: stmt) -> List[instr]: # YOUR CODE HERE @@ -2192,6 +2208,7 @@ def select_stmt(self, s: stmt) -> List[instr]: case Assign([lhs], ValueOf(e, ty)): left = self.select_arg(lhs) value = self.select_arg(e) + # how should I now this is BoolType if isinstance(ty, IntType) or isinstance(ty, BoolType): result.append(Instr('movq', [value, left])) result.append(Instr('sarq', [Immediate(3), left])) @@ -2236,6 +2253,8 @@ def select_stmt(self, s: stmt) -> List[instr]: else: result.append(Instr('movq', [left_arg, lhs])) result.append(Instr('addq', [right_arg, lhs])) + + case Assign([lhs], BinOp(left, Sub(), right)): # breakpoint() left_arg = self.select_arg(left) @@ -2305,7 +2324,6 @@ def select_stmt(self, s: stmt) -> List[instr]: result.append(Instr('cmpq', [l, r])) result.append(Instr('set{}'.format(op_dict[str(op)]), [ByteReg('bl')])) result.append(Instr('movzbq', [ByteReg('bl'), lhs])) - pass case Assign([lhs], Call(Name('len'), [arg])): arg = self.select_arg(arg) result.append(Instr('movq', [arg, Reg('r11')])) @@ -2348,9 +2366,12 @@ def select_stmt(self, s: stmt) -> List[instr]: # result.append(Instr('retq', [])) case Goto(label): result.append(Jump(label)) + case If(expr, [Goto(then_label)], [Goto(else_label)]): + # expr can be valueof if_ = self.select_compare(expr, then_label, else_label) result.extend(if_) + case Collect(size): # size = self.select_arg(size) result.append(Instr('movq', [Reg('r15'), Reg('rdi')])) @@ -2613,12 +2634,12 @@ def analyze_dataflow(self, G, transfer, bottom, join): worklist = deque(G.vertices()) debug = {} while worklist: - print(worklist) + # print(worklist) node = worklist.pop() inputs = [mapping[v] for v in trans_G.adjacent(node)] input = reduce(join, inputs, bottom) output = transfer(node, input) - print("node", node, "input", input, "output", output) + # print("node", node, "input", input, "output", output) if output != mapping[node]: worklist.extend(G.adjacent(node)) mapping[node] = output @@ -2653,7 +2674,7 @@ def build_interference(self, blocks) -> UndirectedAdjList: # live_before_block[label] = tmp[ss[0]] # live_after.update(tmp) - print("live_after ", self.live_after) + # print("live_after ", self.live_after) for label, ss in blocks.items(): for s in ss: match (s): @@ -2719,7 +2740,7 @@ def less(u, v): # print("handing", u) adj_colors = {color_map[v] for v in interference_graph.adjacent(u) if v in color_map} - print(u, adj_colors) + # print(u, adj_colors) if left_color := set(valid_colors) - adj_colors: color = min(left_color) if u not in color_map: @@ -2774,8 +2795,7 @@ def allocate_registers(self, p: X86Program) -> X86Program: new_blocks = {} color_map = self.color_graph(blocks) - - print("color_map", color_map) + # print("color_map", color_map) so_far_rbp = 0 so_far_r15 = 0 cdef.rbp_spill = set() @@ -2802,10 +2822,10 @@ def allocate_registers(self, p: X86Program) -> X86Program: cdef.S = len(cdef.rbp_spill) if cdef.r15_spill.intersection(cdef.rbp_spill): - print("r15 and rbp have somecolor", ) + # print("r15 and rbp have somecolor", ) sys.exit(-1) - print("real_color_map", cdef.real_color_map) + # print("real_color_map", cdef.real_color_map) for label, ss in blocks.items(): ss = blocks[label] diff --git a/compiler_lambda.py b/compiler_lambda.py index c95b0a2..9255b73 100644 --- a/compiler_lambda.py +++ b/compiler_lambda.py @@ -56,7 +56,7 @@ def calculate_tag(size, ty, arith=None): tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6, endian='little') if arith: tag[58:63] = bitarray.util.int2ba(arith, length=5, endian='little') - print("tags", bitarray.util.ba2base(2, tag)) + # print("tags", bitarray.util.ba2base(2, tag)) return bitarray.util.ba2int(tag) class Compiler: @@ -644,8 +644,7 @@ def convert_to_closures_exp(self, e): # pass funref = self.func_map[id] return Closure(funref.arity, [funref]) - # ... - #breakpoint() + return e else: # breakpoint() @@ -1424,7 +1423,7 @@ def explicate_effect(self, e: expr, cont: List[stmt], new_body = self.explicate_stmt(s, new_body, basic_blocks) return new_body case _: - print("......", e) + return [] + cont def explicate_pred(self, cnd: expr, thn: List[stmt], els: List[stmt], @@ -1603,6 +1602,8 @@ def select_compare(self, expr, then_label, else_label) -> List[instr]: Jump(else_label) # Instr('jmp', [else_label]) ] + case _: + raise Exception("no match {} ".format(expr)) def select_stmt(self, s: stmt) -> List[instr]: # YOUR CODE HERE @@ -2006,12 +2007,12 @@ def analyze_dataflow(self, G, transfer, bottom, join): worklist = deque(G.vertices()) debug = {} while worklist: - print(worklist) + # print(worklist) node = worklist.pop() inputs = [mapping[v] for v in trans_G.adjacent(node)] input = reduce(join, inputs, bottom) output = transfer(node, input) - print("node", node, "input", input, "output", output) + # print("node", node, "input", input, "output", output) if output != mapping[node]: worklist.extend(G.adjacent(node)) mapping[node] = output @@ -2046,7 +2047,7 @@ def build_interference(self, blocks) -> UndirectedAdjList: # live_before_block[label] = tmp[ss[0]] # live_after.update(tmp) - print("live_after ", self.live_after) + # print("live_after ", self.live_after) for label, ss in blocks.items(): for s in ss: match (s): @@ -2112,7 +2113,7 @@ def less(u, v): # print("handing", u) adj_colors = {color_map[v] for v in interference_graph.adjacent(u) if v in color_map} - print(u, adj_colors) + # print(u, adj_colors) if left_color := set(valid_colors) - adj_colors: color = min(left_color) if u not in color_map: @@ -2167,8 +2168,7 @@ def allocate_registers(self, p: X86Program) -> X86Program: new_blocks = {} color_map = self.color_graph(blocks) - - print("color_map", color_map) + # print("color_map", color_map) so_far_rbp = 0 so_far_r15 = 0 cdef.rbp_spill = set() @@ -2198,7 +2198,7 @@ def allocate_registers(self, p: X86Program) -> X86Program: print("r15 and rbp have somecolor", ) sys.exit(-1) - print("real_color_map", cdef.real_color_map) + # print("real_color_map", cdef.real_color_map) for label, ss in blocks.items(): ss = blocks[label] diff --git a/compiler_tuple.py b/compiler_tuple.py index 490dd86..e41765e 100644 --- a/compiler_tuple.py +++ b/compiler_tuple.py @@ -426,7 +426,7 @@ def explicate_effect(self, e: expr, cont: List[stmt], new_body = self.explicate_stmt(s, new_body, basic_blocks) return new_body case _: - print("......", e) + # print("......", e) return [] + cont def explicate_pred(self, cnd: expr, thn: List[stmt], els: List[stmt], diff --git a/interp_x86/eval_x86.py b/interp_x86/eval_x86.py index 732bb69..92cba34 100644 --- a/interp_x86/eval_x86.py +++ b/interp_x86/eval_x86.py @@ -238,8 +238,8 @@ def store_arg(self, a, v): def eval_instrs(self, instrs, blocks, output): for instr in instrs: - self.log(f'Evaluating instruction: {instr.pretty()}') - trace(f'Evaluating instruction: {instr.pretty()}') + # self.log(f'Evaluating instruction: {instr.pretty()}') + # trace(f'Evaluating instruction: {instr.pretty()}') if instr.data == 'pushq': a = instr.children[0] self.registers['rsp'] = self.registers['rsp'] - 8 @@ -254,9 +254,9 @@ def eval_instrs(self, instrs, blocks, output): elif instr.data == 'movq': a1, a2 = instr.children - trace("trace {} {} {}".format(instr, a1, a2)) + # trace("trace {} {} {}".format(instr, a1, a2)) v = self.eval_arg(a1) - trace(" movq {} {} {} {}".format(instr, a1, a2, v)) + # trace(" movq {} {} {} {}".format(instr, a1, a2, v)) self.store_arg(a2, v) elif instr.data == 'movzbq': @@ -293,7 +293,7 @@ def eval_instrs(self, instrs, blocks, output): v2 = self.eval_arg(a2) # return val>>n if val >= 0 else (val+1<<64)>>n # while TODO python tag just using it is > 9 - trace(f"...... {v2=} {v1=} {a1=} {a2=}") + # trace(f"...... {v2=} {v1=} {a1=} {a2=}") self.store_arg(a2, v2 >> v1) elif instr.data == 'salq': a1, a2 = instr.children @@ -301,7 +301,7 @@ def eval_instrs(self, instrs, blocks, output): v2 = self.eval_arg(a2) # return val>>n if val >= 0 else (val+1<<64)>>n # while TODO python tag just using it is > 9 - trace(f"...... {v2=} {v1=} {a1=} {a2=}") + # trace(f"...... {v2=} {v1=} {a1=} {a2=}") self.store_arg(a2, v2 << v1) elif instr.data == 'andq': a1, a2 = instr.children @@ -436,7 +436,7 @@ def eval_instrs(self, instrs, blocks, output): a1, a2 = instr.children v1 = self.eval_arg(a1) v2 = self.eval_arg(a2) - trace(f"cmq {v1=} {v2=} {a1=} {a2=}") + # trace(f"cmq {v1=} {v2=} {a1=} {a2=}") if v1 == v2: self.registers['EFLAGS'] = 'e' elif v2 < v1: diff --git a/run-all.sh b/run-all.sh index 061d93e..1ca5c72 100644 --- a/run-all.sh +++ b/run-all.sh @@ -1,5 +1,12 @@ #!/usr/bin/env bash +python run-tests-dyn.py tests/dyn/add.py + +python run-tests-dyn.py tests/dyn/add1.py +python run-tests-dyn.py tests/dyn/add2.py +python run-tests-dyn.py tests/dyn/add4.py + python run-tests-lambda.py tests/lambda/add.py python run-tests-lambda.py tests/lambda/add1.py -python run-tests-lambda.py tests/lambda/add2.py \ No newline at end of file +python run-tests-lambda.py tests/lambda/add2.py +python run-tests-lambda.py tests/lambda/add4.py \ No newline at end of file diff --git a/tests/dyn/add1.py b/tests/dyn/add1.py index 3bd9a88..4cf6026 100644 --- a/tests/dyn/add1.py +++ b/tests/dyn/add1.py @@ -1,7 +1,6 @@ -def f(x:int): - y = 4 - f = lambda z,a,b,c,d,e,f,g : x + y + z + a + b + c + d + e + f + g +def t(x): + y = 10 + f = lambda z: x + y + z return f -g = f(5) -h = f(3) -print(g(11, 0,0,0,0,0,0,0) + h(15, 0,0,0,0,0,0,0)) \ No newline at end of file +g = t(10) +print(g(22)) \ No newline at end of file diff --git a/tests/dyn/add2.py b/tests/dyn/add2.py index e2b7fec..1d4ec1c 100644 --- a/tests/dyn/add2.py +++ b/tests/dyn/add2.py @@ -1,7 +1,7 @@ -def add1(x:int) -> int: +def add1(x): return x + 1 y = input_int() -g : Callable[[int], int] = lambda x : x - y +g = lambda x: x - 10 f = add1 if input_int() == 0 else g # add1 must be translate to closure print(f(41)) \ No newline at end of file diff --git a/tests/lambda/add1.py b/tests/lambda/add1.py index ade56d2..f8db0e2 100644 --- a/tests/lambda/add1.py +++ b/tests/lambda/add1.py @@ -1,7 +1,7 @@ -def f(x:int)-> Callable[[int,int,int,int,int,int,int,int],int] : +def t(x:int)-> Callable[[int,int,int,int,int,int,int,int],int] : y = 4 f: Callable[[int,int,int,int,int,int,int,int], int] = lambda z,a,b,c,d,e,f,g : x + y + z + a + b + c + d + e + f + g return f -g = f(5) -h = f(3) +g = t(5) +h = t(3) print(g(11, 0,0,0,0,0,0,0,0) + h(15, 0,0,0,0,0,0,0,0)) \ No newline at end of file diff --git a/type_check_Cfun.py b/type_check_Cfun.py index fa73978..ef572a6 100644 --- a/type_check_Cfun.py +++ b/type_check_Cfun.py @@ -82,7 +82,7 @@ def type_check_def(self, d, env): for l in reversed(sort_cfg): ss = blocks[l] # should handing by top_logic....... - trace("handing....... {}".format(l)) + # trace("handing....... {}".format(l)) self.type_check_stmts(ss, new_env) # trace('type_check_Cfun iterating ' + repr(new_env)) if new_env == old_env: diff --git a/type_check_Lany.py b/type_check_Lany.py index 63a508a..de87ffa 100644 --- a/type_check_Lany.py +++ b/type_check_Lany.py @@ -6,6 +6,29 @@ class TypeCheckLany(TypeCheckLlambda): + def check_type_equal(self, t1, t2, e): + if t1 == Bottom() or t2 == Bottom(): + return + match t1: + case AnyType(): + return + case FunctionType(ps1, rt1): + match t2: + case FunctionType(ps2, rt2): + for (p1, p2) in zip(ps1, ps2): + self.check_type_equal(p1, p2, e) + self.check_type_equal(rt1, rt2, e) + case _: + if isinstance(t2, AnyType): + # t2 was AnyType we can match any + return + raise Exception('error: ' + repr(t1) + ' != ' + repr(t2) \ + + ' in ' + repr(e)) + case _: + if isinstance(t2, AnyType): + return + super().check_type_equal(t1, t2, e) + def type_check_exp(self, e, env): match e: case Inject(value, typ): @@ -64,6 +87,7 @@ def type_check_stmts(self, ss, env): self.check_type_equal(env[v.id], t, value) else: env[v.id] = t + v.has_type = env[v.id] return self.type_check_stmts(ss[1:], env) case Expr(Call(Name('print'), [arg])): @@ -72,19 +96,31 @@ def type_check_stmts(self, ss, env): return self.type_check_stmts(ss[1:], env) case _: return super().type_check_stmts(ss, env) - # def check_exp(self, e, ty, env): - # match e: - # case Call(Name('make_any'), [value, tag]): - # pass - # case Inject(value, typ): - # pass - # case Project(value, typ): - # pass - # case Call(Name('any_tuple_load'), [tup, index]): - # pass - # case _: - # super().check_exp(e, ty, env) - # return + + def check_exp(self, e, ty, env): + match e: + case Lambda(params, body): + e.has_type = ty + if isinstance(params, ast.arguments): + new_params = [a.arg for a in params.args] + e.args = new_params + else: + new_params = params + match ty: + case FunctionType(params_t, return_t): + new_env = {x: t for (x, t) in env.items()} + for (p, t) in zip(new_params, params_t): + new_env[p] = t + self.check_exp(body, return_t, new_env) + case Bottom(): + pass + case _: + raise Exception('lambda does not have type ' + str(ty)) + case _: + + t = self.type_check_exp(e, env) + trace("^^^^ {} {} {}".format(e, ty, t)) + self.check_type_equal(t, ty, e) # t = self.type_check_exp(e, env) # self.check_type_equal(t, ty, e) @@ -108,3 +144,75 @@ def type_check(self, p): self.check_stmts(body, AnyType(), env) case _: raise Exception('type_check: unexpected ' + repr(p)) + + + def check_stmts(self, ss, return_ty, env): + if len(ss) == 0: + return + #trace('*** check_stmts ' + repr(ss[0]) + '\n') + match ss[0]: + case FunctionDef(name, params, body, dl, returns, comment): + #trace('*** tc_check ' + name) + new_env = {x: t for (x,t) in env.items()} + if isinstance(params, ast.arguments): + new_params = [(p.arg, self.parse_type_annot(p.annotation)) for p in params.args] + ss[0].args = new_params + new_returns = self.parse_type_annot(returns) + ss[0].returns = new_returns + else: + new_params = params + new_returns = returns + for (x,t) in new_params: + new_env[x] = t + rt = self.check_stmts(body, new_returns, new_env) + self.check_stmts(ss[1:], return_ty, env) + case Return(value): + #trace('** tc_check return ' + repr(value)) + self.check_exp(value, return_ty, env) + case Assign([v], value) if isinstance(v, Name): + if v.id in env: + self.check_exp(value, env[v.id], env) + else: + # + t = self.type_check_exp(value, env) + # breakpoint() + trace("ggggg {} {} {} ".format(v.id, value, type(value), t)) + # ggggg g.2 inject((lambda x.3: inject((project(x.3, int) - project(y.1, int)), int)), Callable[[any], any]) any + # so g.2 was AnyType + # but checkfunc has type FunctionType + # here we need real typ? + env[v.id] = t + # breakpoint() + if isinstance(value, Inject) and isinstance(value.typ, FunctionType): + print("ggggg value.typ ", value.typ) + env[v.id] = value.typ # + elif isinstance(value, Call): + if isinstance(value.args[0], AnnLambda): + env[v.id] = value.args[0].convert_to_typ()# pass + v.has_type = env[v.id] + trace("xxxxx {} {}".format(return_ty, env)) + trace(env) + self.check_stmts(ss[1:], return_ty, env) + case Assign([Subscript(tup, Constant(index), Store())], value): + tup_t = self.type_check_exp(tup, env) + match tup_t: + case TupleType(ts): + self.check_exp(value, ts[index], env) + case Bottom(): + pass + case _: + raise Exception('check_stmts: expected a tuple, not ' \ + + repr(tup_t)) + self.check_stmts(ss[1:], return_ty, env) + case AnnAssign(v, ty, value, simple) if isinstance(v, Name): + ty_annot = self.parse_type_annot(ty) + ss[0].annotation = ty_annot + if v.id in env: + self.check_type_equal(env[v.id], ty_annot) + else: + env[v.id] = ty_annot + v.has_type = env[v.id] + self.check_exp(value, ty_annot, env) + self.check_stmts(ss[1:], return_ty, env) + case _: + self.type_check_stmts(ss, env) diff --git a/type_check_Lfun.py b/type_check_Lfun.py index 66e3985..809a71d 100644 --- a/type_check_Lfun.py +++ b/type_check_Lfun.py @@ -10,6 +10,8 @@ def check_type_equal(self, t1, t2, e): if t1 == Bottom() or t2 == Bottom(): return match t1: + case AnyType(): + return case FunctionType(ps1, rt1): match t2: case FunctionType(ps2, rt2): @@ -17,6 +19,9 @@ def check_type_equal(self, t1, t2, e): self.check_type_equal(p1, p2, e) self.check_type_equal(rt1, rt2, e) case _: + if isinstance(t2, AnyType): + # t2 was AnyType we can match any + return raise Exception('error: ' + repr(t1) + ' != ' + repr(t2) \ + ' in ' + repr(e)) case _: diff --git a/type_check_Llambda.py b/type_check_Llambda.py index b595fff..36c25ec 100644 --- a/type_check_Llambda.py +++ b/type_check_Llambda.py @@ -10,7 +10,7 @@ class TypeCheckLlambda(TypeCheckLfun): def type_check_exp(self, e, env): - trace("^^^^ {} {}".format(e, repr(e))) + match e: case Name(id): e.has_type = env[id] @@ -75,7 +75,9 @@ def check_exp(self, e, ty, env): case _: raise Exception('lambda does not have type ' + str(ty)) case _: + t = self.type_check_exp(e, env) + trace("^^^^ {} {} {}".format(e, ty, t)) self.check_type_equal(t, ty, e) def check_stmts(self, ss, return_ty, env): @@ -105,9 +107,17 @@ def check_stmts(self, ss, return_ty, env): if v.id in env: self.check_exp(value, env[v.id], env) else: - env[v.id] = self.type_check_exp(value, env) + t = self.type_check_exp(value, env) + # breakpoint() + trace("ggggg {} {} {} ".format(v.id, value, type(value), t)) + # ggggg g.2 inject((lambda x.3: inject((project(x.3, int) - project(y.1, int)), int)), Callable[[any], any]) any + # so g.2 was AnyType + # but checkfunc has type FunctionType + + env[v.id] = t v.has_type = env[v.id] trace("xxxxx {}".format(return_ty)) + trace(env) self.check_stmts(ss[1:], return_ty, env) case Assign([Subscript(tup, Constant(index), Store())], value): tup_t = self.type_check_exp(tup, env) diff --git a/utils.py b/utils.py index 0678ed4..0d5cc7b 100644 --- a/utils.py +++ b/utils.py @@ -404,6 +404,9 @@ def __str__(self): ', '.join([x + ':' + str(t) for (x,t) in self.params]) + '] -> ' \ + str(self.returns) + ': ' + str(self.body) + def convert_to_typ(self): + return FunctionType([i[1] for i in self.params], self.returns) + # An uninitialized value of a given type. # Needed for boxing local variables. @dataclass