-
Notifications
You must be signed in to change notification settings - Fork 0
/
expr_constructor.py
191 lines (168 loc) · 8.28 KB
/
expr_constructor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Classes for producing Relay expressions.
These are provided to separate producing expressions from decision-making policies.
None of these should include any kind of decision-making or randomness directly.
"""
import tvm
from tvm import relay
from type_constructor import TypeConstructs as TC
def all_pattern_vars(pat, acc=None):
found = set() if acc is None else acc
if isinstance(pat, relay.PatternVar):
found.add(pat.var)
return found
if isinstance(pat, relay.PatternWildcard):
return found
if isinstance(pat, (relay.PatternTuple, relay.PatternConstructor)):
for inner_pat in pat.patterns:
found = all_pattern_vars(inner_pat, acc=found)
return found
raise TypeError(f"Unrecognized pattern {pat}")
class ExprConstructor:
def __init__(self, var_scope, generate_expr, generate_type,
choose_ctor, generate_patterns, generate_op):
"""
var_scope: Responsible for producing variables
and tracking what's in scope
(lets, ifs, matches, and funcs produce new scopes)
generate_expr: Function that takes a type and returns an expr of that type
generate_type: As named (can specify supported types and params)
choose_ctor: Given an ADT handle, returns a constructor for it
generate_patterns: Given a type, generate a set of complete match patterns for it
generate_op: Given a return type, returns the handler for an op with that return type
"""
self.var_scope = var_scope
self.generate_expr = generate_expr
self.generate_type = generate_type
self.choose_ctor = choose_ctor
self.generate_patterns = generate_patterns
self.generate_op = generate_op
def construct_tuple_literal(self, member_types):
return relay.Tuple([self.generate_expr(ty) for ty in member_types])
def construct_ref_literal(self, inner_type):
return relay.RefCreate(self.generate_expr(inner_type))
def construct_func_literal(self, arg_types, ret_type, own_name=None):
# own name: if the function is recursive, it needs to have itself in scope
with self.var_scope.new_scope():
arg_vars = [self.var_scope.new_local_var(ty, add_to_scope=True)
for ty in arg_types]
if own_name is not None:
self.var_scope.add_to_scope(own_name)
body = self.generate_expr(ret_type)
return relay.Function(arg_vars, body, ret_type=ret_type)
def construct_adt_literal(self, type_call):
ctor, instantiated_type = self.choose_ctor(type_call)
# we must annotate the type args or else TVM
# may not be able to infer types in cases like
# Nil() or None()
return relay.Call(ctor, [self.generate_expr(input_type)
for input_type in instantiated_type.arg_types],
type_args=type_call.args)
# connectives
def construct_let_expr(self, ret_type):
# handling recursive function definitions is tricky
binder_ty = self.generate_type()
identifier = self.var_scope.new_local_var(binder_ty, add_to_scope=False)
with self.var_scope.new_scope():
own_name = None
if isinstance(binder_ty, relay.FuncType):
own_name = identifier
binder_expr = self.generate_expr(binder_ty, own_name=own_name)
with self.var_scope.new_scope():
self.var_scope.add_to_scope(identifier)
bound_expr = self.generate_expr(ret_type)
return relay.Let(identifier, binder_expr, bound_expr)
def construct_tuple_index(self, ret_type, idx):
# the tuple must be _at least_ big enough to contain the index
# and tuple[idx] must be of the ret type
constrained = {idx: ret_type}
tup_type = self.generate_type(gen_params={
TC.TUPLE: {
"min_arity": idx+1,
"constrained": constrained
}
})
assert isinstance(tup_type, relay.TupleType)
return relay.TupleGetItem(self.generate_expr(tup_type), idx)
def construct_if_branch(self, ret_type):
# branch condition must be a boolean scalar
cond_type = relay.scalar_type("bool")
cond_expr = self.generate_expr(cond_type)
# new scope for each branch
with self.var_scope.new_scope():
true_branch = self.generate_expr(ret_type)
with self.var_scope.new_scope():
false_branch = self.generate_expr(ret_type)
return relay.If(cond_expr, true_branch, false_branch)
def construct_function_call(self, ret_type):
func_type = self.generate_type(gen_params={
TC.FUNC: {
"ret_type": ret_type
}
})
assert isinstance(func_type, relay.FuncType)
func_expr = self.generate_expr(func_type)
arg_exprs = [self.generate_expr(arg_types) for arg_types in func_type.arg_types]
return relay.Call(func_expr, arg_exprs)
def construct_match(self, ret_type):
# matching only defined on tuples and ADTs
match_type = self.generate_type(gen_params={
TC.TUPLE: {},
TC.ADT: {}
})
match_val = self.generate_expr(match_type)
match_patterns = self.generate_patterns(match_type)
match_clauses = []
# if there are var patterns, those vars are bound to a new scope in each clause
for pattern in match_patterns:
pattern_vars = all_pattern_vars(pattern)
with self.var_scope.new_scope():
for var in pattern_vars:
self.var_scope.add_to_scope(var)
match_expr = self.generate_expr(ret_type)
match_clauses.append(relay.Clause(pattern, match_expr))
return relay.Match(match_val, match_clauses)
def construct_ref_write(self):
# ref writes are always of type (), so there is no type param
ref_type = self.generate_type(gen_params={TC.REF: {}})
assert isinstance(ref_type, relay.RefType)
ref_expr = self.generate_expr(ref_type)
inner_type = ref_type.value
assign_expr = self.generate_expr(inner_type)
return relay.RefWrite(ref_expr, assign_expr)
def construct_ref_read(self, ret_type):
ref_expr = self.generate_expr(relay.RefType(ret_type))
return relay.RefRead(ref_expr)
def construct_op_call(self, ret_type):
# Warning: Check that there exists an operator with the given return type first
# Abstracting away many details of op calls because Relay ops are very varied:
# there are ops that return just tensors, others that return tuples of tensors,
# some that take only tensors, some that take tuples of tensors, etc.,
# and some that take compile-time parameters (not Relay exprs),
# so for maximum flexibility, each op should manage how it is called
op_info = self.generate_op(ret_type)
arg_types, additional_params = op_info.generate_arg_types(ret_type)
arg_exprs = [self.generate_expr(arg_type) for arg_type in arg_types]
return op_info.produce_call(arg_exprs, additional_params=additional_params)
# handle pattern generation
class PatternConstructor:
def __init__(self, var_scope, generate_pattern, choose_ctor):
"""
var_scope: For generating pattern vars
generate_pattern: Given a type, generates a pattern that matches that type
choose_ctor: Given an ADT handle, returns a constructor for it
"""
self.var_scope = var_scope
self.generate_pattern = generate_pattern
self.choose_ctor = choose_ctor
def construct_var_pattern(self, var_type):
fresh_var = self.var_scope.new_local_var(var_type, add_to_scope=False)
return relay.PatternVar(fresh_var)
def construct_tuple_pattern(self, tup_type):
nested_patterns = [self.generate_pattern(field_type) for field_type in tup_type.fields]
return relay.PatternTuple(nested_patterns)
def construct_ctor_pattern(self, type_call):
ctor, instantiated_type = self.choose_ctor(type_call)
nested_patterns = [self.generate_pattern(input_type)
for input_type in instantiated_type.arg_types]
return relay.PatternConstructor(ctor, nested_patterns)