-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainValueAgent.py
170 lines (140 loc) · 7.05 KB
/
trainValueAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# mainExperiment at checkpoint 1
import sys
import os
# add path that contains the dominoes package
mainPath = os.path.dirname(os.path.abspath(__file__)) + "/.."
sys.path.append(mainPath)
# standard imports
from copy import copy
import argparse
from pathlib import Path
from tqdm import tqdm
import numpy as np
from scipy.signal import savgol_filter
import torch.cuda as torchCuda
import matplotlib.pyplot as plt
# dominoes package
from dominoes import gameplay as dg
from dominoes import agents as da
valueAgents = {
'basicValueAgent':da.basicValueAgent,
'lineValueAgent':da.lineValueAgent,
'lineValueAgentSmall':da.lineValueAgentSmall
}
opponents = {
'dominoeAgent':da.dominoeAgent,
'greedyAgent':da.greedyAgent,
'stupidAgent':da.stupidAgent,
'doubleAgent':da.doubleAgent,
'persistentLineAgent':da.persistentLineAgent
}
# can edit this for each machine it's being used on
savePath = Path('.') / 'experiments' / 'savedNetworks'
resPath = Path('.') / 'experiments' / 'savedResults'
prmsPath = Path('.') / 'experiments' / 'savedParameters'
figsPath = Path(mainPath) / 'docs' / 'media'
for path in (resPath, prmsPath, figsPath, savePath):
if not(path.exists()):
path.mkdir()
device = "cuda" if torchCuda.is_available() else "cpu"
print(f"Using device: {device}")
def parseArgs():
parser = argparse.ArgumentParser(description='Run dominoes experiment.')
parser.add_argument('-n','--num-players', type=int, default=4, help='the number of agents in the game of dominoes')
parser.add_argument('-hd','--highest-dominoe', type=int, default=9, help='the highest dominoe in the board')
parser.add_argument('-s','--shuffle-agents', type=bool, default=True, help='whether to shuffle the order of the agents each hand')
parser.add_argument('-tg','--train-games',type=int, default=3000, help='the number of training games')
parser.add_argument('-tr','--train-rounds',type=int, default=None, help='the number of training rounds')
parser.add_argument('-op','--opponent',type=str, default='dominoeAgent', help='which opponent to play the basic value agent against for training and testing')
parser.add_argument('-va','--value-agent',type=str, default='basicValueAgent', help='which value agent to use')
parser.add_argument('--noreplay',default=False,action='store_true', help='if used, will turn off replay in the value agents')
parser.add_argument('--justplot',default=False,action='store_true', help='if used, will only plot the saved results (results have to already have been run and saved)')
parser.add_argument('--nosave',default=False,action='store_true')
args = parser.parse_args()
assert args.value_agent in valueAgents.keys(), f"requested value agent ({args.value_agent}) is not in the list of possibilities!"
assert args.opponent in opponents.keys(), f"requested opponent ({args.opponent}) is not in the list of possible opponents!"
return args
# method for returning the name of the saved network parameters (different save for each possible opponent)
def getFileName():
replayString = '' if args.noreplay else 'withReplay_'
return f"trainValueAgent_{args.value_agent}_"+replayString+f"against_{args.opponent}"
# method for training agent
def trainValueAgent(numPlayers, highestDominoe, shuffleAgents, trainGames, trainRounds):
# open game with basic value agent playing against three default dominoe agents
agents=(valueAgents[args.value_agent], None, None, None)
game = dg.dominoeGame(highestDominoe, numPlayers=numPlayers, shuffleAgents=shuffleAgents, agents=agents, defaultAgent=opponents[args.opponent], device=device)
game.getAgent(0).setLearning(True)
game.getAgent(0).setReplay(not(args.noreplay))
# run training rounds
trainWinnerCount = np.zeros(numPlayers)
trainHandWinnerCount = np.zeros((trainGames,numPlayers))
trainScoreTally = np.zeros((trainGames,numPlayers))
for gameIdx in tqdm(range(trainGames)):
game.playGame(rounds=trainRounds)
trainWinnerCount[game.currentWinner] += 1
trainHandWinnerCount[gameIdx] += np.sum(game.score==0,axis=0)
trainScoreTally[gameIdx] += game.currentScore
results = {
'trainWinnerCount':trainWinnerCount,
'trainHandWinnerCount':trainHandWinnerCount,
'trainScoreTally':trainScoreTally,
}
# save results if requested
if not(args.nosave):
# Save agent parameters
description = f"{args.value_agent} trained against {args.opponent}"
fullSavePath = game.getAgent(0).saveAgentParameters(savePath, modelName=getFileName(), description=description)
np.save(prmsPath / getFileName(), vars(args))
np.save(resPath / getFileName(), results)
# return model and results for plotting
return results
# And a function for plotting results
def plotResults(results):
filter = lambda x : savgol_filter(x, 20, 1)
trainRounds = args.train_rounds if args.train_rounds is not None else highestDominoe+1
fig,ax = plt.subplots(1,2,figsize=(8,4))
ax[0].plot(range(args.train_games),
filter(results['trainScoreTally'][:,0]/trainRounds),
c='b', label=args.value_agent)
ax[0].plot(range(args.train_games),
filter(np.mean(results['trainScoreTally'][:,1:],axis=1)/trainRounds),
c='k', label=f"{args.opponent}")
ax[0].set_ylim(0)
ax[0].set_xlabel('Training Games')
ax[0].set_ylabel('Training Score Per Hand')
ax[0].legend(loc='best')
ax[1].plot(range(args.train_games),
filter(results['trainHandWinnerCount'][:,0]),
c='b', label=args.value_agent)
ax[1].plot(range(args.train_games),
filter(np.mean(results['trainHandWinnerCount'][:,1:],axis=1)),
c='k', label=f"{args.opponent}")
ax[1].set_ylim(0)
ax[1].set_xlabel('Training Games')
ax[1].set_ylabel('Training Num Won Hands')
ax[1].legend(loc='best')
if not(args.nosave):
plt.savefig(str(figsPath/getFileName()))
plt.show()
# Main script
if __name__=='__main__':
args = parseArgs()
# Sorry for my improper style
numPlayers = args.num_players
highestDominoe = args.highest_dominoe
shuffleAgents = args.shuffle_agents
trainGames = args.train_games
trainRounds = args.train_rounds if args.train_rounds is not None else highestDominoe+1
# if just plotting, load data. Otherwise, run training and testing
if not(args.justplot):
results = trainValueAgent(numPlayers, highestDominoe, shuffleAgents, trainGames, trainRounds)
else:
print("Need to check if args match saved args!!!")
results = np.load(resPath / (getFileName()+'.npy'), allow_pickle=True).item()
# Print results of experiment
print("Train winner count: ", results['trainWinnerCount'])
tenPercent = int(np.ceil(trainGames*0.1))
avgScore = np.round(np.mean(results['trainScoreTally'][-tenPercent:]/trainRounds,axis=0),1)
print(f"Average score per round in last 10% of training: {avgScore}")
# Plot results of experiment (and save if requested)
plotResults(results)