Skip to content

Commit

Permalink
Merge pull request #15 from KN-GEST-ongit/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
bkrowka authored Dec 11, 2024
2 parents d128346 + 56e3401 commit 0134546
Show file tree
Hide file tree
Showing 24 changed files with 245 additions and 147 deletions.
File renamed without changes.
File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ setuptools==66
wheel==0.38.0
opencv-python==4.10.0.84
gym==0.21.0
stable-baselines3[extra]==1.5.0
stable-baselines3[extra]==1.5.
sb3-contrib==1.5.0
protobuf==3.20.*
numpy==1.26.4
tornado~=6.4.1
Expand Down
77 changes: 30 additions & 47 deletions scripts/define_routes.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,51 @@
import os
from typing import List, Tuple, Type

from stable_baselines3 import DQN, PPO, A2C

from src.api import RoutesHandler
from src.bots import PongBot, FlappybirdBot, SkijumpBot
from src.env_sim.web_pong import prepare_pong_obs
from src.handlers import AiHandler
from src.agents.web_pong import PongAgent
from src.agents.web_flappy_bird import FlappyBirdAgent
from scripts.load_models import (
dqn_pong,
ppo_pong,
a2c_pong,
trpo_pong,
qrdqn_pong,
ppo_fb,
trpo_fb
)


def define_routes() -> List[Tuple[str, Type, dict]]:
dqn_path = os.path.join('trained-agents', 'dqn')
ppo_path = os.path.join('trained-agents', 'ppo')
a2c_path = os.path.join('trained-agents', 'a2c')

dqn_pong_path = os.path.join(
dqn_path,
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)
dqn_pong = DQN.load(path=dqn_pong_path)

ppo_pong_path = os.path.join(
ppo_path,
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)
ppo_pong = PPO.load(path=ppo_pong_path)

a2c_pong_path = os.path.join(
a2c_path,
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)
a2c_pong = A2C.load(path=a2c_pong_path)

routes = []
pong_routes = [
(r"/ws/pong/pong-bot/", PongBot),
(r"/ws/pong/pong-a2c/", AiHandler, dict(
agent=PongAgent(a2c_pong, 3)
)),
(r"/ws/pong/pong-dqn/", AiHandler, dict(
model=dqn_pong,
obs_funct=prepare_pong_obs,
move_first=-1,
move_last=1,
history_length=3
agent=PongAgent(dqn_pong, 3)
)),
(r"/ws/pong/pong-ppo/", AiHandler, dict(
model=ppo_pong,
obs_funct=prepare_pong_obs,
move_first=-1,
move_last=1,
history_length=3
agent=PongAgent(ppo_pong, 3)
)),
(r"/ws/pong/pong-a2c/", AiHandler, dict(
model=a2c_pong,
obs_funct=prepare_pong_obs,
move_first=-1,
move_last=1,
history_length=3
(r"/ws/pong/pong-qrdqn/", AiHandler, dict(
agent=PongAgent(qrdqn_pong, 3)
)),
(r"/ws/pong/pong-bot/", PongBot),
(r"/ws/pong/pong-trpo/", AiHandler, dict(
agent=PongAgent(trpo_pong, 3)
))
]

flappybird_routes = [
(r"/ws/flappybird/flappybird-bot/", FlappybirdBot),
(r"/ws/flappybird/flappybird-ppo/", AiHandler, dict(
agent=FlappyBirdAgent(ppo_fb, 3)
)),
(r"/ws/flappybird/flappybird-trpo/", AiHandler, dict(
agent=FlappyBirdAgent(trpo_fb, 3)
))
]

skijump_routes = [
(r"/ws/skijump/skijump-bot/", SkijumpBot)
]
Expand Down
16 changes: 16 additions & 0 deletions scripts/load_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from stable_baselines3 import DQN, PPO, A2C
from sb3_contrib import TRPO, QRDQN
from scripts.paths.pong_paths import dqn_pong_path, ppo_pong_path, a2c_pong_path, trpo_pong_path, qrdqn_pong_path
from scripts.paths.flappy_bird_paths import ppo_fb_path, trpo_fb_path

# Pong
dqn_pong = DQN.load(path=dqn_pong_path)
ppo_pong = PPO.load(path=ppo_pong_path)
a2c_pong = A2C.load(path=a2c_pong_path)
trpo_pong = TRPO.load(path=trpo_pong_path)
qrdqn_pong = QRDQN.load(path=qrdqn_pong_path)


# Flappy Bird
ppo_fb = PPO.load(path=ppo_fb_path)
trpo_fb = TRPO.load(path=trpo_fb_path)
7 changes: 7 additions & 0 deletions scripts/paths/algos_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

dqn_path = os.path.join('ready-models', 'dqn')
ppo_path = os.path.join('ready-models', 'ppo')
a2c_path = os.path.join('ready-models', 'a2c')
trpo_path = os.path.join('ready-models', 'trpo')
qrdqn_path = os.path.join('ready-models', 'qrdqn')
12 changes: 12 additions & 0 deletions scripts/paths/flappy_bird_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from scripts.paths.algos_paths import ppo_path, trpo_path

import os


env_path = os.path.join(
'WebsocketFlappyBird-v0',
'WebsocketFlappyBird-v0_200000_steps.zip'
)

ppo_fb_path = os.path.join(ppo_path, env_path)
trpo_fb_path = os.path.join(trpo_path, env_path)
14 changes: 14 additions & 0 deletions scripts/paths/pong_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from scripts.paths.algos_paths import a2c_path, dqn_path, ppo_path, trpo_path, qrdqn_path

import os

env_path = os.path.join(
'WebsocketPong-v0',
'WebsocketPong-v0_200000_steps.zip'
)

dqn_pong_path = os.path.join(dqn_path, env_path)
ppo_pong_path = os.path.join(ppo_path, env_path)
a2c_pong_path = os.path.join(a2c_path, env_path)
trpo_pong_path = os.path.join(trpo_path, env_path)
qrdqn_pong_path = os.path.join(qrdqn_path, env_path)
Empty file removed src/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions src/agents/web_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from collections import deque
from typing import final
from stable_baselines3.common.base_class import BaseAlgorithm

import numpy as np


class WebsocketAgent(ABC):
def __init__(self, model: BaseAlgorithm, history_length: int = 1):
if history_length < 1:
raise ValueError("history_length must be an integer greater than or equal to 1")
self.model = model
self.states = deque(maxlen=history_length)

def start(self) -> bool:
raise NotImplementedError

@abstractmethod
def prepare_observation(self, data: dict) -> np.array:
raise NotImplementedError

@final
def state_stack(self, observation: np.array) -> np.array:
if len(self.states) == 0:
for _ in range(self.states.maxlen):
self.states.append(observation)
else:
self.states.append(observation)
return np.array(self.states).flatten()

@abstractmethod
def return_prediction(self, data: dict) -> dict:
raise NotImplementedError
58 changes: 58 additions & 0 deletions src/agents/web_flappy_bird.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from src.agents.web_env import WebsocketAgent
from stable_baselines3.common.base_class import BaseAlgorithm

import numpy as np


class FlappyBirdAgent(WebsocketAgent):
def __init__(self, model: BaseAlgorithm, history_length: int):
super().__init__(model, history_length)
self.min_values = np.array([0, -20, 0.5, 5, -50, 100], dtype=np.float32)
self.max_values = np.array([600, 90, 1, 15, 1900, 500], dtype=np.float32)
self.should_start = False

def prepare_observation(self, data: dict) -> np.array:
state = data['state']
if not state['isGameStarted']:
self.should_start = True
self.states.clear()
else:
self.should_start = False

nearest_obstacle = min(
(o for o in state['obstacles'] if o["distanceX"] > 0),
key=lambda o: o["distanceX"]
)
curr_observation = np.array([
state['birdY'],
state['birdSpeedY'],
state['gravity'],
state['jumpPowerY'],
nearest_obstacle['distanceX'],
nearest_obstacle['centerGapY']
])

for i in range(6):
if i == 1:
if curr_observation[i] < 0:
curr_observation[i] = (curr_observation[i] + 20) / 20 - 1
else:
curr_observation[i] = curr_observation[i] / 90
elif i == 4:
curr_observation[i] = 2 * ((curr_observation[i] - self.min_values[i]) /
(self.max_values[i] - self.min_values[i])) - 1
else:
curr_observation[i] = ((curr_observation[i] - self.min_values[i]) /
(self.max_values[i] - self.min_values[i]))

return self.state_stack(curr_observation)

def return_prediction(self, data: dict) -> dict:
obs = self.prepare_observation(data)
if not self.should_start:
action, _states = self.model.predict(
observation=obs,
deterministic=True
)
return {'jump': int(action)}
return {'jump': 1}
51 changes: 51 additions & 0 deletions src/agents/web_pong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from src.agents.web_env import WebsocketAgent
from stable_baselines3.common.base_class import BaseAlgorithm

import numpy as np


class PongAgent(WebsocketAgent):
def __init__(self, model: BaseAlgorithm, history_length: int):
super().__init__(model, history_length)
self.min_values = np.array([0, 0, 0, 0, -100, -100], dtype=np.float32)
self.max_values = np.array([600, 600, 1000, 600, 100, 100], dtype=np.float32)
self.action_map = {0: -1, 1: 0, 2: 1}

def prepare_observation(self, data: dict) -> np.array:
player = data['playerId']
state = data['state']
if player == 0:
curr_observation = np.array([
state['leftPaddleY'],
state['rightPaddleY'],
state['ballX'],
state['ballY'],
state['ballSpeedX'],
state['ballSpeedY'],
], dtype=np.float32)
else:
curr_observation = np.array([
state['rightPaddleY'],
state['leftPaddleY'],
1000 - state['ballX'],
state['ballY'],
-state['ballSpeedX'],
state['ballSpeedY'],
], dtype=np.float32)

curr_observation[:4] = ((curr_observation[:4] - self.min_values[:4]) /
(self.max_values[:4] - self.min_values[:4]))

curr_observation[4:] = 2 * ((curr_observation[4:] - self.min_values[4:]) /
(self.max_values[4:] - self.min_values[4:])) - 1

return self.state_stack(curr_observation)

def return_prediction(self, data: dict) -> dict:
obs = self.prepare_observation(data)
action, _states = self.model.predict(
observation=obs,
deterministic=True
)
move = self.action_map[int(action)]
return {'move': move, 'start': 1}
2 changes: 1 addition & 1 deletion src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def on_close(self):
self.after_close()

def after_close(self):
pass
raise NotImplementedError

@final
def on_message(self, message): # do not override
Expand Down
27 changes: 12 additions & 15 deletions src/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
class PongBot(BaseHandler):
def send_message(self, message):
data = json.loads(message)
player = data['playerId']
state = data['state']

if data['playerId'] == 0:
if data['state']['ballY'] < data['state']['leftPaddleY'] + 50:
if player == 0:
if state['ballY'] < state['leftPaddleY'] + 50:
move = 1
else:
move = -1
else:
if data['state']['ballY'] < data['state']['rightPaddleY'] + 50:
if state['ballY'] < state['rightPaddleY'] + 50:
move = 1
else:
move = -1
Expand All @@ -23,22 +25,17 @@ def send_message(self, message):

class FlappybirdBot(BaseHandler):
def send_message(self, message):
data = json.loads(message)['state']
state = json.loads(message)['state']
jump = 0

if not data['isGameStarted']:
if not state['isGameStarted']:
jump = 1
else:
obstacles = data['obstacles']
lowestDist = 1000
lowestDistIndex = 0
for i in range(0, len(obstacles)):
dist = obstacles[i]['distanceX']
if lowestDist > dist > 60:
lowestDist = dist
lowestDistIndex = i

if data['birdY'] > obstacles[lowestDistIndex]['centerGapY']:
nearest_obstacle = min(
(o for o in state['obstacles'] if o["distanceX"] > 0),
key=lambda o: o["distanceX"]
)
if state['birdY'] > nearest_obstacle['centerGapY']:
jump = 1

self.write_message(json.dumps({'jump': jump}))
Expand Down
Empty file removed src/env_sim/__init__.py
Empty file.
18 changes: 0 additions & 18 deletions src/env_sim/general.py

This file was deleted.

Loading

0 comments on commit 0134546

Please sign in to comment.