-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7cdfe5d
commit 9add70f
Showing
7 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import torch | ||
import random | ||
import numpy as np | ||
from collections import deque | ||
from game import SnakeGameAI, Direction, Point | ||
from model import Linear_QNet, QTrainer | ||
from helper import plot | ||
import joblib | ||
import os | ||
from pathlib import Path | ||
|
||
MAX_MEMORY = 100_000 | ||
BATCH_SIZE = 1000 | ||
LR = 0.001 | ||
|
||
class Agent: | ||
|
||
def __init__(self): | ||
self.n_games = 0 | ||
self.epsilon = 0 | ||
self.gamma = 0.9 | ||
self.memory = deque(maxlen=MAX_MEMORY) # | ||
self.model = Linear_QNet(11, 512, 3) | ||
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma) | ||
|
||
|
||
def get_state(self, game): | ||
head = game.snake[0] | ||
point_l = Point(head.x - 20, head.y) | ||
point_r = Point(head.x + 20, head.y) | ||
point_u = Point(head.x, head.y - 20) | ||
point_d = Point(head.x, head.y + 20) | ||
|
||
dir_l = game.direction == Direction.LEFT | ||
dir_r = game.direction == Direction.RIGHT | ||
dir_u = game.direction == Direction.UP | ||
dir_d = game.direction == Direction.DOWN | ||
|
||
state = [ | ||
|
||
(dir_r and game.is_collision(point_r)) or | ||
(dir_l and game.is_collision(point_l)) or | ||
(dir_u and game.is_collision(point_u)) or | ||
(dir_d and game.is_collision(point_d)), | ||
|
||
(dir_u and game.is_collision(point_r)) or | ||
(dir_d and game.is_collision(point_l)) or | ||
(dir_l and game.is_collision(point_u)) or | ||
(dir_r and game.is_collision(point_d)), | ||
|
||
(dir_d and game.is_collision(point_r)) or | ||
(dir_u and game.is_collision(point_l)) or | ||
(dir_r and game.is_collision(point_u)) or | ||
(dir_l and game.is_collision(point_d)), | ||
|
||
dir_l, | ||
dir_r, | ||
dir_u, | ||
dir_d, | ||
|
||
game.food.x < game.head.x, | ||
game.food.x > game.head.x, | ||
game.food.y < game.head.y, | ||
game.food.y > game.head.y | ||
] | ||
|
||
return np.array(state, dtype=int) | ||
|
||
def remember(self, state, action, reward, next_state, done): | ||
self.memory.append((state, action, reward, next_state, done)) | ||
|
||
def train_long_memory(self): | ||
if len(self.memory) > BATCH_SIZE: | ||
mini_sample = random.sample(self.memory, BATCH_SIZE) | ||
else: | ||
mini_sample = self.memory | ||
states, actions, rewards, next_states, dones = zip(*mini_sample) | ||
self.trainer.train_step(states, actions, rewards, next_states, dones) | ||
|
||
def train_short_memory(self, state, action, reward, next_state, done): | ||
self.trainer.train_step(state, action, reward, next_state, done) | ||
|
||
def get_action(self, state): | ||
self.epsilon = 80 - self.n_games | ||
final_move = [0,0,0] | ||
if random.randint(0, 200) < self.epsilon: | ||
move = random.randint(0, 2) | ||
final_move[move] = 1 | ||
else: | ||
state0 = torch.tensor(state, dtype=torch.float) | ||
prediction = self.model(state0) | ||
move = torch.argmax(prediction).item() | ||
final_move[move] = 1 | ||
|
||
return final_move | ||
|
||
|
||
def train(): | ||
|
||
plot_scores = [] | ||
plot_mean_scores = [] | ||
total_score = 0 | ||
record = 0 | ||
agent = Agent() | ||
ans = int(input("Train or Load Model [0,1]")) | ||
|
||
while ans not in [0,1]: | ||
print("invalid input") | ||
ans = int(input("Train or Load Model [0,1]")) | ||
|
||
if ans == 1: | ||
agent = joblib.load((os.fspath(Path(__file__).resolve().parent / "best.pkl"))) | ||
|
||
game = SnakeGameAI() | ||
|
||
while True: | ||
if game.save == True: | ||
joblib.dump(agent, (os.fspath(Path(__file__).resolve().parent / "model.pkl"))) | ||
state_old = agent.get_state(game) | ||
final_move = agent.get_action(state_old) | ||
reward, done, score = game.play_step(final_move) | ||
state_new = agent.get_state(game) | ||
agent.train_short_memory(state_old, final_move, reward, state_new, done) | ||
agent.remember(state_old, final_move, reward, state_new, done) | ||
|
||
if done: | ||
game.reset() | ||
agent.n_games += 1 | ||
agent.train_long_memory() | ||
|
||
if score > record: | ||
record = score | ||
joblib.dump(agent,(os.fspath(Path(__file__).resolve().parent / "model.pkl"))) | ||
agent.model.save() | ||
|
||
print('Game', agent.n_games, 'Score', score, 'Record:', record) | ||
|
||
plot_scores.append(score) | ||
total_score += score | ||
mean_score = total_score / agent.n_games | ||
plot_mean_scores.append(mean_score) | ||
plot(plot_scores, plot_mean_scores) | ||
|
||
|
||
if __name__ == '__main__': | ||
train() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import pygame | ||
import random | ||
from enum import Enum | ||
from collections import namedtuple | ||
import numpy as np | ||
|
||
|
||
pygame.init() | ||
|
||
font = pygame.font.SysFont('calibri', 20) | ||
|
||
class Direction(Enum): | ||
RIGHT = 1 | ||
LEFT = 2 | ||
UP = 3 | ||
DOWN = 4 | ||
|
||
Point = namedtuple('Point', 'x, y') | ||
|
||
WHITE = (255, 255, 255) | ||
RED = (248,143,147) | ||
GREEN = (186,217,181) | ||
FONT = (66,12,20) | ||
BACK = (239,247,207) | ||
|
||
BLOCK_SIZE = 20 | ||
SPEED = 20 | ||
|
||
class SnakeGameAI: | ||
|
||
def __init__(self, w=640, h=480): | ||
self.w = w | ||
self.h = h | ||
self.display = pygame.display.set_mode((self.w, self.h)) | ||
pygame.display.set_caption('Snake') | ||
self.clock = pygame.time.Clock() | ||
self.reset() | ||
self.save = False | ||
|
||
def reset(self): | ||
self.direction = Direction.RIGHT | ||
|
||
self.head = Point(self.w/2, self.h/2) | ||
self.snake = [self.head, | ||
Point(self.head.x-BLOCK_SIZE, self.head.y), | ||
Point(self.head.x-(2*BLOCK_SIZE), self.head.y)] | ||
|
||
self.score = 0 | ||
self.food = None | ||
self._place_food() | ||
self.loop_count = 0 | ||
|
||
def _place_food(self): | ||
x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE | ||
y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE | ||
self.food = Point(x, y) | ||
if self.food in self.snake: | ||
self._place_food() | ||
|
||
def play_step(self, action): | ||
self.loop_count += 1 | ||
for event in pygame.event.get(): | ||
if event.type == pygame.QUIT: | ||
pygame.quit() | ||
quit() | ||
keys_pressed = pygame.key.get_pressed() | ||
|
||
if keys_pressed[pygame.K_u]: | ||
print("saving") | ||
self.save = True | ||
self._move(action) | ||
self.snake.insert(0, self.head) | ||
reward = 0 | ||
game_over = False | ||
if self.is_collision() or self.loop_count > 100*len(self.snake): | ||
reward = -10 | ||
game_over = True | ||
return reward, game_over, self.score | ||
|
||
if self.head == self.food: | ||
self.score += 1 | ||
reward = 10 | ||
self._place_food() | ||
else: | ||
self.snake.pop() | ||
|
||
self._update_ui() | ||
self.clock.tick(SPEED) | ||
return reward,game_over, self.score | ||
|
||
def is_collision(self, pt=None): | ||
if pt is None: | ||
pt = self.head | ||
|
||
if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0: | ||
return True | ||
if pt in self.snake[1:]: | ||
return True | ||
|
||
return False | ||
|
||
def _update_ui(self): | ||
self.display.fill(BACK) | ||
|
||
for pt in self.snake: | ||
pygame.draw.rect(self.display, GREEN, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE)) | ||
|
||
pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE)) | ||
|
||
text = font.render( str(self.score), True, GREEN) | ||
self.display.blit(text, [320, 5]) | ||
pygame.display.flip() | ||
|
||
def _move(self, action): | ||
clock_wise = [Direction.RIGHT, Direction.DOWN,Direction.LEFT,Direction.UP] | ||
idx = clock_wise.index(self.direction) | ||
|
||
if np.array_equal(action, [1,0,0]): | ||
new_dir = clock_wise[idx] | ||
|
||
elif np.array_equal(action, [0,1,0]): | ||
next_idx = (idx+1) % 4 | ||
new_dir = clock_wise[next_idx] | ||
else: | ||
next_idx = (idx-1) % 4 | ||
new_dir = clock_wise[next_idx] | ||
self.direction = new_dir | ||
|
||
x = self.head.x | ||
y = self.head.y | ||
if self.direction == Direction.RIGHT: | ||
x += BLOCK_SIZE | ||
elif self.direction == Direction.LEFT: | ||
x -= BLOCK_SIZE | ||
elif self.direction == Direction.DOWN: | ||
y += BLOCK_SIZE | ||
elif self.direction == Direction.UP: | ||
y -= BLOCK_SIZE | ||
|
||
self.head = Point(x, y) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import matplotlib.pyplot as plt | ||
from IPython import display | ||
|
||
plt.ion() | ||
|
||
def plot(scores, mean_scores): | ||
display.clear_output(wait=True) | ||
display.display(plt.gcf()) | ||
plt.clf() | ||
plt.style.use('ggplot') | ||
plt.title('Training...') | ||
plt.xlabel('Number of Games') | ||
plt.ylabel('Score') | ||
plt.plot(scores) | ||
plt.plot(mean_scores,linestyle=":") | ||
plt.ylim(ymin=0) | ||
plt.text(len(scores)-1, scores[-1], str(scores[-1])) | ||
plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1])) | ||
plt.show(block=False) | ||
plt.pause(.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
import os | ||
from pathlib import Path | ||
|
||
class Linear_QNet(nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size): | ||
super().__init__() | ||
self.linear1 = nn.Linear(input_size, hidden_size) | ||
self.linear2 = nn.Linear(hidden_size, output_size) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.linear1(x)) | ||
x = self.linear2(x) | ||
return x | ||
|
||
def save(self, file_name='model.pth'): | ||
torch.save(self.state_dict(), (os.fspath(Path(__file__).resolve().parent / file_name))) | ||
|
||
def load(self): | ||
self = torch.load((os.fspath(Path(__file__).resolve().parent / 'model.pth'))) | ||
|
||
|
||
class QTrainer: | ||
def __init__(self, model, lr, gamma): | ||
self.lr = lr | ||
self.gamma = gamma | ||
self.model = model | ||
self.optimizer = optim.Adam(model.parameters(), lr=self.lr) | ||
self.criterion = nn.MSELoss() | ||
|
||
def train_step(self, state, action, reward, next_state, done): | ||
state = torch.tensor(state, dtype=torch.float) | ||
next_state = torch.tensor(next_state, dtype=torch.float) | ||
action = torch.tensor(action, dtype=torch.long) | ||
reward = torch.tensor(reward, dtype=torch.float) | ||
|
||
if len(state.shape) == 1: | ||
state = torch.unsqueeze(state, 0) | ||
next_state = torch.unsqueeze(next_state, 0) | ||
action = torch.unsqueeze(action, 0) | ||
reward = torch.unsqueeze(reward, 0) | ||
done = (done, ) | ||
pred = self.model(state) | ||
|
||
target = pred.clone() | ||
for idx in range(len(done)): | ||
Q_new = reward[idx] | ||
if not done[idx]: | ||
Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) | ||
|
||
target[idx][torch.argmax(action[idx]).item()] = Q_new | ||
|
||
self.optimizer.zero_grad() | ||
loss = self.criterion(target, pred) | ||
loss.backward() | ||
|
||
self.optimizer.step() | ||
|
||
|
||
|
||
|
||
|