-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatabazel.bzl
163 lines (140 loc) · 4.8 KB
/
databazel.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Utility functions
def mk_param_summary(hyperparams):
return '__'.join([hpname + '_' + val for hpname, val in hyperparams.items()])
def insert_param_summary(filename, param_summary):
basename, extn = filename.rsplit('.', 1)
return basename + '__' + param_summary + '.' + extn
def combinations(ll):
# This got kind of difficult because skylark doesn't allow recursion,
# but I think it's correct?
last_pass = []
cur_pass = [[]]
for l in ll:
last_pass = cur_pass
cur_pass = []
for elem in l:
for prev_combination in last_pass:
cur_pass.append(list(prev_combination) + [elem])
return cur_pass
# Internal functions
def _model_internal(data, model_output, hyperparams, ctx):
hyperparams_struct = struct(**hyperparams)
args = [
'--data', data.path,
'--model-output-path', model_output.path,
'--hyperparams', hyperparams_struct.to_json()
]
ctx.actions.run(
inputs = [data],
outputs = [model_output],
arguments = args,
progress_message = "Running training script with args %s" % args,
executable = ctx.executable.train_executable,
)
def _eval_internal(data, model, output, ctx):
args = [
'--data-path', data.path,
'--model-path', model.path,
'--output-file', output.path
]
ctx.actions.run(
inputs = [data, model],
outputs = [output],
arguments = args,
progress_message = "Running evaluation script with args %s" % args,
executable = ctx.executable.eval_executable,
)
# Rule definitions
def _model_impl(ctx):
_model_internal(
ctx.file.training_data,
ctx.outputs.model,
ctx.attr.hyperparams,
ctx
)
model = rule(
implementation = _model_impl,
attrs = {
"deps": attr.label_list(),
"training_data": attr.label( # TODO turn this into a keyed list or something
allow_single_file = True,
),
"train_executable": attr.label(
cfg = "target",
executable = True
),
"model": attr.output(),
"hyperparams": attr.string_dict()
},
)
def _evaluate_impl(ctx):
_eval_internal(
ctx.file.test_data,
ctx.file.model,
ctx.outputs.outputs,
ctx
)
evaluate = rule(
implementation = _evaluate_impl,
attrs = {
"deps": attr.label_list(),
"test_data": attr.label(allow_single_file=True),
"model": attr.label(allow_single_file=True),
"outputs": attr.output_list(allow_empty=False),
"eval_executable": attr.label(
cfg = "target",
executable = True
)
}
)
def _hyperparam_search_impl(ctx):
# This impl is totally wrong but I'm going to just try to get a macro working
files_to_build = []
hyperparam_names, hyperparam_valuess = zip(*ctx.attr.hyperparam_values.items())
hyperparam_combinations = combinations(hyperparam_valuess)
for hyperparam_values in hyperparam_combinations:
# Prep our new unique names for this run
these_values = dict(zip(hyperparam_names, hyperparam_values))
param_summary = mk_param_summary(these_values)
new_name = ctx.attr.name + "__" + param_summary
new_model_name = insert_param_summary(ctx.attr.model_name, param_summary)
# This is going to create the new model file, which we need to declare...
new_model_file = ctx.actions.declare_file(new_model_name)
# Create the model training instance for this run
_model_internal(
data = ctx.file.data,
model_output = new_model_file,
hyperparams = these_values,
ctx = ctx
)
# And then also create an eval instance
new_eval_output = ctx.actions.declare_file(
insert_param_summary(ctx.attr.eval_output, param_summary)
)
_eval_internal(
data = ctx.file.data,
model = new_model_file,
output = new_eval_output,
ctx = ctx
)
# Finally, add the files generated here to the list of default files for the rule
files_to_build += [new_eval_output, new_model_file]
return DefaultInfo(files=depset(files_to_build))
hyperparam_search = rule(
implementation = _hyperparam_search_impl,
attrs = {
"deps": attr.label_list(),
"data": attr.label(allow_single_file=True),
"model_name": attr.string(mandatory=True),
"eval_executable": attr.label(
cfg = "target",
executable = True
),
"eval_output": attr.string(),
"train_executable": attr.label(
cfg = "target",
executable = True
),
"hyperparam_values": attr.string_list_dict(mandatory=True)
}
)