diff --git a/custom_trainers/mk2_trainer.py b/custom_trainers/mk2_trainer.py index 9ef89c5..4867176 100644 --- a/custom_trainers/mk2_trainer.py +++ b/custom_trainers/mk2_trainer.py @@ -13,6 +13,7 @@ import argparse import logging import numpy as np +import gc from model_trainer import ModelTrainer from model_vs_game import ModelVsGame @@ -33,7 +34,7 @@ def parse_cmdline(argv): parser.add_argument('--state', type=str, default=None) parser.add_argument('--num_players', type=int, default='1') parser.add_argument('--num_env', type=int, default=16) - parser.add_argument('--num_timesteps', type=int, default=30_000) + parser.add_argument('--num_timesteps', type=int, default=10_000_000) parser.add_argument('--output_basedir', type=str, default='~/OUTPUT') parser.add_argument('--load_p1_model', type=str, default='') parser.add_argument('--model_1', type=str, default='') @@ -55,7 +56,7 @@ def parse_cmdline(argv): game_states= [ 'LiuKangVsBaraka_VeryHard_01', 'LiuKangVsReptile_VeryHard_02', - #'LiuKangVsJax_VeryHard_03', + 'LiuKangVsJax_VeryHard_03', #'LiuKangVsRayden_VeryHard_04', #'LiuKangVsKitana_VeryHard_05', #'LiuKangVsLiuKang_VeryHard_06', @@ -111,6 +112,8 @@ def main(argv): trainer = ModelTrainer(args, logger) p1_model_path = trainer.train() + gc.collect() + # Test model performance #num_test_matchs = NUM_TEST_MATCHS #new_args = args @@ -133,9 +136,10 @@ def main(argv): won_matchs, total_reward = test_model(new_args, num_test_matchs, logger) percentage = won_matchs / num_test_matchs com_print('STATE:%s... WON MATCHS:%d/%d TOTAL REWARDS:%d' % (state, won_matchs, num_test_matchs, total_reward)) + gc.collect() if args.play: - args.state = 'LiuKangVsRayden_VeryHard_04' + args.state = 'LiuKangVsRayden_VeryHard_01' args.model_1 = p1_model_path args.model_2 = '' args.num_timesteps = 0 diff --git a/model_vs_game.py b/model_vs_game.py index bf3cb1b..4aa23c2 100644 --- a/model_vs_game.py +++ b/model_vs_game.py @@ -47,12 +47,7 @@ class ModelVsGame: def __init__(self, args, logger, need_display=True): self.p1_env = init_env(None, 1, args.state, 1, args, True) - print('================================ init_play_env =================================') - traceback.print_stack(file=sys.stdout) - print(need_display) self.display_env = init_play_env(args, 1, False, need_display, False) - print(self.display_env) - print('=================================================================') self.ai_sys = games.wrappers.ai_sys(args, self.p1_env, logger) if args.model_1 != '' or args.model_2 != '':