-
Notifications
You must be signed in to change notification settings - Fork 4
/
generate_random_boards_to_hdf5
executable file
·171 lines (147 loc) · 5.28 KB
/
generate_random_boards_to_hdf5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# This file is part of the gunyanza.
# Copyright (C) 2015- Tasuku SUENAGA <[email protected]>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals, print_function
import shogi
import numpy
import sys
import os
import multiprocessing
import itertools
import random
import h5py
import json
from utils import board2arrays
import sys
if not hasattr(sys.stdout, 'buffer'):
import locale
import codecs
sys.stdout = codecs.getwriter(locale.getpreferredencoding())(sys.stdout)
MOVE_COUNT_THRESHOLD = 35
# TODO: 終盤付近を重めにサンプリングすべきでは?
SAMPLE_POSITIONS_PER_GAME = 30
SAMPLE_MOVES_PER_POSITION = 10
DISABLE_MULTI_PROCESSING = False
def read_games(fn):
with open(fn) as f:
line_no = 1
for line in f:
try:
match = json.loads(line, 'utf-8')
except KeyboardInterrupt:
raise
except:
print('Cannot load line %d' % line_no)
continue
if not match:
break
yield match
line_no += 1
def sample_positions(g):
rm = {'b': 1, 'w': -1, '-': 0}
r = g['win']
if r not in rm:
return
y = rm[r]
# print('result:', y, file=sys.stderr)
# 盤面の生成
board = shogi.Board(g['sfen'])
# FIXME: check board is end state or not
boards = []
moves_left = len(g['moves'])
parent_board_sfen = None
last_move = None
for move_str in g['moves']:
board_sfen = board.sfen()
boards.append((moves_left, board_sfen, board.turn == shogi.BLACK, parent_board_sfen, last_move))
move = shogi.Move.from_usi(move_str)
board.push(move)
parent_board_sfen = board_sfen
last_move = move
moves_left -= 1
if len(boards) < MOVE_COUNT_THRESHOLD or len(boards) < SAMPLE_POSITIONS_PER_GAME:
return
boards.pop(0) # 初期状態を削除
yield_counter = 0
for moves_left, board_sfen, flip, parent_board_sfen, last_move in random.sample(boards, SAMPLE_POSITIONS_PER_GAME):
board = shogi.Board(board_sfen)
b_parent = shogi.Board(parent_board_sfen)
x = board2arrays(board, flip=flip)
x_parent = board2arrays(b_parent, flip=(not flip))
if flip:
y = -y
# ランダムな手を選択し、そのboardを得る
moves = list(b_parent.legal_moves)
if len(moves) == 1:
# 無限ループよけ
continue
for move in random.sample(moves, min(len(moves), SAMPLE_MOVES_PER_POSITION)):
# もとの手と違った手でないといけない
if move == last_move:
continue
b_parent.push(move)
x_random = board2arrays(b_parent, flip=flip)
b_parent.pop()
yield (x, x_parent, x_random, moves_left, y)
def read_all_games(fn_in, fn_out):
g = h5py.File(fn_out, 'w')
# board: 9x9 mochigoma: 7x2
X, Xr, Xp = [g.create_dataset(d, (0, 9 * 9 + 7 * 2), dtype='b', maxshape=(None, 9 * 9 + 7 * 2), chunks=True) for d in ['x', 'xr', 'xp']]
Y, M = [g.create_dataset(d, (0,), dtype='b', maxshape=(None,), chunks=True) for d in ['y', 'm']]
size = 0
index = 0
for game in read_games(fn_in):
for x, x_parent, x_random, moves_left, y in sample_positions(game):
if index + 1 >= size:
g.flush()
size = 2 * size + 1
print('resizing to', size)
[d.resize(size=size, axis=0) for d in (X, Xr, Xp, Y, M)]
X[index] = x
Xr[index] = x_random
Xp[index] = x_parent
Y[index] = y
M[index] = moves_left
index += 1
[d.resize(size=index, axis=0) for d in (X, Xr, Xp, Y, M)]
g.close()
def read_all_games_2(a):
return read_all_games(*a)
def parse_dir(input_dir, output_dir):
files = []
for fn_in in os.listdir(input_dir):
if not fn_in.endswith('.jsons'):
continue
fn_out = os.path.join(output_dir, fn_in.replace('.jsons', '.hdf5'))
fn_in = os.path.join(input_dir, fn_in)
print(fn_in, fn_out)
if os.path.exists(fn_out):
print('Output file already exists', file=sys.stderr)
else:
files.append((fn_in, fn_out))
if DISABLE_MULTI_PROCESSING:
for fn_in, fn_out in files:
read_all_games(fn_in, fn_out)
else:
pool = multiprocessing.Pool()
pool.map(read_all_games_2, files)
if __name__ == '__main__':
if len(sys.argv) != 3:
print('Usage: {0} input_dir output_dir'.format(sys.argv[0]))
else:
parse_dir(sys.argv[1], sys.argv[2])