-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.py
65 lines (47 loc) · 2.35 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import chess
import torch
def board_to_input(board: chess.Board)->torch.Tensor:
"""
Coverts the given board position into the planes
required as input for the neural network.
"""
pieces = []
for color in chess.COLORS:
for piece_type in chess.PIECE_TYPES:
plane = torch.zeros(1, 8, 8)
for index in list(board.pieces(piece_type, color)):
plane[0][7 - index//8][index % 8] = True
pieces.append(plane)
turn = torch.ones(1, 8, 8) if board.turn else torch.zeros(1, 8, 8)
moves = torch.full((1, 8, 8), board.fullmove_number)
wk_castle = torch.ones(1, 8, 8) if board.has_kingside_castling_rights(chess.WHITE) else torch.zeros(1, 8, 8)
wq_castle = torch.ones(1, 8, 8) if board.has_queenside_castling_rights(chess.WHITE) else torch.zeros(1, 8, 8)
bk_castle = torch.ones(1, 8, 8) if board.has_kingside_castling_rights(chess.BLACK) else torch.zeros(1, 8, 8)
bq_castle = torch.ones(1, 8, 8) if board.has_queenside_castling_rights(chess.BLACK) else torch.zeros(1, 8, 8)
claim_draw = torch.ones(1, 8, 8) if board.can_claim_draw() else torch.zeros(1, 8, 8)
return torch.cat((*pieces, turn, moves, wk_castle, wq_castle, bk_castle, bq_castle, claim_draw), dim=0)
def move_to_output(move: chess.Move):
knight_moves = [-15, -17, -6, 10, 15, 17, 6, -10]
diff = move.to_square - move.from_square
index = 0
if move.promotion == chess.QUEEN or move.promotion is None:
if diff not in knight_moves:
if diff%8 == 0:
index = abs(diff)//8
elif move.to_square//8 == move.from_square//8:
index = 14 + abs(diff)
elif move.to_square%8 < move.from_square%8:
index = 28
dist = diff - (move.to_square%8 - move.from_square%8)
index += abs(dist)//8
else:
index = 42
dist = diff - (move.to_square%8 - move.from_square%8)
index += abs(dist)//8
if diff < 0:
index += 7
else:
index = 56 + knight_moves.index(diff) + 1
else:
index = 64 + 3*(move.promotion - 2) + (abs(diff) - 6)
return index-1