forked from yandexdataschool/nlp_course
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex.html
435 lines (380 loc) · 20.3 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
<!doctype html>
<html lang="en">
<!-- this html defines a webpage that serves a neural language model. It assumes that you've already gone throgh train_and_export.ipynb to obtain the actual model. You can open it with your browser or serve it with SimpleHTTPServer, Twistd or dozens of other googlable frameworks -->
<head>
<meta charset="utf-8">
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tfjs-vis.umd.min.js"></script>
<script src="https://d3js.org/d3.v3.min.js" charset="utf-8"></script>
<style>
:root{}
.lm_window {
width: 100%;
border: 1px solid #ccc;
border-radius: 1px;
margin: 10px 5px;
padding: 3px;
background-color: white;
text-align: center;
}
#lm_window {
box-shadow: 0 2px 4px #377b94, 0 1px 1px ;
}
#lm_window:hover {
box-shadow: 0 6px 12px #377b94, 0 4px 4px #377b94;
}
.prompt_text {
font-size: 24px;
font-family: "Comic Neue", "Arial";
margin-top: 10px;
margin-bottom: 10px;
font-weight: bold;
text-align: center;
background-color: #ebf6fa;
}
.comment_text {
font-size: 16px;
font-family: "Gill Sans", sans-serif;
font-style: italic;
text-align: center;
}
.rect_button {
width: 25%;
background-color: #fafafa;
font-size: 20px;
font-family: "Comic Neue", "Arial";
font-weight: bold;
color: black;
margin-right: 20px;
margin-left: 20px;
margin-top: 10px;
margin-bottom: 10px;
border: 0px solid black;
border-radius: 12px;
padding: 20px;
padding-top: 0px;
padding-bottom: 0px;
text-decoration: none;
display: inline-block;
text-align: center;
box-shadow: 0px 2px 3px #377b94, 0 1px 1px #377b94;
}
.rect_button:hover {
box-shadow: 0 3px 6px #377b94, 0 2px 2px #377b94;
}
.button_text {
margin-top: 3px;
margin-bottom: 3px;
}
.prompt_and_button {
display: grid;
grid-template-columns: auto 30px;
}
.next_button {
width: 0;
height: 0;
border-top: 30px solid transparent;
border-left: 45px solid #1b6f8c;
border-bottom: 30px solid transparent;
}
.next_text {
color: #1b6f8c;
font-style: italic;
font-weight: 600;
font-family: "Comic Neue", "Arial";
margin: 0 auto;
}
.axis path, .axis line {
fill: none;
stroke: black;
shape-rendering: crispEdges; /* Round any decimal pixels so it'll render nicely */
}
</style>
</head>
<body>
<center>
<div class="lm_window" id="lm_window" style="width: 70%;">
<div class="cardContent">
<p class="prompt_text">Let's write a paper!</p>
<hr color="#e9f0f2" style="margin:5px; display: block">
<div id="lm_prompt_screen" style="display: block">
<p class="comment_text">
Every great paper starts a an inciting thought - something only a human can have ...<br>
</p>
<div class="prompt_and_button" style="margin-top: -20px;">
<div style="width: 85%; margin-left: 7%;">
<p class="data_text" style="font-size: 16px; color: gray; text-align:left; margin-bottom: -2px;">
Write a prefix, for example: "Deep Variational" (without quotes)</p>
<textarea class="input_prompt" style="width: 100%; min-height: 72px; resize: vertical;"></textarea>
</div>
<div class="next" style="opacity:0.2; transform: translate(-32px, 35px);">
<div class="next_button"></div>
<p class="next_text">Start!</p>
</div>
</div>
<p class="comment_text status_comment" id="aftertext_jk" style="color:gray; font-size: 16px">
Waiting for the model to load... (this should take 3-5s)
</p>
</div>
<div id="lm_generate_screen" style="display: none">
<div style="width: 85%; margin-left: 7%; float: center; display: block">
<p class="data_text"
style="font-size: 16px; color: gray; text-align:left; margin-bottom: -2px; margin-left: 5px;">
Currently generated text ("@@" is a BPE split, ";" is newline). Click on the circles below to add tokens.</p>
<textarea class="input_preprocessed"
style="width: 100%; resize: vertical; margin-left: 5px" disabled=true></textarea>
</div>
<div id="d3_container" style="width: 85%; display:block; margin: 0 auto;">
<svg id="d3_beamsearch_svg" width=100% height=240px></svg>
</div>
<button class="rect_button button_add_greedy"
onclick="console.warn('generate button should not be available at this point')">
<p class="button_text" style="font-size:16px">Add greedy</p>
</button>
<button class="rect_button button_backspace"
onclick="console.warn('backspace button should not be available at this point')">
<p class="button_text" style="font-size:16px">Backspace</p>
</button>
<button class="rect_button button_restart"
onclick="console.warn('restart button should not be available at this point')">
<p class="button_text" style="font-size:16px">Restart</p>
</button>
</div>
</div>
</div>
</center>
<script> // this script ONLY defines global constants
var beam_size = 4;
var search_length = 5;
</script>
<script> // this <script> defines model, vocabularies & auxiliary functions, NOT the interface
var model, lstm_units, emb_size, voc_size, bpe_rules, vocab, bos, eos, unk,
token_to_ix={}, bpe_rule_priority={}, is_initialized=false; // to be initialized
let tokenizer_regex = new RegExp(/([A-zÀ-ÿ-]+|[0-9._]+|.|!|\?|'|"|:|;|,|-)/i);
let bpe_terminal = "</w>", bpe_sep = "@@";
async function async_initialize_model() {
console.assert(!is_initialized, "Model was initialized more than once!")
model = await tf.loadLayersModel('lm/model.json');
voc_response = await fetch('voc.json');
[lstm_units, emb_size, voc_size, bpe_rules, vocab, bos, eos, unk] = await voc_response.json();
vocab.forEach((token, index) => token_to_ix[token] = index);
bpe_rules.forEach((pair, index) => bpe_rule_priority[pair[0] + " " + pair[1]] = index);
is_initialized = true;
}
function preprocess(string) {
// raw string -> array of tokens -> array of bpe segments
console.assert(is_initialized, "Model is not initialized!")
return tokenize(string).flatMap(bpeize_token)
}
function tokenize(string) {
let tokens = [];
string.split(tokenizer_regex)
.filter(token => (token != ' ' && token != '' && token != '\n'))
.forEach(token => tokens.push(token.toLowerCase()))
return tokens
}
function bpeize_token(token) {
// split a single token into an Array of bpe segments; equivalent to https://github.com/rsennrich/subword-nmt v0.2
let segments = token.split('');
segments[segments.length - 1] += bpe_terminal;
while(segments.length > 1){
// find bpe rule with lowest index
var best_rule_index = Infinity;
for(let i = 0; i < segments.length - 1; i++) {
let cand = segments[i] + " " + segments[i + 1];
if ((cand in bpe_rule_priority) && (bpe_rule_priority[cand] < best_rule_index))
best_rule_index = bpe_rule_priority[cand];
}
if (best_rule_index == Infinity)
break
// apply that rule everywhere
let [chosen_left, chosen_right] = bpe_rules[best_rule_index];
for(let i = segments.length - 2; i >= 0; i--) {
if (segments[i] == chosen_left && segments[i + 1] == chosen_right) {
segments.splice(i + 1, 1);
segments[i] = chosen_left + chosen_right;
}
}
}
// don't print </w> symbols
end = segments.length - 1
if (segments[end] == bpe_terminal)
segments.pop()
else if (segments[end].endsWith(bpe_terminal))
segments[end] = segments[end].slice(0, segments[end].length - bpe_terminal.length);
// append bpe separator to all segments except last
for (let i = 0; i < segments.length - 1; i++)
segments[i] += bpe_sep;
return segments
}
function get_initial_state() {
console.assert(is_initialized, "Model is not initialized!")
return [tf.zeros([1, voc_size]), tf.zeros([1, lstm_units]), tf.zeros([1, lstm_units])]
}
function generate_beam_search_vertices(prev_state) {
// runs beam search with up to beam_size hypotheses for search_length steps
// updates bokeh data source (and hence, the data shown in bokeh
var [prev_logits, prev_hid, prev_cell] = prev_state;
var plot_data = [
{ x: 0.0, y: 0.0, token_text: "<...>", circle_text: "#1",
edge: [{x: -0.5, y: 0.0}, {x: 0, y: 0}], parent_index: -1},
]
var parent_i = 0;
for (let t = 0; t < search_length; t++) {
var token_ix = prev_logits.argMax(-1);
var best_token_indices = findIndicesOfMax(prev_logits.arraySync()[0], beam_size)
for (let i = 0; i < best_token_indices.length; i++){
var token = vocab[best_token_indices[i]];
plot_data.push({
x: t + 1, y: 1.5 * i, token_text: token, circle_text: `#${i + 1}`,
edge: [{x: plot_data[parent_i].x, y: plot_data[parent_i].y},
{x: t + 1, y: 1.5 * i}], // add edge from parent to oneself
parent_index: parent_i,
})
}
[prev_logits, prev_hid, prev_cell] = model.predict([token_ix.reshape([1, -1]), prev_hid, prev_cell]);
parent_i = plot_data.length - beam_size;
}
return plot_data;
}
function sleep(ms) {return new Promise(resolve => setTimeout(resolve, ms));}
function findIndicesOfMax(inp, count) {
var outp = [];
for (var i = 0; i < inp.length; i++) {
outp.push(i); // add index to output array
if (outp.length > count) {
outp.sort(function(a, b) { return inp[b] - inp[a]; }); // descending sort the output array
outp.pop(); // remove the last index (index of smallest element in output array)
}
}
return outp;
}
</script>
<script> // this script ONLY draws d3js charts and nothing else
var svg = d3.select("#d3_beamsearch_svg");
var d3_div = document.getElementById('d3_container');
var x_scale, y_scale, circle_containers, beam_search_ready = false;
var example_plot_data = [ { x: 0.0, y: 0.0, token_text: "<...>", circle_text: "#1",
edge: [{x: -0.5, y: 0.5}, {x:0, y:0}], parent_index: -1},
]; // this is an example, the actual data will be filled in by generate_beam_search_vertices
var circle_attrs = {r: 12, fill: "#87CEEB", stroke: "navy", "stroke-width": 1.5};
var line_attrs = {
fill: "none", stroke: "black", "stroke-width": 1.0,
d: (function(plot_data_elem)
{return d3.svg.line()
.x(function(d_i) { return x_scale(d_i.x); })
.y(function(d_i) { return y_scale(d_i.y); })
(plot_data_elem.edge)})
};
// TODO features: change line width if is_best, change in some other way if selected
function draw_beam_search_vertices(plot_data, callback_on_click) {
svg.selectAll("*").remove();
x_scale = d3.scale.linear().domain([-0.25, search_length + 0.35]).range([0, d3_div.clientWidth]);
y_scale = d3.scale.linear().domain([-1.0, 1.5 * (beam_size - 1) + 0.5]).range([0, d3_div.clientHeight]);
svg.selectAll(".line").data(plot_data).enter().append("path").attr(line_attrs);
circle_containers = svg.selectAll("g circleWithText").data(plot_data).enter().append("g")
.attr("transform", function(d){return `translate(${x_scale(d.x)}, ${y_scale(d.y)})`})
.attr("data-entry-index", function(d, i){return i})
circle_containers.append("circle").attr(circle_attrs);
circle_containers.append("text").text(function(d){return d.token_text;})
.attr('dy', -16).style("text-anchor", "middle").style("font-size", "16px")
circle_containers.append("text").text(function(d){return d.circle_text;})
.attr('dy', "0.3em").style("text-anchor", "middle")
.style("font-size", "14px").style("font-weight", "bold")
circle_containers
.on("mouseover", function(d, i) {d3.select(this).style("font-weight", "bold");})
.on("mouseout", function(d, i) {d3.select(this).style("font-weight", "normal");})
.on("click", function(d, i) {callback_on_click(parseInt(this.dataset.entryIndex))});
// note: this.dataset.elemIndex will point to element's "data-entry-index" property that we set above
}
</script>
<script> // this script defines the visual interface components
var prompt_screen = document.getElementById('lm_prompt_screen');
var prompt_input = prompt_screen.getElementsByClassName('input_prompt')[0];
var prompt_next_button = prompt_screen.getElementsByClassName('next')[0];
var prompt_status = prompt_screen.getElementsByClassName('status_comment')[0];
var generate_screen = document.getElementById('lm_generate_screen');
var generated_text = generate_screen.getElementsByClassName("input_preprocessed")[0];
// generator state: tokens and RNN states for all tokens that are in the textarea
var current_tokens = [], current_lstm_states = []; // should be of same length
var beam_search_vertices; // current beam search tree from which user can choose
// load model, then enable the "next" button
document.addEventListener('DOMContentLoaded', async function() {
await async_initialize_model();
prompt_next_button.style.opacity = 1.0;
prompt_next_button.onclick = on_prompt_next
prompt_status.innerHTML = `Just kidding, you can leave it blank. The machine doesn't need you, meatbag.`
});
async function on_prompt_next() {
prompt_status.innerHTML = "Processing...";
bootstrap_from_prefix(prompt_input.value);
generate_screen.getElementsByClassName("button_add_greedy")[0].onclick = add_token_greedy;
generate_screen.getElementsByClassName("button_backspace")[0].onclick = remove_last_token;
generate_screen.getElementsByClassName("button_restart")[0].onclick = reset_to_prompt_screen;
prompt_screen.style.display = 'none';
generate_screen.style.display = 'block';
}
async function reset_to_prompt_screen() {
prompt_status.innerHTML = "Another prompt, perhaps?";
current_tokens = current_lstm_states = [];
beam_search_vertices = undefined;
prompt_screen.style.display = 'block';
generate_screen.style.display = 'none';
}
async function bootstrap_from_prefix(prefix) {
// tokenize, bpeize and lstm-ize prefix string,
// populate current_tokens, token_indices and current_lstm_states
current_tokens = Array.prototype.concat([bos], preprocess(prefix));
generated_text.value = current_tokens.slice(1, current_tokens.length).join(" ")
await sleep(50);
generated_text.style.height = Math.max(Math.round(generated_text.scrollHeight), 64) + 'px';
var token_indices = current_tokens.map(tok => (tok in token_to_ix ? token_to_ix[tok] : token_to_ix[unk]));
var [prev_logits, prev_hid, prev_cell] = get_initial_state();
for (let t = 0; t < token_indices.length; t++) {
token_ix = tf.tensor([token_indices[t]], null, "int32").reshape([1, -1]);
[prev_logits, prev_hid, prev_cell] = model.predict([token_ix, prev_hid, prev_cell]);
current_lstm_states.push([prev_logits, prev_hid, prev_cell]);
}
await redraw_user_interface();
}
async function redraw_user_interface() {
generated_text.value = current_tokens.slice(1, current_tokens.length).join(" ")
generated_text.scrollTop = generated_text.scrollHeight;
var [prev_logits, prev_hid, prev_cell] = current_lstm_states[current_lstm_states.length - 1];
beam_search_vertices = await generate_beam_search_vertices([prev_logits, prev_hid, prev_cell]);
draw_beam_search_vertices(beam_search_vertices, callback_on_click=add_tokens_from_chosen_vertex);
}
async function add_token_greedy() {
var [prev_logits, prev_hid, prev_cell] = current_lstm_states[current_lstm_states.length - 1];
var token_ix = prev_logits.argMax(-1);
current_lstm_states.push(model.predict([token_ix.reshape([1, -1]), prev_hid, prev_cell]));
current_tokens.push(vocab[(await token_ix.array())[0]]);
await redraw_user_interface();
}
async function remove_last_token() {
if (current_tokens.length <= 1) return;
current_tokens.pop();
current_lstm_states.pop();
await redraw_user_interface();
}
async function add_tokens_from_chosen_vertex(chosen_index) {
var current_vertex = beam_search_vertices[chosen_index];
var additional_tokens = [];
while (current_vertex.parent_index != -1) {
additional_tokens.push(current_vertex.token_text);
current_vertex = beam_search_vertices[current_vertex.parent_index];
}
additional_tokens = additional_tokens.reverse();
var token_indices = additional_tokens.map(tok => (tok in token_to_ix ? token_to_ix[tok] : token_to_ix[unk]));
var [prev_logits, prev_hid, prev_cell] = current_lstm_states[current_lstm_states.length - 1];
for (let t = 0; t < token_indices.length; t++) {
current_tokens.push(additional_tokens[t]);
token_ix = tf.tensor([token_indices[t]], null, "int32").reshape([1, -1]);
[prev_logits, prev_hid, prev_cell] = model.predict([token_ix, prev_hid, prev_cell]);
current_lstm_states.push([prev_logits, prev_hid, prev_cell]);
}
await redraw_user_interface();
}
</script>
</body>
</html>