From c7142bbde4a55e7e690fdf28a8f95e88727fb469 Mon Sep 17 00:00:00 2001 From: jiamo Date: Sun, 13 Mar 2022 16:38:30 +0800 Subject: [PATCH] add allocal r8 --- compiler.py | 136 ++++++++++++++++++++++++++++---------- interp_x86/convert_x86.py | 7 +- interp_x86/eval_x86.py | 32 ++++++++- tests/tuple/add1.py | 2 +- type_check_Ctup.py | 2 +- x86_ast.py | 2 +- 6 files changed, 139 insertions(+), 42 deletions(-) diff --git a/compiler.py b/compiler.py index bce6b7a..490dd86 100644 --- a/compiler.py +++ b/compiler.py @@ -18,6 +18,8 @@ from typing import Tuple as Tupling import type_check_Ltup from interp_x86.eval_x86 import interp_x86 +import type_check_Ctup + Binding = Tupling[Name, expr] Temporaries = List[Binding] @@ -37,16 +39,19 @@ def calculate_tag(size, ty): - tag = bitarray.bitarray(64) + tag = bitarray.bitarray(64, endian="little") tag.setall(0) p_mask = 7 + tag[0] = 1 for i, type in enumerate(ty.types): - if type == TupleType: + # breakpoint() + if isinstance(type, TupleType): tag[p_mask + i] = 1 else: tag[p_mask + i] = 0 - tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6) - return bitarray.util.ba2hex(tag) + tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6, endian='little') + print("tags", bitarray.util.ba2base(2, tag)) + return bitarray.util.ba2int(tag) class Compiler: @@ -158,13 +163,13 @@ def expose_allocation_exp(self, exp) -> Tupling[expr, List[stmt]]: stmts.append(new_stmt) tmp_exprs.append(var) # breakpoint() - n = 8 + 8 * len(exprs) + n = len(exprs) stmts.append( - If(Compare(BinOp(GlobalValue("free_pr"), Add(), Constant(n)), [Lt()], [GlobalValue("fromspace_end")]), + If(Compare(BinOp(GlobalValue("free_ptr"), Add(), Constant(8 * (n+1))), [Lt()], [GlobalValue("fromspace_end")]), [Expr(Constant(0))], - [Collect(n)]) + [Collect(8 * (n+1))]) ) - tmp = generate_name(tmp) + tmp = generate_name("alloc") var = Name(tmp) stmts.append(Assign([var], Allocate(n, exp.has_type))) # may exp.has_type.types for i in range(len(exprs)): @@ -532,14 +537,16 @@ def select_arg(self, e: expr) -> arg: match e: case Name(name): return Variable(name) + case GlobalValue(name): + return x86_ast.Global(name) case Constant(True): return Immediate(1) case Constant(False): return Immediate(0) case Constant(value): return Immediate(value) - case x if isinstance(x, int): - return Immediate(x) + # case x if isinstance(x, int): + # return Immediate(x) case _: raise Exception('error in select_arg, unexpected ' + repr(e)) @@ -645,10 +652,10 @@ def select_stmt(self, s: stmt) -> List[instr]: case Assign([lhs], Allocate(size, ty)): lhs = self.select_arg(lhs) - size = self.select_arg(size) - tag = calculate_tag(size.value, ty) + # size = self.select_arg(size) + tag = calculate_tag(size, ty) result.append(Instr("movq", [x86_ast.Global("free_ptr"), Reg('r11')])) - result.append(Instr("movq", [8 * (size.value + 1) , Reg('r11')])) + result.append(Instr("addq", [Immediate(8 * (size + 1)), x86_ast.Global("free_ptr")])) result.append(Instr("movq", [Immediate(tag), Deref('r11', 0)])) result.append(Instr('movq', [Reg('r11'), lhs])) case Assign([lhs], value): @@ -664,10 +671,10 @@ def select_stmt(self, s: stmt) -> List[instr]: if_ = self.select_compare(expr, then_label, else_label) result.extend(if_) case Collect(size): - size = self.select_arg(size) - result.append(Instr('movq', [Reg('r15'), Reg('%rdi')])) - result.append(Instr('movq', [size, Reg('rsi')])) - result.append(Callq("collect")) + # size = self.select_arg(size) + result.append(Instr('movq', [Reg('r15'), Reg('rdi')])) + result.append(Instr('movq', [Immediate(size), Reg('rsi')])) + result.append(Callq(label_name("collect"), 2)) case _: raise Exception('error in select_stmt, unexpected ' + repr(s)) return result @@ -675,8 +682,8 @@ def select_stmt(self, s: stmt) -> List[instr]: def select_instructions(self, p: Module) -> X86Program: # YOUR CODE HERE - - + type_check_Ctup.TypeCheckCtup().type_check(p) + # breakpoint() blocks = {} match p: case CProgram(basic_blocks): @@ -690,6 +697,7 @@ def select_instructions(self, p: Module) -> X86Program: x86 = X86Program(blocks) + x86.var_types = p.var_types # breakpoint() # print("......") # interp_x86(x86) @@ -768,10 +776,21 @@ def read_var(self, i: instr) -> Set[location]: case _: return set() + def free_var(self, t): + match(t): + case Variable(i): + return t + case Reg(r): + return t + case Deref(r, offset): + return Reg(r) + case _: + return set() + def write_var(self, i) -> Set[location]: match (i): case Instr("movq", [s, t]): - return {i.args[1]} + return set([self.free_var(t)]) case Callq(func, num_args): return set(callee_saved_regs) case _: @@ -844,7 +863,7 @@ def transfer(self,label, live_after_block): before_instr_set = tmp.union(live_after_block) case _: before_instr_set = (self.live_after[s] - self.write_var(s)).union(self.read_var(s)) - # print("s" , s, pre_instr_set) + # print("s" , s, before_instr_set) self.live_before[s] = before_instr_set live_before_block = live_before_block.union(before_instr_set) pre_instr = s @@ -908,6 +927,7 @@ def build_interference(self, blocks) -> UndirectedAdjList: match (s): case Instr("movq", [si, d]): # si = s.args[0] + d = self.free_var(d) for v in self.live_after[s]: if v != d and v != si: interference_graph.add_edge(d, v) @@ -929,7 +949,11 @@ def color_graph(self, blocks, k=100) -> Dict[location, int]: # first make it k big enough valid_colors = list(range(0, k)) # number of colar # Rdi 的保存问题 - color_map = {Reg('rax'): -1, Reg('rsp'): -2, Reg('rdi'): -3, ByteReg('bl'): -4} + color_map = { + Reg('rax'): -1, Reg('rsp'): -2, Reg('rdi'): -3, ByteReg('bl'): -4, Reg('r11'): -5, + Reg('r15'): -6, Reg('rsi'): -7 # rsi 其实可以用来做其他事情。 但如果分配 rsi 9 rsi 的 color + # 算法 color 9 和 可以分配出去reg 的color 0 1 3 矛盾 + } # color_map = {} saturated = {} @@ -951,12 +975,15 @@ def less(u, v): for v in vsets: saturated[v] = set() for v in vsets: + queue.push(v) + while not queue.empty(): u = queue.pop() - print("handing", u) + # print("handing", u) + adj_colors = {color_map[v] for v in interference_graph.adjacent(u) if v in color_map} print(u, adj_colors) if left_color := set(valid_colors) - adj_colors: @@ -973,12 +1000,14 @@ def less(u, v): def allocate_registers(self, p: X86Program) -> X86Program: # YOUR CODE HERE + # breakpoint() # ? RDI - self.color_regs = [Reg("rbx"), Reg("rcx"), Reg("rdx"), Reg("rsi"), Reg("rdi"), Reg("r8"), Reg("r9"), Reg("r10")] + self.color_regs = [Reg("rbx"), Reg("rcx"), Reg("rdx"), Reg("rsi"), + Reg("rdi"), Reg("r8"), Reg("r9"), Reg("r10")] self.color_regs = [Reg("rbx"), Reg("rcx")] self.color_regs = [Reg("rbx")] # rcx as tmp - self.color_regs = [Reg("rbx"), Reg("rcx")] + self.color_regs = [Reg("rbx"), Reg("rcx"), Reg("r8")] self.alloc_callee_saved_regs = list(set(self.color_regs).intersection(callee_saved_regs)) self.C = len(self.alloc_callee_saved_regs) @@ -988,6 +1017,9 @@ def allocate_registers(self, p: X86Program) -> X86Program: color_regs_map[-2] = Reg('rsp') color_regs_map[-3] = Reg('rdi') color_regs_map[-4] = ByteReg("bl") + color_regs_map[-5] = Reg('r11') + color_regs_map[-6] = Reg('r15') + color_regs_map[-7] = Reg('rsi') self.real_color_map = {} match(p): @@ -996,19 +1028,37 @@ def allocate_registers(self, p: X86Program) -> X86Program: # breakpoint() new_blocks = {} color_map = self.color_graph(blocks) - self.S = len(set(color_map.values())) - len(self.color_regs) - - print("color_map", color_map) - for color in sorted(set(color_map.values())): + print("color_map", color_map) + so_far_rbp = 0 + so_far_r15 = 0 + self.rbp_spill = set() + self.r15_spill = set() + for var, color in sorted(color_map.items(), key=lambda i: i[1]): + # 相同的 color 但 type 不一样 if color in self.real_color_map: continue if color in color_regs_map: self.real_color_map[color] = color_regs_map[color] else: # Yes - self.real_color_map[color] = Deref("rbp", -8*(color-len(self.color_regs) + self.C + 1)) + # breakpoint() + if isinstance(p.var_types.get(str(var)), TupleType): + # breakpoint() + # r15 is up r15 was saveid in heap + self.real_color_map[color] = Deref("r15", 8*(so_far_r15)) + so_far_r15 += 1 + self.r15_spill.add(color) + else: + self.real_color_map[color] = Deref("rbp", -8*(so_far_rbp+ self.C + 1)) + so_far_rbp += 1 + self.rbp_spill.add(color) + + self.S = len(self.rbp_spill) + if self.r15_spill.intersection(self.rbp_spill): + print("r15 and rbp have somecolor", ) + sys.exit(-1) print("real_color_map", self.real_color_map) @@ -1049,10 +1099,15 @@ def patch_instr(self, i: instr) -> List[instr]: match(i): case Instr(instr, [x, y]) if x == y: return [] - case Instr(instr, [Deref("rbp", x), Deref("rbp", y)]): + case Instr(instr, [Deref(label_x, x), Deref(label_y, y)]): return [ - Instr("movq", [Deref("rbp", x), Reg("rax")]), - Instr("movq", [Reg("rax"), Deref("rbp", y)]) + Instr("movq", [Deref(label_x, x), Reg("rax")]), + Instr("movq", [Reg("rax"), Deref(label_y, y)]) + ] + case Instr(instr, [x86_ast.Global(x), Deref(label_y, y)]): + return [ + Instr("movq", [x86_ast.Global(x), Reg("rax")]), + Instr("movq", [Reg("rax"), Deref(label_y, y)]) ] case Instr('cmpq', [x, Immediate(v)]): return [ @@ -1114,12 +1169,25 @@ def prelude_and_conclusion(self, p: X86Program) -> X86Program: for reg in extra_saved_regs: main.append(Instr("pushq", [reg])) main.extend([ Instr("subq", [Immediate(self.rsp_sub), Reg("rsp")])]) + main.extend([ + Instr("movq", [Immediate(65536), Reg("rdi")]), + Instr("movq", [Immediate(65536), Reg("rsi")]), + Callq(label_name("initialize"), 2), + Instr("movq", [x86_ast.Global("rootstack_begin"), Reg("r15")]), + ]) + len_spill_r15 = len(self.r15_spill) + for i in range(len_spill_r15): + main.append( Instr("movq", [Immediate(0), Deref("r15", 8 * i)])) + main.append(Instr('addq', [Immediate(8 * len_spill_r15), Reg('r15')])) main.append(Jump(label_name("start"))) blocks[label_name("main")] = main # for label , body in blocks.items(): # pass - conclusion.extend([ Instr("addq", [Immediate(self.rsp_sub), Reg("rsp")]),]) + conclusion.extend([ + Instr("subq", [Immediate(8 * len_spill_r15), Reg("r15")]), + Instr("addq", [Immediate(self.rsp_sub), Reg("rsp")]), + ]) for reg in extra_saved_regs[::-1]: conclusion.append(Instr("popq", [reg])) conclusion.append(Instr("popq", [Reg('rbp')])) # seem no need pop diff --git a/interp_x86/convert_x86.py b/interp_x86/convert_x86.py index 8de5809..fde9565 100644 --- a/interp_x86/convert_x86.py +++ b/interp_x86/convert_x86.py @@ -4,13 +4,13 @@ from lark import Tree from ast import Name, Constant from x86_ast import * -from utils import label_name, GlobalValue +from utils import label_name, GlobalValue, trace def convert_int(value): if value >= 0: return Tree('int_a', [Tree('int_a', [value])]) else: - return Tree('neg_a',[Tree('int_a', [- value])]) + return Tree('neg_a', [Tree('int_a', [- value])]) def convert_arg(arg): match arg: @@ -24,7 +24,7 @@ def convert_arg(arg): return Tree('mem_a', [convert_int(offset), reg]) case ByteReg(id): return Tree('reg_a', [id]) - case GlobalValue(id): + case Global(id): return Tree('global_val_a', [id, 'rip']) case _: raise Exception('convert_arg: unhandled ' + repr(arg)) @@ -32,6 +32,7 @@ def convert_arg(arg): def convert_instr(instr): match instr: case Instr(instr, args): + # trace("convert.... {} {}".format( instr, args)) return Tree(instr, [convert_arg(arg) for arg in args]) case Callq(func, args): return Tree('callq', [func]) diff --git a/interp_x86/eval_x86.py b/interp_x86/eval_x86.py index 59c53e4..6db2c96 100644 --- a/interp_x86/eval_x86.py +++ b/interp_x86/eval_x86.py @@ -4,6 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from utils import * +from lark import Tree from .parser_x86 import x86_parser, x86_parser_instrs from .convert_x86 import convert_program @@ -148,7 +149,12 @@ def print_mem(self, mem): def eval_imm(self, e): if e.data == 'int_a': + # breakpoint() + # trace(e.children[0]) + if isinstance(e.children[0], Tree): + return self.eval_imm(e.children[0]) return int(e.children[0]) + # return self.eval_imm(e.children[0]) elif e.data == 'neg_a': return -self.eval_imm(e.children[0]) else: @@ -170,7 +176,28 @@ def eval_arg(self, a): elif a.data == 'global_val_a': loc, reg = a.children assert str(reg) == 'rip', a - return self.global_vals[str(loc)] + if str(loc) in self.global_vals: + return self.global_vals[str(loc)] # select instr 这里还没有 init + else: + rootstack_size = 65535 + heap_size = 65535 + + rs_begin = 2000 + rs_end = rs_begin + rootstack_size + + fromspace_begin = 100000 + fromspace_end = fromspace_begin + heap_size + + self.global_vals = {**self.global_vals, + 'rootstack_begin': rs_begin, + 'rootstack_end': rs_end, + 'free_ptr': fromspace_begin, + 'fromspace_begin': fromspace_begin, + 'fromspace_end': fromspace_end + } + self.registers['r15'] = self.global_vals['rootstack_begin'] + return self.global_vals[str(loc)] + else: raise RuntimeError(f'Unknown arg in eval_arg: {a}') @@ -213,6 +240,7 @@ def eval_instrs(self, instrs, blocks, output): elif instr.data == 'movq': a1, a2 = instr.children + # trace("trace {} {} {}".format(instr, a1, a2)) v = self.eval_arg(a1) self.store_arg(a2, v) @@ -333,7 +361,7 @@ def eval_instrs(self, instrs, blocks, output): print(self.print_state()) - elif target == 'collect': + elif target == label_name('collect'): self.log(f'CALL TO collect: need {self.registers["rsi"]} bytes') needed = self.registers["rsi"] diff --git a/tests/tuple/add1.py b/tests/tuple/add1.py index 61c221e..939a956 100644 --- a/tests/tuple/add1.py +++ b/tests/tuple/add1.py @@ -1 +1 @@ -print( ((42, 2, 3), 2)[0][0]) \ No newline at end of file +print( ((42, 42, 2), )[0][1]) \ No newline at end of file diff --git a/type_check_Ctup.py b/type_check_Ctup.py index 36acf69..8a8daa1 100644 --- a/type_check_Ctup.py +++ b/type_check_Ctup.py @@ -51,7 +51,7 @@ def type_check_stmt(self, s, env): match s: case Collect(size): pass - case Assign([Subscript(tup, Constant(index), Store())], value): + case Assign([Subscript(tup, Constant(index), x)], value): # TODO Store to x tup_t = self.type_check_atm(tup, env) value_t = self.type_check_atm(value, env) match tup_t: diff --git a/x86_ast.py b/x86_ast.py index c44c647..2ee38cb 100644 --- a/x86_ast.py +++ b/x86_ast.py @@ -207,7 +207,7 @@ class Global(arg): def __init__(self, name): self.name = name def __str__(self): - return str(self.name) + "(%rip)" + return '_' + str(self.name) + "(%rip)" def __repr__(self): return 'Global(' + repr(self.name) + ')'