Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinGurasvili authored Nov 10, 2022
1 parent 7cdfe5d commit 9add70f
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 0 deletions.
147 changes: 147 additions & 0 deletions Snake_AI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import torch
import random
import numpy as np
from collections import deque
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot
import joblib
import os
from pathlib import Path

MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001

class Agent:

def __init__(self):
self.n_games = 0
self.epsilon = 0
self.gamma = 0.9
self.memory = deque(maxlen=MAX_MEMORY) #
self.model = Linear_QNet(11, 512, 3)
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)


def get_state(self, game):
head = game.snake[0]
point_l = Point(head.x - 20, head.y)
point_r = Point(head.x + 20, head.y)
point_u = Point(head.x, head.y - 20)
point_d = Point(head.x, head.y + 20)

dir_l = game.direction == Direction.LEFT
dir_r = game.direction == Direction.RIGHT
dir_u = game.direction == Direction.UP
dir_d = game.direction == Direction.DOWN

state = [

(dir_r and game.is_collision(point_r)) or
(dir_l and game.is_collision(point_l)) or
(dir_u and game.is_collision(point_u)) or
(dir_d and game.is_collision(point_d)),

(dir_u and game.is_collision(point_r)) or
(dir_d and game.is_collision(point_l)) or
(dir_l and game.is_collision(point_u)) or
(dir_r and game.is_collision(point_d)),

(dir_d and game.is_collision(point_r)) or
(dir_u and game.is_collision(point_l)) or
(dir_r and game.is_collision(point_u)) or
(dir_l and game.is_collision(point_d)),

dir_l,
dir_r,
dir_u,
dir_d,

game.food.x < game.head.x,
game.food.x > game.head.x,
game.food.y < game.head.y,
game.food.y > game.head.y
]

return np.array(state, dtype=int)

def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))

def train_long_memory(self):
if len(self.memory) > BATCH_SIZE:
mini_sample = random.sample(self.memory, BATCH_SIZE)
else:
mini_sample = self.memory
states, actions, rewards, next_states, dones = zip(*mini_sample)
self.trainer.train_step(states, actions, rewards, next_states, dones)

def train_short_memory(self, state, action, reward, next_state, done):
self.trainer.train_step(state, action, reward, next_state, done)

def get_action(self, state):
self.epsilon = 80 - self.n_games
final_move = [0,0,0]
if random.randint(0, 200) < self.epsilon:
move = random.randint(0, 2)
final_move[move] = 1
else:
state0 = torch.tensor(state, dtype=torch.float)
prediction = self.model(state0)
move = torch.argmax(prediction).item()
final_move[move] = 1

return final_move


def train():

plot_scores = []
plot_mean_scores = []
total_score = 0
record = 0
agent = Agent()
ans = int(input("Train or Load Model [0,1]"))

while ans not in [0,1]:
print("invalid input")
ans = int(input("Train or Load Model [0,1]"))

if ans == 1:
agent = joblib.load((os.fspath(Path(__file__).resolve().parent / "best.pkl")))

game = SnakeGameAI()

while True:
if game.save == True:
joblib.dump(agent, (os.fspath(Path(__file__).resolve().parent / "model.pkl")))
state_old = agent.get_state(game)
final_move = agent.get_action(state_old)
reward, done, score = game.play_step(final_move)
state_new = agent.get_state(game)
agent.train_short_memory(state_old, final_move, reward, state_new, done)
agent.remember(state_old, final_move, reward, state_new, done)

if done:
game.reset()
agent.n_games += 1
agent.train_long_memory()

if score > record:
record = score
joblib.dump(agent,(os.fspath(Path(__file__).resolve().parent / "model.pkl")))
agent.model.save()

print('Game', agent.n_games, 'Score', score, 'Record:', record)

plot_scores.append(score)
total_score += score
mean_score = total_score / agent.n_games
plot_mean_scores.append(mean_score)
plot(plot_scores, plot_mean_scores)


if __name__ == '__main__':
train()

Binary file added best.pkl
Binary file not shown.
141 changes: 141 additions & 0 deletions game.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np


pygame.init()

font = pygame.font.SysFont('calibri', 20)

class Direction(Enum):
RIGHT = 1
LEFT = 2
UP = 3
DOWN = 4

Point = namedtuple('Point', 'x, y')

WHITE = (255, 255, 255)
RED = (248,143,147)
GREEN = (186,217,181)
FONT = (66,12,20)
BACK = (239,247,207)

BLOCK_SIZE = 20
SPEED = 20

class SnakeGameAI:

def __init__(self, w=640, h=480):
self.w = w
self.h = h
self.display = pygame.display.set_mode((self.w, self.h))
pygame.display.set_caption('Snake')
self.clock = pygame.time.Clock()
self.reset()
self.save = False

def reset(self):
self.direction = Direction.RIGHT

self.head = Point(self.w/2, self.h/2)
self.snake = [self.head,
Point(self.head.x-BLOCK_SIZE, self.head.y),
Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]

self.score = 0
self.food = None
self._place_food()
self.loop_count = 0

def _place_food(self):
x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
self.food = Point(x, y)
if self.food in self.snake:
self._place_food()

def play_step(self, action):
self.loop_count += 1
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
quit()
keys_pressed = pygame.key.get_pressed()

if keys_pressed[pygame.K_u]:
print("saving")
self.save = True
self._move(action)
self.snake.insert(0, self.head)
reward = 0
game_over = False
if self.is_collision() or self.loop_count > 100*len(self.snake):
reward = -10
game_over = True
return reward, game_over, self.score

if self.head == self.food:
self.score += 1
reward = 10
self._place_food()
else:
self.snake.pop()

self._update_ui()
self.clock.tick(SPEED)
return reward,game_over, self.score

def is_collision(self, pt=None):
if pt is None:
pt = self.head

if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
return True
if pt in self.snake[1:]:
return True

return False

def _update_ui(self):
self.display.fill(BACK)

for pt in self.snake:
pygame.draw.rect(self.display, GREEN, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))

pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))

text = font.render( str(self.score), True, GREEN)
self.display.blit(text, [320, 5])
pygame.display.flip()

def _move(self, action):
clock_wise = [Direction.RIGHT, Direction.DOWN,Direction.LEFT,Direction.UP]
idx = clock_wise.index(self.direction)

if np.array_equal(action, [1,0,0]):
new_dir = clock_wise[idx]

elif np.array_equal(action, [0,1,0]):
next_idx = (idx+1) % 4
new_dir = clock_wise[next_idx]
else:
next_idx = (idx-1) % 4
new_dir = clock_wise[next_idx]
self.direction = new_dir

x = self.head.x
y = self.head.y
if self.direction == Direction.RIGHT:
x += BLOCK_SIZE
elif self.direction == Direction.LEFT:
x -= BLOCK_SIZE
elif self.direction == Direction.DOWN:
y += BLOCK_SIZE
elif self.direction == Direction.UP:
y -= BLOCK_SIZE

self.head = Point(x, y)

20 changes: 20 additions & 0 deletions helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import matplotlib.pyplot as plt
from IPython import display

plt.ion()

def plot(scores, mean_scores):
display.clear_output(wait=True)
display.display(plt.gcf())
plt.clf()
plt.style.use('ggplot')
plt.title('Training...')
plt.xlabel('Number of Games')
plt.ylabel('Score')
plt.plot(scores)
plt.plot(mean_scores,linestyle=":")
plt.ylim(ymin=0)
plt.text(len(scores)-1, scores[-1], str(scores[-1]))
plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
plt.show(block=False)
plt.pause(.1)
Binary file added model.pkl
Binary file not shown.
Binary file added model.pth
Binary file not shown.
65 changes: 65 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from pathlib import Path

class Linear_QNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x

def save(self, file_name='model.pth'):
torch.save(self.state_dict(), (os.fspath(Path(__file__).resolve().parent / file_name)))

def load(self):
self = torch.load((os.fspath(Path(__file__).resolve().parent / 'model.pth')))


class QTrainer:
def __init__(self, model, lr, gamma):
self.lr = lr
self.gamma = gamma
self.model = model
self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
self.criterion = nn.MSELoss()

def train_step(self, state, action, reward, next_state, done):
state = torch.tensor(state, dtype=torch.float)
next_state = torch.tensor(next_state, dtype=torch.float)
action = torch.tensor(action, dtype=torch.long)
reward = torch.tensor(reward, dtype=torch.float)

if len(state.shape) == 1:
state = torch.unsqueeze(state, 0)
next_state = torch.unsqueeze(next_state, 0)
action = torch.unsqueeze(action, 0)
reward = torch.unsqueeze(reward, 0)
done = (done, )
pred = self.model(state)

target = pred.clone()
for idx in range(len(done)):
Q_new = reward[idx]
if not done[idx]:
Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

target[idx][torch.argmax(action[idx]).item()] = Q_new

self.optimizer.zero_grad()
loss = self.criterion(target, pred)
loss.backward()

self.optimizer.step()





0 comments on commit 9add70f

Please sign in to comment.