forked from fairy-stockfish/variant-nnue-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
96 lines (79 loc) · 5.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import model as M
import nnue_dataset
import pytorch_lightning as pl
import features
import os
import torch
from torch import set_num_threads as t_set_num_threads
from pytorch_lightning import loggers as pl_loggers
from torch.utils.data import DataLoader, Dataset
def make_data_loaders(train_filename, val_filename, feature_set, num_workers, batch_size, filtered, random_fen_skipping, main_device, epoch_size, val_size):
features_name = feature_set.name
train_infinite = nnue_dataset.SparseBatchDataset(features_name, train_filename, batch_size, num_workers=num_workers,
filtered=filtered, random_fen_skipping=random_fen_skipping, device=main_device)
val_infinite = nnue_dataset.SparseBatchDataset(features_name, val_filename, batch_size, filtered=filtered,
random_fen_skipping=random_fen_skipping, device=main_device)
# num_workers has to be 0 for sparse, and 1 for dense
# it currently cannot work in parallel mode but it shouldn't need to
train = DataLoader(nnue_dataset.FixedNumBatchesDataset(train_infinite, (epoch_size + batch_size - 1) // batch_size), batch_size=None, batch_sampler=None)
val = DataLoader(nnue_dataset.FixedNumBatchesDataset(val_infinite, (val_size + batch_size - 1) // batch_size), batch_size=None, batch_sampler=None)
return train, val
def main():
parser = argparse.ArgumentParser(description="Trains the network.")
parser.add_argument("train", help="Training data (.bin or .binpack)")
parser.add_argument("val", help="Validation data (.bin or .binpack)")
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--lambda", default=1.0, type=float, dest='lambda_', help="lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0).")
parser.add_argument("--num-workers", default=1, type=int, dest='num_workers', help="Number of worker threads to use for data loading. Currently only works well for binpack.")
parser.add_argument("--batch-size", default=-1, type=int, dest='batch_size', help="Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128.")
parser.add_argument("--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .")
parser.add_argument("--seed", default=42, type=int, dest='seed', help="torch seed to use.")
parser.add_argument("--smart-fen-skipping", action='store_true', dest='smart_fen_skipping_deprecated', help="If enabled positions that are bad training targets will be skipped during loading. Default: True, kept for backwards compatibility. This option is ignored")
parser.add_argument("--no-smart-fen-skipping", action='store_true', dest='no_smart_fen_skipping', help="If used then no smart fen skipping will be done. By default smart fen skipping is done.")
parser.add_argument("--random-fen-skipping", default=3, type=int, dest='random_fen_skipping', help="skip fens randomly on average random_fen_skipping before using one.")
parser.add_argument("--resume-from-model", dest='resume_from_model', help="Initializes training using the weights from the given .pt model")
parser.add_argument("--epoch-size", type=int, default=20000000, dest='epoch_size', help="Number of positions per epoch.")
parser.add_argument("--validation-size", type=int, default=1000000, dest='validation_size', help="Number of positions per validation step.")
features.add_argparse_args(parser)
args = parser.parse_args()
if not os.path.exists(args.train):
raise Exception('{0} does not exist'.format(args.train))
if not os.path.exists(args.val):
raise Exception('{0} does not exist'.format(args.val))
feature_set = features.get_feature_set_from_name(args.features)
if args.resume_from_model is None:
nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_)
nnue.cuda()
else:
nnue = torch.load(args.resume_from_model)
nnue.set_feature_set(feature_set)
nnue.lambda_ = args.lambda_
nnue.cuda()
print("Feature set: {}".format(feature_set.name))
print("Num real features: {}".format(feature_set.num_real_features))
print("Num virtual features: {}".format(feature_set.num_virtual_features))
print("Num features: {}".format(feature_set.num_features))
print("Training with {} validating with {}".format(args.train, args.val))
pl.seed_everything(args.seed)
print("Seed {}".format(args.seed))
batch_size = args.batch_size
if batch_size <= 0:
batch_size = 16384
print('Using batch size {}'.format(batch_size))
print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
print('Random fen skipping: {}'.format(args.random_fen_skipping))
if args.threads > 0:
print('limiting torch to {} threads.'.format(args.threads))
t_set_num_threads(args.threads)
logdir = args.default_root_dir if args.default_root_dir else 'logs/'
print('Using log dir {}'.format(logdir), flush=True)
tb_logger = pl_loggers.TensorBoardLogger(logdir)
checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True, period=1, save_top_k=-1)
trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=tb_logger)
main_device = trainer.root_device if trainer.root_gpu is None else 'cuda:' + str(trainer.root_gpu)
print('Using c++ data loader')
train, val = make_data_loaders(args.train, args.val, feature_set, args.num_workers, batch_size, not args.no_smart_fen_skipping, args.random_fen_skipping, main_device, args.epoch_size, args.validation_size)
trainer.fit(nnue, train, val)
if __name__ == '__main__':
main()