Skip to content

Commit

Permalink
add allocal r8
Browse files Browse the repository at this point in the history
  • Loading branch information
jiamo committed Mar 13, 2022
1 parent 91b7883 commit c7142bb
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 42 deletions.
136 changes: 102 additions & 34 deletions compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Tuple as Tupling
import type_check_Ltup
from interp_x86.eval_x86 import interp_x86
import type_check_Ctup


Binding = Tupling[Name, expr]
Temporaries = List[Binding]
Expand All @@ -37,16 +39,19 @@


def calculate_tag(size, ty):
tag = bitarray.bitarray(64)
tag = bitarray.bitarray(64, endian="little")
tag.setall(0)
p_mask = 7
tag[0] = 1
for i, type in enumerate(ty.types):
if type == TupleType:
# breakpoint()
if isinstance(type, TupleType):
tag[p_mask + i] = 1
else:
tag[p_mask + i] = 0
tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6)
return bitarray.util.ba2hex(tag)
tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6, endian='little')
print("tags", bitarray.util.ba2base(2, tag))
return bitarray.util.ba2int(tag)

class Compiler:

Expand Down Expand Up @@ -158,13 +163,13 @@ def expose_allocation_exp(self, exp) -> Tupling[expr, List[stmt]]:
stmts.append(new_stmt)
tmp_exprs.append(var)
# breakpoint()
n = 8 + 8 * len(exprs)
n = len(exprs)
stmts.append(
If(Compare(BinOp(GlobalValue("free_pr"), Add(), Constant(n)), [Lt()], [GlobalValue("fromspace_end")]),
If(Compare(BinOp(GlobalValue("free_ptr"), Add(), Constant(8 * (n+1))), [Lt()], [GlobalValue("fromspace_end")]),
[Expr(Constant(0))],
[Collect(n)])
[Collect(8 * (n+1))])
)
tmp = generate_name(tmp)
tmp = generate_name("alloc")
var = Name(tmp)
stmts.append(Assign([var], Allocate(n, exp.has_type))) # may exp.has_type.types
for i in range(len(exprs)):
Expand Down Expand Up @@ -532,14 +537,16 @@ def select_arg(self, e: expr) -> arg:
match e:
case Name(name):
return Variable(name)
case GlobalValue(name):
return x86_ast.Global(name)
case Constant(True):
return Immediate(1)
case Constant(False):
return Immediate(0)
case Constant(value):
return Immediate(value)
case x if isinstance(x, int):
return Immediate(x)
# case x if isinstance(x, int):
# return Immediate(x)
case _:
raise Exception('error in select_arg, unexpected ' + repr(e))

Expand Down Expand Up @@ -645,10 +652,10 @@ def select_stmt(self, s: stmt) -> List[instr]:
case Assign([lhs], Allocate(size, ty)):
lhs = self.select_arg(lhs)

size = self.select_arg(size)
tag = calculate_tag(size.value, ty)
# size = self.select_arg(size)
tag = calculate_tag(size, ty)
result.append(Instr("movq", [x86_ast.Global("free_ptr"), Reg('r11')]))
result.append(Instr("movq", [8 * (size.value + 1) , Reg('r11')]))
result.append(Instr("addq", [Immediate(8 * (size + 1)), x86_ast.Global("free_ptr")]))
result.append(Instr("movq", [Immediate(tag), Deref('r11', 0)]))
result.append(Instr('movq', [Reg('r11'), lhs]))
case Assign([lhs], value):
Expand All @@ -664,19 +671,19 @@ def select_stmt(self, s: stmt) -> List[instr]:
if_ = self.select_compare(expr, then_label, else_label)
result.extend(if_)
case Collect(size):
size = self.select_arg(size)
result.append(Instr('movq', [Reg('r15'), Reg('%rdi')]))
result.append(Instr('movq', [size, Reg('rsi')]))
result.append(Callq("collect"))
# size = self.select_arg(size)
result.append(Instr('movq', [Reg('r15'), Reg('rdi')]))
result.append(Instr('movq', [Immediate(size), Reg('rsi')]))
result.append(Callq(label_name("collect"), 2))
case _:
raise Exception('error in select_stmt, unexpected ' + repr(s))
return result
pass

def select_instructions(self, p: Module) -> X86Program:
# YOUR CODE HERE


type_check_Ctup.TypeCheckCtup().type_check(p)
# breakpoint()
blocks = {}
match p:
case CProgram(basic_blocks):
Expand All @@ -690,6 +697,7 @@ def select_instructions(self, p: Module) -> X86Program:


x86 = X86Program(blocks)
x86.var_types = p.var_types
# breakpoint()
# print("......")
# interp_x86(x86)
Expand Down Expand Up @@ -768,10 +776,21 @@ def read_var(self, i: instr) -> Set[location]:
case _:
return set()

def free_var(self, t):
match(t):
case Variable(i):
return t
case Reg(r):
return t
case Deref(r, offset):
return Reg(r)
case _:
return set()

def write_var(self, i) -> Set[location]:
match (i):
case Instr("movq", [s, t]):
return {i.args[1]}
return set([self.free_var(t)])
case Callq(func, num_args):
return set(callee_saved_regs)
case _:
Expand Down Expand Up @@ -844,7 +863,7 @@ def transfer(self,label, live_after_block):
before_instr_set = tmp.union(live_after_block)
case _:
before_instr_set = (self.live_after[s] - self.write_var(s)).union(self.read_var(s))
# print("s" , s, pre_instr_set)
# print("s" , s, before_instr_set)
self.live_before[s] = before_instr_set
live_before_block = live_before_block.union(before_instr_set)
pre_instr = s
Expand Down Expand Up @@ -908,6 +927,7 @@ def build_interference(self, blocks) -> UndirectedAdjList:
match (s):
case Instr("movq", [si, d]):
# si = s.args[0]
d = self.free_var(d)
for v in self.live_after[s]:
if v != d and v != si:
interference_graph.add_edge(d, v)
Expand All @@ -929,7 +949,11 @@ def color_graph(self, blocks, k=100) -> Dict[location, int]:
# first make it k big enough
valid_colors = list(range(0, k)) # number of colar
# Rdi 的保存问题
color_map = {Reg('rax'): -1, Reg('rsp'): -2, Reg('rdi'): -3, ByteReg('bl'): -4}
color_map = {
Reg('rax'): -1, Reg('rsp'): -2, Reg('rdi'): -3, ByteReg('bl'): -4, Reg('r11'): -5,
Reg('r15'): -6, Reg('rsi'): -7 # rsi 其实可以用来做其他事情。 但如果分配 rsi 9 rsi 的 color
# 算法 color 9 和 可以分配出去reg 的color 0 1 3 矛盾
}
# color_map = {}
saturated = {}

Expand All @@ -951,12 +975,15 @@ def less(u, v):
for v in vsets:
saturated[v] = set()
for v in vsets:

queue.push(v)


while not queue.empty():

u = queue.pop()
print("handing", u)
# print("handing", u)

adj_colors = {color_map[v] for v in interference_graph.adjacent(u) if v in color_map}
print(u, adj_colors)
if left_color := set(valid_colors) - adj_colors:
Expand All @@ -973,12 +1000,14 @@ def less(u, v):
def allocate_registers(self, p: X86Program) -> X86Program:
# YOUR CODE HERE

# breakpoint()
# ? RDI
self.color_regs = [Reg("rbx"), Reg("rcx"), Reg("rdx"), Reg("rsi"), Reg("rdi"), Reg("r8"), Reg("r9"), Reg("r10")]
self.color_regs = [Reg("rbx"), Reg("rcx"), Reg("rdx"), Reg("rsi"),
Reg("rdi"), Reg("r8"), Reg("r9"), Reg("r10")]
self.color_regs = [Reg("rbx"), Reg("rcx")]
self.color_regs = [Reg("rbx")]
# rcx as tmp
self.color_regs = [Reg("rbx"), Reg("rcx")]
self.color_regs = [Reg("rbx"), Reg("rcx"), Reg("r8")]

self.alloc_callee_saved_regs = list(set(self.color_regs).intersection(callee_saved_regs))
self.C = len(self.alloc_callee_saved_regs)
Expand All @@ -988,6 +1017,9 @@ def allocate_registers(self, p: X86Program) -> X86Program:
color_regs_map[-2] = Reg('rsp')
color_regs_map[-3] = Reg('rdi')
color_regs_map[-4] = ByteReg("bl")
color_regs_map[-5] = Reg('r11')
color_regs_map[-6] = Reg('r15')
color_regs_map[-7] = Reg('rsi')
self.real_color_map = {}

match(p):
Expand All @@ -996,19 +1028,37 @@ def allocate_registers(self, p: X86Program) -> X86Program:
# breakpoint()
new_blocks = {}
color_map = self.color_graph(blocks)
self.S = len(set(color_map.values())) - len(self.color_regs)

print("color_map", color_map)

for color in sorted(set(color_map.values())):

print("color_map", color_map)
so_far_rbp = 0
so_far_r15 = 0
self.rbp_spill = set()
self.r15_spill = set()
for var, color in sorted(color_map.items(), key=lambda i: i[1]):
# 相同的 color 但 type 不一样
if color in self.real_color_map:
continue
if color in color_regs_map:
self.real_color_map[color] = color_regs_map[color]
else:
# Yes
self.real_color_map[color] = Deref("rbp", -8*(color-len(self.color_regs) + self.C + 1))
# breakpoint()
if isinstance(p.var_types.get(str(var)), TupleType):
# breakpoint()
# r15 is up r15 was saveid in heap
self.real_color_map[color] = Deref("r15", 8*(so_far_r15))
so_far_r15 += 1
self.r15_spill.add(color)
else:
self.real_color_map[color] = Deref("rbp", -8*(so_far_rbp+ self.C + 1))
so_far_rbp += 1
self.rbp_spill.add(color)

self.S = len(self.rbp_spill)
if self.r15_spill.intersection(self.rbp_spill):
print("r15 and rbp have somecolor", )
sys.exit(-1)

print("real_color_map", self.real_color_map)

Expand Down Expand Up @@ -1049,10 +1099,15 @@ def patch_instr(self, i: instr) -> List[instr]:
match(i):
case Instr(instr, [x, y]) if x == y:
return []
case Instr(instr, [Deref("rbp", x), Deref("rbp", y)]):
case Instr(instr, [Deref(label_x, x), Deref(label_y, y)]):
return [
Instr("movq", [Deref("rbp", x), Reg("rax")]),
Instr("movq", [Reg("rax"), Deref("rbp", y)])
Instr("movq", [Deref(label_x, x), Reg("rax")]),
Instr("movq", [Reg("rax"), Deref(label_y, y)])
]
case Instr(instr, [x86_ast.Global(x), Deref(label_y, y)]):
return [
Instr("movq", [x86_ast.Global(x), Reg("rax")]),
Instr("movq", [Reg("rax"), Deref(label_y, y)])
]
case Instr('cmpq', [x, Immediate(v)]):
return [
Expand Down Expand Up @@ -1114,12 +1169,25 @@ def prelude_and_conclusion(self, p: X86Program) -> X86Program:
for reg in extra_saved_regs:
main.append(Instr("pushq", [reg]))
main.extend([ Instr("subq", [Immediate(self.rsp_sub), Reg("rsp")])])
main.extend([
Instr("movq", [Immediate(65536), Reg("rdi")]),
Instr("movq", [Immediate(65536), Reg("rsi")]),
Callq(label_name("initialize"), 2),
Instr("movq", [x86_ast.Global("rootstack_begin"), Reg("r15")]),
])
len_spill_r15 = len(self.r15_spill)
for i in range(len_spill_r15):
main.append( Instr("movq", [Immediate(0), Deref("r15", 8 * i)]))
main.append(Instr('addq', [Immediate(8 * len_spill_r15), Reg('r15')]))
main.append(Jump(label_name("start")))
blocks[label_name("main")] = main
# for label , body in blocks.items():
# pass

conclusion.extend([ Instr("addq", [Immediate(self.rsp_sub), Reg("rsp")]),])
conclusion.extend([
Instr("subq", [Immediate(8 * len_spill_r15), Reg("r15")]),
Instr("addq", [Immediate(self.rsp_sub), Reg("rsp")]),
])
for reg in extra_saved_regs[::-1]:
conclusion.append(Instr("popq", [reg]))
conclusion.append(Instr("popq", [Reg('rbp')])) # seem no need pop
Expand Down
7 changes: 4 additions & 3 deletions interp_x86/convert_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from lark import Tree
from ast import Name, Constant
from x86_ast import *
from utils import label_name, GlobalValue
from utils import label_name, GlobalValue, trace

def convert_int(value):
if value >= 0:
return Tree('int_a', [Tree('int_a', [value])])
else:
return Tree('neg_a',[Tree('int_a', [- value])])
return Tree('neg_a', [Tree('int_a', [- value])])

def convert_arg(arg):
match arg:
Expand All @@ -24,14 +24,15 @@ def convert_arg(arg):
return Tree('mem_a', [convert_int(offset), reg])
case ByteReg(id):
return Tree('reg_a', [id])
case GlobalValue(id):
case Global(id):
return Tree('global_val_a', [id, 'rip'])
case _:
raise Exception('convert_arg: unhandled ' + repr(arg))

def convert_instr(instr):
match instr:
case Instr(instr, args):
# trace("convert.... {} {}".format( instr, args))
return Tree(instr, [convert_arg(arg) for arg in args])
case Callq(func, args):
return Tree('callq', [func])
Expand Down
Loading

0 comments on commit c7142bb

Please sign in to comment.