-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
55 lines (47 loc) · 1.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from pathlib import Path
import torch
import yaml
from data import get_encoder_decoder_fn, load_vocab
from model import TransformerDecoderModel
from util import get_device, load_model
if __name__ == "__main__":
# Load config and data
script_dir = Path(__file__).parent
with open(script_dir / "config.yaml", "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
seed = config["seed"]
lr = config["lr"]
batch_size = config["batch_size"]
block_size = config["block_size"]
n_heads = config["n_heads"]
head_size = config["head_size"]
n_layers = config["n_layers"]
dropout = config["dropout"]
embed_size = config["embed_size"]
save_path = config["save_path"]
vocab = load_vocab()
(
_,
decode,
) = get_encoder_decoder_fn(vocab)
# Load model
device = get_device()
model = TransformerDecoderModel(
vocab_size=len(vocab),
block_size=block_size,
n_layers=n_layers,
n_heads=n_heads,
head_size=head_size,
dropout=dropout,
embed_size=embed_size,
).to(device)
model = load_model(model, save_path, device)
# Generate text
model.eval()
with torch.inference_mode():
x = torch.zeros((1, 1), dtype=torch.long, device=device)
out = model.generate(x, 500)[0].tolist()
print("Generated text:")
print("=" * 20)
print(decode(out).strip())
print("=" * 20)