-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathenv.py
103 lines (92 loc) · 3.99 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
# MIT License
#
# Copyright (c) 2017 Kai Arulkumaran
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# ==============================================================================
from collections import deque
import random
import atari_py
import cv2
import torch
class Env():
def __init__(self, args):
self.device = args.device
self.ale = atari_py.ALEInterface()
self.ale.setInt('random_seed', args.seed)
self.ale.setInt('max_num_frames_per_episode', args.max_episode_length)
self.ale.setFloat('repeat_action_probability', 0) # Disable sticky actions
self.ale.setInt('frame_skip', 0)
self.ale.setBool('color_averaging', False)
self.ale.loadROM(atari_py.get_game_path(args.game)) # ROM loading must be done after setting options
actions = self.ale.getMinimalActionSet()
self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions))
self.lives = 0 # Life counter (used in DeepMind training)
self.life_termination = False # Used to check if resetting only from loss of life
self.window = args.history_length # Number of frames to concatenate
self.state_buffer = deque([], maxlen=args.history_length)
self.training = True # Consistent with model training mode
def _get_state(self):
state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR)
return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255)
def _reset_buffer(self):
for _ in range(self.window):
self.state_buffer.append(torch.zeros(84, 84, device=self.device))
def reset(self):
if self.life_termination:
self.life_termination = False # Reset flag
self.ale.act(0) # Use a no-op after loss of life
else:
# Reset internals
self._reset_buffer()
self.ale.reset_game()
# Perform up to 30 random no-ops before starting
for _ in range(random.randrange(30)):
self.ale.act(0) # Assumes raw action 0 is always no-op
if self.ale.game_over():
self.ale.reset_game()
# Process and return "initial" state
observation = self._get_state()
self.state_buffer.append(observation)
self.lives = self.ale.lives()
return torch.stack(list(self.state_buffer), 0)
def step(self, action):
# Repeat action 4 times, max pool over last 2 frames
frame_buffer = torch.zeros(2, 84, 84, device=self.device)
reward, done = 0, False
for t in range(4):
reward += self.ale.act(self.actions.get(action))
if t == 2:
frame_buffer[0] = self._get_state()
elif t == 3:
frame_buffer[1] = self._get_state()
done = self.ale.game_over()
if done:
break
observation = frame_buffer.max(0)[0]
self.state_buffer.append(observation)
# Detect loss of life as terminal in training mode
if self.training:
lives = self.ale.lives()
if lives < self.lives and lives > 0: # Lives > 0 for Q*bert
self.life_termination = not done # Only set flag when not truly done
done = True
self.lives = lives
# Return state, reward, done
return torch.stack(list(self.state_buffer), 0), reward, done
# Uses loss of life as terminal signal
def train(self):
self.training = True
# Uses standard terminal signal
def eval(self):
self.training = False
def action_space(self):
return len(self.actions)
def render(self):
cv2.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1])
cv2.waitKey(1)
def close(self):
cv2.destroyAllWindows()