Skip to content

Commit

Permalink
only functiontype need to fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jiamo committed Jun 19, 2022
1 parent 1105b65 commit a07b03d
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 83 deletions.
92 changes: 56 additions & 36 deletions compiler_dyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def calculate_tag(size, ty, arith=None):
tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6, endian='little')
if arith:
tag[57:62] = bitarray.util.int2ba(arith, length=5, endian='little')
print("tags", bitarray.util.ba2base(2, tag))
# print("tags", bitarray.util.ba2base(2, tag))
return bitarray.util.ba2int(tag)

class Compiler:
Expand Down Expand Up @@ -569,7 +569,6 @@ def cast_insert_exp(self, e):
right = Project(right, IntType())
return Inject(Compare(left, [cmp], [right]), BoolType())
case IfExp(expr_test, expr_body, expr_orelse):
# 所有的这种表达式可以用 children 来做
t = self.cast_insert_exp(expr_test)
t = Project(t, BoolType())
b = self.cast_insert_exp(expr_body)
Expand Down Expand Up @@ -656,8 +655,6 @@ def cast_insert(self, p):

def reveal_casts_exp(self, e):
match e:
# case Call(Name('input_int'), []):
# return Call(Project(e.func, FunctionType([], AnyType())), [])
case Project(e, ftype):
match ftype:
case BoolType() | IntType():
Expand Down Expand Up @@ -731,7 +728,6 @@ def reveal_casts_exp(self, e):
right = self.reveal_casts_exp(right)
return Compare(left, [cmp], [right])
case IfExp(expr_test, expr_body, expr_orelse):
# 所有的这种表达式可以用 children 来做
t = self.reveal_casts_exp(expr_test)
b = self.reveal_casts_exp(expr_body)
e = self.reveal_casts_exp(expr_orelse)
Expand All @@ -748,12 +744,12 @@ def reveal_casts_exp(self, e):
params = [(x, AnyType()) for x in params]
# breakpoint()
body = self.reveal_casts_exp(body)
# breakpoint()
return AnnLambda(params, AnyType(), body)
case _:
raise Exception('interp: reveal_casts_exp ' + repr(e))

def reveal_casts_stmt(self, stmt):
# TODO 每次都要展开 stmt 能不能不展开,直接处理 子children
match stmt:
case Expr(Call(Name('print'), [arg])):
new_arg = self.reveal_casts_exp(arg)
Expand Down Expand Up @@ -809,6 +805,7 @@ def reveal_casts(self, p):
raise Exception('interp: unexpected ' + repr(p))

trace(result)
# type_check_Lany.TypeCheckLany().type_check(p)
return result

def convert_assignments_exp(self, e):
Expand All @@ -823,7 +820,6 @@ def convert_assignments_exp(self, e):
case TagOf(value):
value = self.convert_assignments_exp(value)
return TagOf(value)
pass
case ValueOf(value, typ):
value = self.convert_assignments_exp(value)
return ValueOf(value, typ)
Expand All @@ -832,7 +828,6 @@ def convert_assignments_exp(self, e):
if id not in self.box_dict:
return e
else:
# breakpoint()
return Subscript(Name(self.box_dict[id]), Constant(0), Load())
case BinOp(left, op, right):
left = self.convert_assignments_exp( left)
Expand Down Expand Up @@ -1024,8 +1019,6 @@ def convert_to_closures_exp(self, e):
# if isinstance(c.my_extra_type, AnyType):
# breakpoint()
if isinstance(c.my_extra_type, AnyType):


# breakpoint()
ne.my_extra_type = AnyType()
else:
Expand All @@ -1046,12 +1039,13 @@ def convert_to_closures_exp(self, e):
case ValueOf(value, typ):
value = self.convert_to_closures_exp(value)
if isinstance(value, Name) and value.id in self.func_val_real_types:
typ = self.func_val_real_types[value.id]

if isinstance(self.func_val_real_types[value.id], TupleType):
typ = self.func_val_real_types[value.id]

print("ValueOf ", value, value.my_extra_type)
ne = ValueOf(value, value.my_extra_type)
ne.my_extra_type = value.my_extra_type
# bool cam't change
# print("ValueOf ", value, value.my_extra_type)
ne = ValueOf(value, typ)
ne.my_extra_type = typ
return ne

case Constant(v):
Expand All @@ -1067,14 +1061,20 @@ def convert_to_closures_exp(self, e):
# return TagOf(closureTy)
return e


case Name(id):
# breakpoint()
e.my_extra_type = self.func_val_real_types[id]
print("id is .... {}".format(id))

if id not in self.box_dict:
# breakpoint()
print("id is .... {} {} {}".format(id, self.top_funs, self.func_val_real_types))
# if isinstance(id, )
if id in self.top_funs:
# pass
funref = self.func_map[id]

ne = Closure(funref.arity, [funref])
ne.my_extra_type = self.func_val_real_types[id]
return ne
return e
else:
return Subscript(Name(self.box_dict[id]), Constant(0), Load())
Expand All @@ -1091,12 +1091,16 @@ def convert_to_closures_exp(self, e):

case Compare(left, [cmp], [right]):
left = self.convert_to_closures_exp(left)
right = self.convert_to_closures_exp(right)
if isinstance(left, TagOf) and left.value.id in self.func_val_real_types:
# breakpoint()
right = Constant(tagof(self.func_val_real_types[left.value.id]))
else:
right = self.convert_to_closures_exp(right)
return Compare(left, [cmp], [right])
if isinstance(self.func_val_real_types[left.value.id], TupleType):
right = Constant(tagof(self.func_val_real_types[left.value.id]))

# right = self.convert_to_closures_exp(right)
ne = Compare(left, [cmp], [right])
ne.my_extra_type = BoolType()
return ne

case IfExp(expr_test, expr_body, expr_orelse):
t = self.convert_to_closures_exp(expr_test)
Expand Down Expand Up @@ -1198,6 +1202,7 @@ def convert_to_closures_exp(self, e):

return c
case Uninitialized(ty):
e.my_extra_type = ty
return e
case _:
raise Exception('convert_to_closures_exp: unexpected ' + repr(e))
Expand All @@ -1215,7 +1220,7 @@ def convert_to_closures_stmt(self, stmt):
# l = self.convert_to_closures_exp(l)

v_expr = self.convert_to_closures_exp(value)
# if isinstance(v_expr, Tuple):
# if isinstance(v_expr, Closure):
# breakpoint()
# Subscript single type don't need to be check
if not isinstance(l, Subscript):
Expand Down Expand Up @@ -1762,7 +1767,7 @@ def rco_exp(self, e: expr, need_atomic: bool) -> Tupling[expr, Temporaries]:

# 如果这里不这个name
if body_tmps:
body = Begin([ Assign([name], expr)for name,expr in body_tmps], body)
body = Begin([Assign([name], expr) for name,expr in body_tmps], body)
if orelse_tmp:
orelse_expr = Begin([Assign([name], expr) for name, expr in orelse_tmp], orelse_expr)
return_expr = IfExp(test_expr, body, orelse_expr)
Expand Down Expand Up @@ -1952,7 +1957,7 @@ def explicate_effect(self, e: expr, cont: List[stmt],
new_body = self.explicate_stmt(s, new_body, basic_blocks)
return new_body
case _:
print("......", e)
# print("......", e)
return [] + cont

def explicate_pred(self, cnd: expr, thn: List[stmt], els: List[stmt],
Expand Down Expand Up @@ -2108,7 +2113,7 @@ def select_arg(self, e: expr) -> arg:
return Immediate(0)
case Constant(value):
return Immediate(value)
case Uninitialized(ty) if isinstance(ty, IntType) :
case Uninitialized(ty):
return Immediate(0)
# case FunRef(name, arith):
# # breakpoint()
Expand All @@ -2128,11 +2133,22 @@ def select_compare(self, expr, then_label, else_label) -> List[instr]:
y = self.select_arg(y)
return [
Instr('cmpq', [y, x]),
# Instr('j{}'.format(op_dict[str(op)]), [then_label]),
JumpIf(op_dict[str(op)], then_label),
Jump(else_label)
# Instr('jmp', [else_label])
]
case Name(id):
x = self.select_arg(expr)
y = Immediate(1)
# the result == 1 ?
return [
Instr('cmpq', [y, x]),
JumpIf(op_dict["=="], then_label),
Jump(else_label)
]
case _:
breakpoint()
raise Exception("no match {} ".format(expr))

def select_stmt(self, s: stmt) -> List[instr]:
# YOUR CODE HERE
Expand Down Expand Up @@ -2192,6 +2208,7 @@ def select_stmt(self, s: stmt) -> List[instr]:
case Assign([lhs], ValueOf(e, ty)):
left = self.select_arg(lhs)
value = self.select_arg(e)
# how should I now this is BoolType
if isinstance(ty, IntType) or isinstance(ty, BoolType):
result.append(Instr('movq', [value, left]))
result.append(Instr('sarq', [Immediate(3), left]))
Expand Down Expand Up @@ -2236,6 +2253,8 @@ def select_stmt(self, s: stmt) -> List[instr]:
else:
result.append(Instr('movq', [left_arg, lhs]))
result.append(Instr('addq', [right_arg, lhs]))


case Assign([lhs], BinOp(left, Sub(), right)):
# breakpoint()
left_arg = self.select_arg(left)
Expand Down Expand Up @@ -2305,7 +2324,6 @@ def select_stmt(self, s: stmt) -> List[instr]:
result.append(Instr('cmpq', [l, r]))
result.append(Instr('set{}'.format(op_dict[str(op)]), [ByteReg('bl')]))
result.append(Instr('movzbq', [ByteReg('bl'), lhs]))
pass
case Assign([lhs], Call(Name('len'), [arg])):
arg = self.select_arg(arg)
result.append(Instr('movq', [arg, Reg('r11')]))
Expand Down Expand Up @@ -2348,9 +2366,12 @@ def select_stmt(self, s: stmt) -> List[instr]:
# result.append(Instr('retq', []))
case Goto(label):
result.append(Jump(label))

case If(expr, [Goto(then_label)], [Goto(else_label)]):
# expr can be valueof
if_ = self.select_compare(expr, then_label, else_label)
result.extend(if_)

case Collect(size):
# size = self.select_arg(size)
result.append(Instr('movq', [Reg('r15'), Reg('rdi')]))
Expand Down Expand Up @@ -2613,12 +2634,12 @@ def analyze_dataflow(self, G, transfer, bottom, join):
worklist = deque(G.vertices())
debug = {}
while worklist:
print(worklist)
# print(worklist)
node = worklist.pop()
inputs = [mapping[v] for v in trans_G.adjacent(node)]
input = reduce(join, inputs, bottom)
output = transfer(node, input)
print("node", node, "input", input, "output", output)
# print("node", node, "input", input, "output", output)
if output != mapping[node]:
worklist.extend(G.adjacent(node))
mapping[node] = output
Expand Down Expand Up @@ -2653,7 +2674,7 @@ def build_interference(self, blocks) -> UndirectedAdjList:
# live_before_block[label] = tmp[ss[0]]
# live_after.update(tmp)

print("live_after ", self.live_after)
# print("live_after ", self.live_after)
for label, ss in blocks.items():
for s in ss:
match (s):
Expand Down Expand Up @@ -2719,7 +2740,7 @@ def less(u, v):
# print("handing", u)

adj_colors = {color_map[v] for v in interference_graph.adjacent(u) if v in color_map}
print(u, adj_colors)
# print(u, adj_colors)
if left_color := set(valid_colors) - adj_colors:
color = min(left_color)
if u not in color_map:
Expand Down Expand Up @@ -2774,8 +2795,7 @@ def allocate_registers(self, p: X86Program) -> X86Program:
new_blocks = {}
color_map = self.color_graph(blocks)


print("color_map", color_map)
# print("color_map", color_map)
so_far_rbp = 0
so_far_r15 = 0
cdef.rbp_spill = set()
Expand All @@ -2802,10 +2822,10 @@ def allocate_registers(self, p: X86Program) -> X86Program:

cdef.S = len(cdef.rbp_spill)
if cdef.r15_spill.intersection(cdef.rbp_spill):
print("r15 and rbp have somecolor", )
# print("r15 and rbp have somecolor", )
sys.exit(-1)

print("real_color_map", cdef.real_color_map)
# print("real_color_map", cdef.real_color_map)

for label, ss in blocks.items():
ss = blocks[label]
Expand Down
22 changes: 11 additions & 11 deletions compiler_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def calculate_tag(size, ty, arith=None):
tag[1:7] = tag[1:7] | bitarray.util.int2ba(size, length=6, endian='little')
if arith:
tag[58:63] = bitarray.util.int2ba(arith, length=5, endian='little')
print("tags", bitarray.util.ba2base(2, tag))
# print("tags", bitarray.util.ba2base(2, tag))
return bitarray.util.ba2int(tag)

class Compiler:
Expand Down Expand Up @@ -644,8 +644,7 @@ def convert_to_closures_exp(self, e):
# pass
funref = self.func_map[id]
return Closure(funref.arity, [funref])
# ...
#breakpoint()

return e
else:
# breakpoint()
Expand Down Expand Up @@ -1424,7 +1423,7 @@ def explicate_effect(self, e: expr, cont: List[stmt],
new_body = self.explicate_stmt(s, new_body, basic_blocks)
return new_body
case _:
print("......", e)

return [] + cont

def explicate_pred(self, cnd: expr, thn: List[stmt], els: List[stmt],
Expand Down Expand Up @@ -1603,6 +1602,8 @@ def select_compare(self, expr, then_label, else_label) -> List[instr]:
Jump(else_label)
# Instr('jmp', [else_label])
]
case _:
raise Exception("no match {} ".format(expr))

def select_stmt(self, s: stmt) -> List[instr]:
# YOUR CODE HERE
Expand Down Expand Up @@ -2006,12 +2007,12 @@ def analyze_dataflow(self, G, transfer, bottom, join):
worklist = deque(G.vertices())
debug = {}
while worklist:
print(worklist)
# print(worklist)
node = worklist.pop()
inputs = [mapping[v] for v in trans_G.adjacent(node)]
input = reduce(join, inputs, bottom)
output = transfer(node, input)
print("node", node, "input", input, "output", output)
# print("node", node, "input", input, "output", output)
if output != mapping[node]:
worklist.extend(G.adjacent(node))
mapping[node] = output
Expand Down Expand Up @@ -2046,7 +2047,7 @@ def build_interference(self, blocks) -> UndirectedAdjList:
# live_before_block[label] = tmp[ss[0]]
# live_after.update(tmp)

print("live_after ", self.live_after)
# print("live_after ", self.live_after)
for label, ss in blocks.items():
for s in ss:
match (s):
Expand Down Expand Up @@ -2112,7 +2113,7 @@ def less(u, v):
# print("handing", u)

adj_colors = {color_map[v] for v in interference_graph.adjacent(u) if v in color_map}
print(u, adj_colors)
# print(u, adj_colors)
if left_color := set(valid_colors) - adj_colors:
color = min(left_color)
if u not in color_map:
Expand Down Expand Up @@ -2167,8 +2168,7 @@ def allocate_registers(self, p: X86Program) -> X86Program:
new_blocks = {}
color_map = self.color_graph(blocks)


print("color_map", color_map)
# print("color_map", color_map)
so_far_rbp = 0
so_far_r15 = 0
cdef.rbp_spill = set()
Expand Down Expand Up @@ -2198,7 +2198,7 @@ def allocate_registers(self, p: X86Program) -> X86Program:
print("r15 and rbp have somecolor", )
sys.exit(-1)

print("real_color_map", cdef.real_color_map)
# print("real_color_map", cdef.real_color_map)

for label, ss in blocks.items():
ss = blocks[label]
Expand Down
Loading

0 comments on commit a07b03d

Please sign in to comment.