Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 6 #14

Merged
merged 6 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
Binary file not shown.
File renamed without changes.
Binary file not shown.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ setuptools==66
wheel==0.38.0
opencv-python==4.10.0.84
gym==0.21.0
stable-baselines3[extra]==1.5.0
stable-baselines3[extra]==1.5.
sb3-contrib==1.5.0
protobuf==3.20.*
numpy==1.26.4
tornado~=6.4.1
Expand Down
65 changes: 20 additions & 45 deletions scripts/define_routes.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,43 @@
import os
from typing import List, Tuple, Type

from stable_baselines3 import DQN, PPO, A2C

from src.api import RoutesHandler
from src.bots import PongBot, FlappybirdBot, SkijumpBot
from src.env_sim.web_pong import prepare_pong_obs
from src.handlers import AiHandler
from src.agents.web_pong import PongAgent
from src.agents.web_flappy_bird import FlappyBirdAgent
from scripts.load_models import (
dqn_pong,
ppo_pong,
a2c_pong,
ppo_fb,
trpo_fb
)


def define_routes() -> List[Tuple[str, Type, dict]]:
dqn_path = os.path.join('trained-agents', 'dqn')
ppo_path = os.path.join('trained-agents', 'ppo')
a2c_path = os.path.join('trained-agents', 'a2c')

dqn_pong_path = os.path.join(
dqn_path,
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)
dqn_pong = DQN.load(path=dqn_pong_path)

ppo_pong_path = os.path.join(
ppo_path,
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)
ppo_pong = PPO.load(path=ppo_pong_path)

a2c_pong_path = os.path.join(
a2c_path,
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)
a2c_pong = A2C.load(path=a2c_pong_path)

routes = []
pong_routes = [
(r"/ws/pong/pong-dqn/", AiHandler, dict(
model=dqn_pong,
obs_funct=prepare_pong_obs,
move_first=-1,
move_last=1,
history_length=3
agent=PongAgent(dqn_pong, 3)
)),
(r"/ws/pong/pong-ppo/", AiHandler, dict(
model=ppo_pong,
obs_funct=prepare_pong_obs,
move_first=-1,
move_last=1,
history_length=3
agent=PongAgent(ppo_pong, 3)
)),
(r"/ws/pong/pong-a2c/", AiHandler, dict(
model=a2c_pong,
obs_funct=prepare_pong_obs,
move_first=-1,
move_last=1,
history_length=3
agent=PongAgent(a2c_pong, 3)
)),
(r"/ws/pong/pong-bot/", PongBot),
]

flappybird_routes = [
(r"/ws/flappybird/flappybird-ppo/", AiHandler, dict(
agent=FlappyBirdAgent(ppo_fb, 3)
)),
(r"/ws/flappybird/flappybird-trpo/", AiHandler, dict(
agent=FlappyBirdAgent(trpo_fb, 3)
)),
(r"/ws/flappybird/flappybird-bot/", FlappybirdBot),
]

skijump_routes = [
(r"/ws/skijump/skijump-bot/", SkijumpBot)
]
Expand Down
13 changes: 13 additions & 0 deletions scripts/load_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from stable_baselines3 import DQN, PPO, A2C
from sb3_contrib import TRPO
from scripts.paths.pong_paths import dqn_pong_path, ppo_pong_path, a2c_pong_path
from scripts.paths.flappy_bird_paths import ppo_fb_path, trpo_fb_path

# Pong
dqn_pong = DQN.load(path=dqn_pong_path)
ppo_pong = PPO.load(path=ppo_pong_path)
a2c_pong = A2C.load(path=a2c_pong_path)

# Flappy Bird
ppo_fb = PPO.load(path=ppo_fb_path)
trpo_fb = TRPO.load(path=trpo_fb_path)
6 changes: 6 additions & 0 deletions scripts/paths/algos_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os

dqn_path = os.path.join('ready-models', 'dqn')
ppo_path = os.path.join('ready-models', 'ppo')
a2c_path = os.path.join('ready-models', 'a2c')
trpo_path = os.path.join('ready-models', 'trpo')
12 changes: 12 additions & 0 deletions scripts/paths/flappy_bird_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from scripts.paths.algos_paths import ppo_path, trpo_path

import os


env_path = os.path.join(
'WebsocketFlappyBird-v0',
'WebsocketFlappyBird-v0_200000_steps.zip'
)

ppo_fb_path = os.path.join(ppo_path, env_path)
trpo_fb_path = os.path.join(trpo_path, env_path)
12 changes: 12 additions & 0 deletions scripts/paths/pong_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from scripts.paths.algos_paths import a2c_path, dqn_path, ppo_path, trpo_path

import os

env_path = os.path.join(
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)

dqn_pong_path = os.path.join(dqn_path, env_path)
ppo_pong_path = os.path.join(ppo_path, env_path)
a2c_pong_path = os.path.join(a2c_path, env_path)
Empty file removed src/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions src/agents/web_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from collections import deque
from typing import final
from stable_baselines3.common.base_class import BaseAlgorithm

import numpy as np


class WebsocketAgent(ABC):
def __init__(self, model: BaseAlgorithm, history_length: int = 1):
if history_length < 1:
raise ValueError("history_length must be an integer greater than or equal to 1")
self.model = model
self.states = deque(maxlen=history_length)

def start(self) -> bool:
raise NotImplementedError

@abstractmethod
def prepare_observation(self, data: dict) -> np.array:
raise NotImplementedError

@final
def state_stack(self, observation: np.array) -> np.array:
if len(self.states) == 0:
for _ in range(self.states.maxlen):
self.states.append(observation)
else:
self.states.append(observation)
return np.array(self.states).flatten()

@abstractmethod
def return_prediction(self, data: dict) -> dict:
raise NotImplementedError
55 changes: 55 additions & 0 deletions src/agents/web_flappy_bird.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from src.agents.web_env import WebsocketAgent
from stable_baselines3.common.base_class import BaseAlgorithm

import numpy as np


class FlappyBirdAgent(WebsocketAgent):
def __init__(self, model: BaseAlgorithm, history_length: int):
super().__init__(model, history_length)
self.min_values = np.array([0, -20, 0.5, 5, -50, 100], dtype=np.float32)
self.max_values = np.array([600, 90, 1, 15, 1900, 500], dtype=np.float32)
self.should_start = False

def prepare_observation(self, data: dict) -> np.array:
state = data['state']
if not state['isGameStarted']:
self.should_start = True
self.states.clear()
else:
self.should_start = False

nearest_obstacle = min(state['obstacles'], key=lambda o: o["distanceX"])
curr_observation = np.array([
state['birdY'],
state['birdSpeedY'],
state['gravity'],
state['jumpPowerY'],
nearest_obstacle['distanceX'],
nearest_obstacle['centerGapY']
])

for i in range(6):
if i == 1:
if curr_observation[i] < 0:
curr_observation[i] = (curr_observation[i] + 20) / 20 - 1
else:
curr_observation[i] = curr_observation[i] / 90
elif i == 4:
curr_observation[i] = 2 * ((curr_observation[i] - self.min_values[i]) /
(self.max_values[i] - self.min_values[i])) - 1
else:
curr_observation[i] = ((curr_observation[i] - self.min_values[i]) /
(self.max_values[i] - self.min_values[i]))

return self.state_stack(curr_observation)

def return_prediction(self, data: dict) -> dict:
obs = self.prepare_observation(data)
if not self.should_start:
action, _states = self.model.predict(
observation=obs,
deterministic=True
)
return {'jump': int(action)}
return {'jump': 1}
51 changes: 51 additions & 0 deletions src/agents/web_pong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from src.agents.web_env import WebsocketAgent
from stable_baselines3.common.base_class import BaseAlgorithm

import numpy as np


class PongAgent(WebsocketAgent):
def __init__(self, model: BaseAlgorithm, history_length: int):
super().__init__(model, history_length)
self.min_values = np.array([0, 0, 0, 0, -100, -100], dtype=np.float32)
self.max_values = np.array([600, 600, 1000, 600, 100, 100], dtype=np.float32)
self.action_map = {0: -1, 1: 0, 2: 1}

def prepare_observation(self, data: dict) -> np.array:
player = data['playerId']
state = data['state']
if player == 0:
curr_observation = np.array([
state['leftPaddleY'],
state['rightPaddleY'],
state['ballX'],
state['ballY'],
state['ballSpeedX'],
state['ballSpeedY'],
], dtype=np.float32)
else:
curr_observation = np.array([
state['rightPaddleY'],
state['leftPaddleY'],
1000 - state['ballX'],
state['ballY'],
-state['ballSpeedX'],
state['ballSpeedY'],
], dtype=np.float32)

curr_observation[:4] = ((curr_observation[:4] - self.min_values[:4]) /
(self.max_values[:4] - self.min_values[:4]))

curr_observation[4:] = 2 * ((curr_observation[4:] - self.min_values[4:]) /
(self.max_values[4:] - self.min_values[4:])) - 1

return self.state_stack(curr_observation)

def return_prediction(self, data: dict) -> dict:
obs = self.prepare_observation(data)
action, _states = self.model.predict(
observation=obs,
deterministic=True
)
move = self.action_map[int(action)]
return {'move': move, 'start': 1}
3 changes: 2 additions & 1 deletion src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def on_close(self):

self.after_close()

@abstractmethod
def after_close(self):
pass
raise NotImplementedError

@final
def on_message(self, message): # do not override
Expand Down
Empty file removed src/env_sim/__init__.py
Empty file.
18 changes: 0 additions & 18 deletions src/env_sim/general.py

This file was deleted.

33 changes: 0 additions & 33 deletions src/env_sim/web_pong.py

This file was deleted.

Loading
Loading