diff --git a/trained-agents/a2c/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip b/ready-models/a2c/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip old mode 100755 new mode 100644 similarity index 100% rename from trained-agents/a2c/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip rename to ready-models/a2c/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip diff --git a/trained-agents/dqn/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip b/ready-models/dqn/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip old mode 100755 new mode 100644 similarity index 100% rename from trained-agents/dqn/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip rename to ready-models/dqn/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip diff --git a/ready-models/ppo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip b/ready-models/ppo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip new file mode 100644 index 0000000..f2890d7 Binary files /dev/null and b/ready-models/ppo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip differ diff --git a/trained-agents/ppo/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip b/ready-models/ppo/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip old mode 100755 new mode 100644 similarity index 100% rename from trained-agents/ppo/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip rename to ready-models/ppo/WebsocketPong-v0/WebsocketPong-v0_200000_steps.zip diff --git a/ready-models/trpo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip b/ready-models/trpo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip new file mode 100644 index 0000000..f2890d7 Binary files /dev/null and b/ready-models/trpo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip differ diff --git a/requirements.txt b/requirements.txt index 7cf17d2..04c8f37 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,8 @@ setuptools==66 wheel==0.38.0 opencv-python==4.10.0.84 gym==0.21.0 -stable-baselines3[extra]==1.5.0 +stable-baselines3[extra]==1.5. +sb3-contrib==1.5.0 protobuf==3.20.* numpy==1.26.4 tornado~=6.4.1 diff --git a/scripts/define_routes.py b/scripts/define_routes.py index 9e21b21..fd1ddae 100755 --- a/scripts/define_routes.py +++ b/scripts/define_routes.py @@ -1,68 +1,43 @@ -import os from typing import List, Tuple, Type - -from stable_baselines3 import DQN, PPO, A2C - from src.api import RoutesHandler from src.bots import PongBot, FlappybirdBot, SkijumpBot -from src.env_sim.web_pong import prepare_pong_obs from src.handlers import AiHandler +from src.agents.web_pong import PongAgent +from src.agents.web_flappy_bird import FlappyBirdAgent +from scripts.load_models import ( + dqn_pong, + ppo_pong, + a2c_pong, + ppo_fb, + trpo_fb +) def define_routes() -> List[Tuple[str, Type, dict]]: - dqn_path = os.path.join('trained-agents', 'dqn') - ppo_path = os.path.join('trained-agents', 'ppo') - a2c_path = os.path.join('trained-agents', 'a2c') - - dqn_pong_path = os.path.join( - dqn_path, - 'WebsocketPong-v0', - 'WebsocketPong-v0_200000_steps.zip' - ) - dqn_pong = DQN.load(path=dqn_pong_path) - - ppo_pong_path = os.path.join( - ppo_path, - 'WebsocketPong-v0', - 'WebsocketPong-v0_200000_steps.zip' - ) - ppo_pong = PPO.load(path=ppo_pong_path) - - a2c_pong_path = os.path.join( - a2c_path, - 'WebsocketPong-v0', - 'WebsocketPong-v0_200000_steps.zip' - ) - a2c_pong = A2C.load(path=a2c_pong_path) - routes = [] pong_routes = [ (r"/ws/pong/pong-dqn/", AiHandler, dict( - model=dqn_pong, - obs_funct=prepare_pong_obs, - move_first=-1, - move_last=1, - history_length=3 + agent=PongAgent(dqn_pong, 3) )), (r"/ws/pong/pong-ppo/", AiHandler, dict( - model=ppo_pong, - obs_funct=prepare_pong_obs, - move_first=-1, - move_last=1, - history_length=3 + agent=PongAgent(ppo_pong, 3) )), (r"/ws/pong/pong-a2c/", AiHandler, dict( - model=a2c_pong, - obs_funct=prepare_pong_obs, - move_first=-1, - move_last=1, - history_length=3 + agent=PongAgent(a2c_pong, 3) )), (r"/ws/pong/pong-bot/", PongBot), ] + flappybird_routes = [ + (r"/ws/flappybird/flappybird-ppo/", AiHandler, dict( + agent=FlappyBirdAgent(ppo_fb, 3) + )), + (r"/ws/flappybird/flappybird-trpo/", AiHandler, dict( + agent=FlappyBirdAgent(trpo_fb, 3) + )), (r"/ws/flappybird/flappybird-bot/", FlappybirdBot), ] + skijump_routes = [ (r"/ws/skijump/skijump-bot/", SkijumpBot) ] diff --git a/scripts/load_models.py b/scripts/load_models.py new file mode 100644 index 0000000..7eb1be6 --- /dev/null +++ b/scripts/load_models.py @@ -0,0 +1,13 @@ +from stable_baselines3 import DQN, PPO, A2C +from sb3_contrib import TRPO +from scripts.paths.pong_paths import dqn_pong_path, ppo_pong_path, a2c_pong_path +from scripts.paths.flappy_bird_paths import ppo_fb_path, trpo_fb_path + +# Pong +dqn_pong = DQN.load(path=dqn_pong_path) +ppo_pong = PPO.load(path=ppo_pong_path) +a2c_pong = A2C.load(path=a2c_pong_path) + +# Flappy Bird +ppo_fb = PPO.load(path=ppo_fb_path) +trpo_fb = TRPO.load(path=trpo_fb_path) diff --git a/scripts/paths/algos_paths.py b/scripts/paths/algos_paths.py new file mode 100644 index 0000000..37edff0 --- /dev/null +++ b/scripts/paths/algos_paths.py @@ -0,0 +1,6 @@ +import os + +dqn_path = os.path.join('ready-models', 'dqn') +ppo_path = os.path.join('ready-models', 'ppo') +a2c_path = os.path.join('ready-models', 'a2c') +trpo_path = os.path.join('ready-models', 'trpo') diff --git a/scripts/paths/flappy_bird_paths.py b/scripts/paths/flappy_bird_paths.py new file mode 100644 index 0000000..1422c81 --- /dev/null +++ b/scripts/paths/flappy_bird_paths.py @@ -0,0 +1,12 @@ +from scripts.paths.algos_paths import ppo_path, trpo_path + +import os + + +env_path = os.path.join( + 'WebsocketFlappyBird-v0', + 'WebsocketFlappyBird-v0_200000_steps.zip' +) + +ppo_fb_path = os.path.join(ppo_path, env_path) +trpo_fb_path = os.path.join(trpo_path, env_path) diff --git a/scripts/paths/pong_paths.py b/scripts/paths/pong_paths.py new file mode 100644 index 0000000..be1c43b --- /dev/null +++ b/scripts/paths/pong_paths.py @@ -0,0 +1,12 @@ +from scripts.paths.algos_paths import a2c_path, dqn_path, ppo_path, trpo_path + +import os + +env_path = os.path.join( + 'WebsocketPong-v0', + 'WebsocketPong-v0_200000_steps.zip' +) + +dqn_pong_path = os.path.join(dqn_path, env_path) +ppo_pong_path = os.path.join(ppo_path, env_path) +a2c_pong_path = os.path.join(a2c_path, env_path) diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/src/agents/web_env.py b/src/agents/web_env.py new file mode 100644 index 0000000..d727259 --- /dev/null +++ b/src/agents/web_env.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from collections import deque +from typing import final +from stable_baselines3.common.base_class import BaseAlgorithm + +import numpy as np + + +class WebsocketAgent(ABC): + def __init__(self, model: BaseAlgorithm, history_length: int = 1): + if history_length < 1: + raise ValueError("history_length must be an integer greater than or equal to 1") + self.model = model + self.states = deque(maxlen=history_length) + + def start(self) -> bool: + raise NotImplementedError + + @abstractmethod + def prepare_observation(self, data: dict) -> np.array: + raise NotImplementedError + + @final + def state_stack(self, observation: np.array) -> np.array: + if len(self.states) == 0: + for _ in range(self.states.maxlen): + self.states.append(observation) + else: + self.states.append(observation) + return np.array(self.states).flatten() + + @abstractmethod + def return_prediction(self, data: dict) -> dict: + raise NotImplementedError diff --git a/src/agents/web_flappy_bird.py b/src/agents/web_flappy_bird.py new file mode 100644 index 0000000..f21012e --- /dev/null +++ b/src/agents/web_flappy_bird.py @@ -0,0 +1,55 @@ +from src.agents.web_env import WebsocketAgent +from stable_baselines3.common.base_class import BaseAlgorithm + +import numpy as np + + +class FlappyBirdAgent(WebsocketAgent): + def __init__(self, model: BaseAlgorithm, history_length: int): + super().__init__(model, history_length) + self.min_values = np.array([0, -20, 0.5, 5, -50, 100], dtype=np.float32) + self.max_values = np.array([600, 90, 1, 15, 1900, 500], dtype=np.float32) + self.should_start = False + + def prepare_observation(self, data: dict) -> np.array: + state = data['state'] + if not state['isGameStarted']: + self.should_start = True + self.states.clear() + else: + self.should_start = False + + nearest_obstacle = min(state['obstacles'], key=lambda o: o["distanceX"]) + curr_observation = np.array([ + state['birdY'], + state['birdSpeedY'], + state['gravity'], + state['jumpPowerY'], + nearest_obstacle['distanceX'], + nearest_obstacle['centerGapY'] + ]) + + for i in range(6): + if i == 1: + if curr_observation[i] < 0: + curr_observation[i] = (curr_observation[i] + 20) / 20 - 1 + else: + curr_observation[i] = curr_observation[i] / 90 + elif i == 4: + curr_observation[i] = 2 * ((curr_observation[i] - self.min_values[i]) / + (self.max_values[i] - self.min_values[i])) - 1 + else: + curr_observation[i] = ((curr_observation[i] - self.min_values[i]) / + (self.max_values[i] - self.min_values[i])) + + return self.state_stack(curr_observation) + + def return_prediction(self, data: dict) -> dict: + obs = self.prepare_observation(data) + if not self.should_start: + action, _states = self.model.predict( + observation=obs, + deterministic=True + ) + return {'jump': int(action)} + return {'jump': 1} diff --git a/src/agents/web_pong.py b/src/agents/web_pong.py new file mode 100644 index 0000000..1e4c69f --- /dev/null +++ b/src/agents/web_pong.py @@ -0,0 +1,51 @@ +from src.agents.web_env import WebsocketAgent +from stable_baselines3.common.base_class import BaseAlgorithm + +import numpy as np + + +class PongAgent(WebsocketAgent): + def __init__(self, model: BaseAlgorithm, history_length: int): + super().__init__(model, history_length) + self.min_values = np.array([0, 0, 0, 0, -100, -100], dtype=np.float32) + self.max_values = np.array([600, 600, 1000, 600, 100, 100], dtype=np.float32) + self.action_map = {0: -1, 1: 0, 2: 1} + + def prepare_observation(self, data: dict) -> np.array: + player = data['playerId'] + state = data['state'] + if player == 0: + curr_observation = np.array([ + state['leftPaddleY'], + state['rightPaddleY'], + state['ballX'], + state['ballY'], + state['ballSpeedX'], + state['ballSpeedY'], + ], dtype=np.float32) + else: + curr_observation = np.array([ + state['rightPaddleY'], + state['leftPaddleY'], + 1000 - state['ballX'], + state['ballY'], + -state['ballSpeedX'], + state['ballSpeedY'], + ], dtype=np.float32) + + curr_observation[:4] = ((curr_observation[:4] - self.min_values[:4]) / + (self.max_values[:4] - self.min_values[:4])) + + curr_observation[4:] = 2 * ((curr_observation[4:] - self.min_values[4:]) / + (self.max_values[4:] - self.min_values[4:])) - 1 + + return self.state_stack(curr_observation) + + def return_prediction(self, data: dict) -> dict: + obs = self.prepare_observation(data) + action, _states = self.model.predict( + observation=obs, + deterministic=True + ) + move = self.action_map[int(action)] + return {'move': move, 'start': 1} diff --git a/src/api.py b/src/api.py index cd14291..e3eb226 100755 --- a/src/api.py +++ b/src/api.py @@ -75,8 +75,9 @@ def on_close(self): self.after_close() + @abstractmethod def after_close(self): - pass + raise NotImplementedError @final def on_message(self, message): # do not override diff --git a/src/env_sim/__init__.py b/src/env_sim/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/src/env_sim/general.py b/src/env_sim/general.py deleted file mode 100755 index d5e41d0..0000000 --- a/src/env_sim/general.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections import deque - -import numpy as np - - -def action_map(action: int, first: int, last: int) -> int: - num_actions = last - first + 1 - return first + action % num_actions - - -def state_stack(obs: dict, states: deque) -> np.array: - if len(states) == 0: - for _ in range(states.maxlen): - states.append(obs) - else: - states.append(obs) - - return np.array(states).flatten() diff --git a/src/env_sim/web_pong.py b/src/env_sim/web_pong.py deleted file mode 100755 index f513424..0000000 --- a/src/env_sim/web_pong.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - - -def prepare_pong_obs(data: dict) -> np.array: - min_values = np.array([0, 0, 0, 0, -100, -100], dtype=np.float32) - max_values = np.array([600, 600, 1000, 600, 100, 100], dtype=np.float32) - - player = data['playerId'] - state = data['state'] - if player == 0: - curr_observation = np.array([ - state['leftPaddleY'], - state['rightPaddleY'], - state['ballX'], - state['ballY'], - state['ballSpeedX'], - state['ballSpeedY'], - ], dtype=np.float32) - else: - curr_observation = np.array([ - state['rightPaddleY'], - state['leftPaddleY'], - 1000 - state['ballX'], - state['ballY'], - -state['ballSpeedX'], - state['ballSpeedY'], - ], dtype=np.float32) - - curr_observation[:4] = ((curr_observation[:4] - min_values[:4]) / (max_values[:4] - min_values[:4])) - - curr_observation[4:] = 2 * ((curr_observation[4:] - min_values[4:]) / (max_values[4:] - min_values[4:])) - 1 - - return curr_observation diff --git a/src/handlers.py b/src/handlers.py index 74b1447..06ede87 100755 --- a/src/handlers.py +++ b/src/handlers.py @@ -1,48 +1,24 @@ -import json +from src.api import BaseHandler +from src.agents.web_env import WebsocketAgent from collections import deque from typing import Callable import numpy as np -from stable_baselines3.common.base_class import BaseAlgorithm - -from src.api import BaseHandler -from src.env_sim.general import state_stack, action_map +import json class AiHandler(BaseHandler): def initialize( self, - model: BaseAlgorithm, - obs_funct: Callable[[dict], np.array], - move_first: int, - move_last: int, - history_length: int = 1 + agent: WebsocketAgent, ): - if history_length < 1: - raise ValueError("history_length must be an integer greater than or equal to 1") - - self.model = model - self.prepare_obs = obs_funct - self.states = deque(maxlen=history_length) - self.move_first = move_first - self.move_last = move_last + self.agent = agent def after_close(self): - self.states.clear() + self.agent.states.clear() def send_message(self, message): data = json.loads(message) - obs = self.prepare_obs(data) - obs = state_stack(obs=obs, states=self.states) - action, _states = self.model.predict( - observation=obs, - deterministic=True - ) - move = action_map( - action=int(action), - first=self.move_first, - last=self.move_last - ) - if move is not None: - self.write_message(json.dumps({'move': move, 'start': 1})) + action = self.agent.return_prediction(data) + self.write_message(json.dumps(action))