-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.m
34 lines (29 loc) · 914 Bytes
/
generate.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
% The GPT tokens is counted from 0, but matlab counts from 1,
% so the input shuold be added by 1
Script_input = inputTokens + 1;
% Token embedding
coded = TokenEmbedding(wte_weight, Script_input);
% Positional embedding
poscode = PostionEmbedding(wpe_weight, length(Script_input));
% transformer input is Token Embedding plus Postional Embedding
final_input = coded + poscode;
clear coded
clear poscode
% Decoder: a series of transformers, each has it's own weights
temp = final_input;
for i = 1:NUM_LAYERS
% find the weight variable name
weight_name = sprintf("weights_layer_%02d", i - 1);
% do transformer block
temp = Block(eval(weight_name), temp);
if (i == 3)
outof1 = temp;
end
end
temp = LayerNorm(ln_f_weight, ln_f_bias, temp);
final_output = temp * lm_head_weight';
logi = final_output(end, :);
%%
prob = softmax(logi');
[~, idx] = max(prob);
Script_output = idx - 1;