diff --git a/Snake_AI.py b/Snake_AI.py new file mode 100644 index 0000000..f759ac9 --- /dev/null +++ b/Snake_AI.py @@ -0,0 +1,147 @@ +import torch +import random +import numpy as np +from collections import deque +from game import SnakeGameAI, Direction, Point +from model import Linear_QNet, QTrainer +from helper import plot +import joblib +import os +from pathlib import Path + +MAX_MEMORY = 100_000 +BATCH_SIZE = 1000 +LR = 0.001 + +class Agent: + + def __init__(self): + self.n_games = 0 + self.epsilon = 0 + self.gamma = 0.9 + self.memory = deque(maxlen=MAX_MEMORY) # + self.model = Linear_QNet(11, 512, 3) + self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma) + + + def get_state(self, game): + head = game.snake[0] + point_l = Point(head.x - 20, head.y) + point_r = Point(head.x + 20, head.y) + point_u = Point(head.x, head.y - 20) + point_d = Point(head.x, head.y + 20) + + dir_l = game.direction == Direction.LEFT + dir_r = game.direction == Direction.RIGHT + dir_u = game.direction == Direction.UP + dir_d = game.direction == Direction.DOWN + + state = [ + + (dir_r and game.is_collision(point_r)) or + (dir_l and game.is_collision(point_l)) or + (dir_u and game.is_collision(point_u)) or + (dir_d and game.is_collision(point_d)), + + (dir_u and game.is_collision(point_r)) or + (dir_d and game.is_collision(point_l)) or + (dir_l and game.is_collision(point_u)) or + (dir_r and game.is_collision(point_d)), + + (dir_d and game.is_collision(point_r)) or + (dir_u and game.is_collision(point_l)) or + (dir_r and game.is_collision(point_u)) or + (dir_l and game.is_collision(point_d)), + + dir_l, + dir_r, + dir_u, + dir_d, + + game.food.x < game.head.x, + game.food.x > game.head.x, + game.food.y < game.head.y, + game.food.y > game.head.y + ] + + return np.array(state, dtype=int) + + def remember(self, state, action, reward, next_state, done): + self.memory.append((state, action, reward, next_state, done)) + + def train_long_memory(self): + if len(self.memory) > BATCH_SIZE: + mini_sample = random.sample(self.memory, BATCH_SIZE) + else: + mini_sample = self.memory + states, actions, rewards, next_states, dones = zip(*mini_sample) + self.trainer.train_step(states, actions, rewards, next_states, dones) + + def train_short_memory(self, state, action, reward, next_state, done): + self.trainer.train_step(state, action, reward, next_state, done) + + def get_action(self, state): + self.epsilon = 80 - self.n_games + final_move = [0,0,0] + if random.randint(0, 200) < self.epsilon: + move = random.randint(0, 2) + final_move[move] = 1 + else: + state0 = torch.tensor(state, dtype=torch.float) + prediction = self.model(state0) + move = torch.argmax(prediction).item() + final_move[move] = 1 + + return final_move + + +def train(): + + plot_scores = [] + plot_mean_scores = [] + total_score = 0 + record = 0 + agent = Agent() + ans = int(input("Train or Load Model [0,1]")) + + while ans not in [0,1]: + print("invalid input") + ans = int(input("Train or Load Model [0,1]")) + + if ans == 1: + agent = joblib.load((os.fspath(Path(__file__).resolve().parent / "best.pkl"))) + + game = SnakeGameAI() + + while True: + if game.save == True: + joblib.dump(agent, (os.fspath(Path(__file__).resolve().parent / "model.pkl"))) + state_old = agent.get_state(game) + final_move = agent.get_action(state_old) + reward, done, score = game.play_step(final_move) + state_new = agent.get_state(game) + agent.train_short_memory(state_old, final_move, reward, state_new, done) + agent.remember(state_old, final_move, reward, state_new, done) + + if done: + game.reset() + agent.n_games += 1 + agent.train_long_memory() + + if score > record: + record = score + joblib.dump(agent,(os.fspath(Path(__file__).resolve().parent / "model.pkl"))) + agent.model.save() + + print('Game', agent.n_games, 'Score', score, 'Record:', record) + + plot_scores.append(score) + total_score += score + mean_score = total_score / agent.n_games + plot_mean_scores.append(mean_score) + plot(plot_scores, plot_mean_scores) + + +if __name__ == '__main__': + train() + \ No newline at end of file diff --git a/best.pkl b/best.pkl new file mode 100644 index 0000000..6ecb941 Binary files /dev/null and b/best.pkl differ diff --git a/game.py b/game.py new file mode 100644 index 0000000..30ba5c3 --- /dev/null +++ b/game.py @@ -0,0 +1,141 @@ +import pygame +import random +from enum import Enum +from collections import namedtuple +import numpy as np + + +pygame.init() + +font = pygame.font.SysFont('calibri', 20) + +class Direction(Enum): + RIGHT = 1 + LEFT = 2 + UP = 3 + DOWN = 4 + +Point = namedtuple('Point', 'x, y') + +WHITE = (255, 255, 255) +RED = (248,143,147) +GREEN = (186,217,181) +FONT = (66,12,20) +BACK = (239,247,207) + +BLOCK_SIZE = 20 +SPEED = 20 + +class SnakeGameAI: + + def __init__(self, w=640, h=480): + self.w = w + self.h = h + self.display = pygame.display.set_mode((self.w, self.h)) + pygame.display.set_caption('Snake') + self.clock = pygame.time.Clock() + self.reset() + self.save = False + + def reset(self): + self.direction = Direction.RIGHT + + self.head = Point(self.w/2, self.h/2) + self.snake = [self.head, + Point(self.head.x-BLOCK_SIZE, self.head.y), + Point(self.head.x-(2*BLOCK_SIZE), self.head.y)] + + self.score = 0 + self.food = None + self._place_food() + self.loop_count = 0 + + def _place_food(self): + x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE + y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE + self.food = Point(x, y) + if self.food in self.snake: + self._place_food() + + def play_step(self, action): + self.loop_count += 1 + for event in pygame.event.get(): + if event.type == pygame.QUIT: + pygame.quit() + quit() + keys_pressed = pygame.key.get_pressed() + + if keys_pressed[pygame.K_u]: + print("saving") + self.save = True + self._move(action) + self.snake.insert(0, self.head) + reward = 0 + game_over = False + if self.is_collision() or self.loop_count > 100*len(self.snake): + reward = -10 + game_over = True + return reward, game_over, self.score + + if self.head == self.food: + self.score += 1 + reward = 10 + self._place_food() + else: + self.snake.pop() + + self._update_ui() + self.clock.tick(SPEED) + return reward,game_over, self.score + + def is_collision(self, pt=None): + if pt is None: + pt = self.head + + if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0: + return True + if pt in self.snake[1:]: + return True + + return False + + def _update_ui(self): + self.display.fill(BACK) + + for pt in self.snake: + pygame.draw.rect(self.display, GREEN, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE)) + + pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE)) + + text = font.render( str(self.score), True, GREEN) + self.display.blit(text, [320, 5]) + pygame.display.flip() + + def _move(self, action): + clock_wise = [Direction.RIGHT, Direction.DOWN,Direction.LEFT,Direction.UP] + idx = clock_wise.index(self.direction) + + if np.array_equal(action, [1,0,0]): + new_dir = clock_wise[idx] + + elif np.array_equal(action, [0,1,0]): + next_idx = (idx+1) % 4 + new_dir = clock_wise[next_idx] + else: + next_idx = (idx-1) % 4 + new_dir = clock_wise[next_idx] + self.direction = new_dir + + x = self.head.x + y = self.head.y + if self.direction == Direction.RIGHT: + x += BLOCK_SIZE + elif self.direction == Direction.LEFT: + x -= BLOCK_SIZE + elif self.direction == Direction.DOWN: + y += BLOCK_SIZE + elif self.direction == Direction.UP: + y -= BLOCK_SIZE + + self.head = Point(x, y) + \ No newline at end of file diff --git a/helper.py b/helper.py new file mode 100644 index 0000000..c323ff3 --- /dev/null +++ b/helper.py @@ -0,0 +1,20 @@ +import matplotlib.pyplot as plt +from IPython import display + +plt.ion() + +def plot(scores, mean_scores): + display.clear_output(wait=True) + display.display(plt.gcf()) + plt.clf() + plt.style.use('ggplot') + plt.title('Training...') + plt.xlabel('Number of Games') + plt.ylabel('Score') + plt.plot(scores) + plt.plot(mean_scores,linestyle=":") + plt.ylim(ymin=0) + plt.text(len(scores)-1, scores[-1], str(scores[-1])) + plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1])) + plt.show(block=False) + plt.pause(.1) \ No newline at end of file diff --git a/model.pkl b/model.pkl new file mode 100644 index 0000000..6ecb941 Binary files /dev/null and b/model.pkl differ diff --git a/model.pth b/model.pth new file mode 100644 index 0000000..c4689f7 Binary files /dev/null and b/model.pth differ diff --git a/model.py b/model.py new file mode 100644 index 0000000..0f7d227 --- /dev/null +++ b/model.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import os +from pathlib import Path + +class Linear_QNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.linear1 = nn.Linear(input_size, hidden_size) + self.linear2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = F.relu(self.linear1(x)) + x = self.linear2(x) + return x + + def save(self, file_name='model.pth'): + torch.save(self.state_dict(), (os.fspath(Path(__file__).resolve().parent / file_name))) + + def load(self): + self = torch.load((os.fspath(Path(__file__).resolve().parent / 'model.pth'))) + + +class QTrainer: + def __init__(self, model, lr, gamma): + self.lr = lr + self.gamma = gamma + self.model = model + self.optimizer = optim.Adam(model.parameters(), lr=self.lr) + self.criterion = nn.MSELoss() + + def train_step(self, state, action, reward, next_state, done): + state = torch.tensor(state, dtype=torch.float) + next_state = torch.tensor(next_state, dtype=torch.float) + action = torch.tensor(action, dtype=torch.long) + reward = torch.tensor(reward, dtype=torch.float) + + if len(state.shape) == 1: + state = torch.unsqueeze(state, 0) + next_state = torch.unsqueeze(next_state, 0) + action = torch.unsqueeze(action, 0) + reward = torch.unsqueeze(reward, 0) + done = (done, ) + pred = self.model(state) + + target = pred.clone() + for idx in range(len(done)): + Q_new = reward[idx] + if not done[idx]: + Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) + + target[idx][torch.argmax(action[idx]).item()] = Q_new + + self.optimizer.zero_grad() + loss = self.criterion(target, pred) + loss.backward() + + self.optimizer.step() + + + + + \ No newline at end of file