forked from suragnair/alpha-zero-general
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNeuralNet.py
50 lines (42 loc) · 1.52 KB
/
NeuralNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class NeuralNet():
"""
This class specifies the base NeuralNet class. To define your own neural
network, subclass this class and implement the functions below. The neural
network does not consider the current player, and instead only deals with
the canonical form of the board.
See othello/NNet.py for an example implementation.
"""
def __init__(self, game):
pass
def train(self, examples):
"""
This function trains the neural network with examples obtained from
self-play.
Input:
examples: a list of training examples, where each example is of form
(board, pi, v). pi is the MCTS informed policy vector for
the given board, and v is its value. The examples has
board in its canonical form.
"""
pass
def predict(self, board):
"""
Input:
board: current board in its canonical form.
Returns:
pi: a policy vector for the current board- a numpy array of length
game.getActionSize
v: a float in [-1,1] that gives the value of the current board
"""
pass
def save_checkpoint(self, folder, filename):
"""
Saves the current neural network (with its parameters) in
folder/filename
"""
pass
def load_checkpoint(self, folder, filename):
"""
Loads parameters of the neural network from folder/filename
"""
pass