-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from KN-GEST-ongit/dev
Dev
- Loading branch information
Showing
24 changed files
with
245 additions
and
147 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
Binary file added
BIN
+151 KB
ready-models/ppo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+151 KB
ready-models/trpo/WebsocketFlappyBird-v0/WebsocketFlappyBird-v0_200000_steps.zip
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from stable_baselines3 import DQN, PPO, A2C | ||
from sb3_contrib import TRPO, QRDQN | ||
from scripts.paths.pong_paths import dqn_pong_path, ppo_pong_path, a2c_pong_path, trpo_pong_path, qrdqn_pong_path | ||
from scripts.paths.flappy_bird_paths import ppo_fb_path, trpo_fb_path | ||
|
||
# Pong | ||
dqn_pong = DQN.load(path=dqn_pong_path) | ||
ppo_pong = PPO.load(path=ppo_pong_path) | ||
a2c_pong = A2C.load(path=a2c_pong_path) | ||
trpo_pong = TRPO.load(path=trpo_pong_path) | ||
qrdqn_pong = QRDQN.load(path=qrdqn_pong_path) | ||
|
||
|
||
# Flappy Bird | ||
ppo_fb = PPO.load(path=ppo_fb_path) | ||
trpo_fb = TRPO.load(path=trpo_fb_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import os | ||
|
||
dqn_path = os.path.join('ready-models', 'dqn') | ||
ppo_path = os.path.join('ready-models', 'ppo') | ||
a2c_path = os.path.join('ready-models', 'a2c') | ||
trpo_path = os.path.join('ready-models', 'trpo') | ||
qrdqn_path = os.path.join('ready-models', 'qrdqn') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from scripts.paths.algos_paths import ppo_path, trpo_path | ||
|
||
import os | ||
|
||
|
||
env_path = os.path.join( | ||
'WebsocketFlappyBird-v0', | ||
'WebsocketFlappyBird-v0_200000_steps.zip' | ||
) | ||
|
||
ppo_fb_path = os.path.join(ppo_path, env_path) | ||
trpo_fb_path = os.path.join(trpo_path, env_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from scripts.paths.algos_paths import a2c_path, dqn_path, ppo_path, trpo_path, qrdqn_path | ||
|
||
import os | ||
|
||
env_path = os.path.join( | ||
'WebsocketPong-v0', | ||
'WebsocketPong-v0_200000_steps.zip' | ||
) | ||
|
||
dqn_pong_path = os.path.join(dqn_path, env_path) | ||
ppo_pong_path = os.path.join(ppo_path, env_path) | ||
a2c_pong_path = os.path.join(a2c_path, env_path) | ||
trpo_pong_path = os.path.join(trpo_path, env_path) | ||
qrdqn_pong_path = os.path.join(qrdqn_path, env_path) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from abc import ABC, abstractmethod | ||
from collections import deque | ||
from typing import final | ||
from stable_baselines3.common.base_class import BaseAlgorithm | ||
|
||
import numpy as np | ||
|
||
|
||
class WebsocketAgent(ABC): | ||
def __init__(self, model: BaseAlgorithm, history_length: int = 1): | ||
if history_length < 1: | ||
raise ValueError("history_length must be an integer greater than or equal to 1") | ||
self.model = model | ||
self.states = deque(maxlen=history_length) | ||
|
||
def start(self) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def prepare_observation(self, data: dict) -> np.array: | ||
raise NotImplementedError | ||
|
||
@final | ||
def state_stack(self, observation: np.array) -> np.array: | ||
if len(self.states) == 0: | ||
for _ in range(self.states.maxlen): | ||
self.states.append(observation) | ||
else: | ||
self.states.append(observation) | ||
return np.array(self.states).flatten() | ||
|
||
@abstractmethod | ||
def return_prediction(self, data: dict) -> dict: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from src.agents.web_env import WebsocketAgent | ||
from stable_baselines3.common.base_class import BaseAlgorithm | ||
|
||
import numpy as np | ||
|
||
|
||
class FlappyBirdAgent(WebsocketAgent): | ||
def __init__(self, model: BaseAlgorithm, history_length: int): | ||
super().__init__(model, history_length) | ||
self.min_values = np.array([0, -20, 0.5, 5, -50, 100], dtype=np.float32) | ||
self.max_values = np.array([600, 90, 1, 15, 1900, 500], dtype=np.float32) | ||
self.should_start = False | ||
|
||
def prepare_observation(self, data: dict) -> np.array: | ||
state = data['state'] | ||
if not state['isGameStarted']: | ||
self.should_start = True | ||
self.states.clear() | ||
else: | ||
self.should_start = False | ||
|
||
nearest_obstacle = min( | ||
(o for o in state['obstacles'] if o["distanceX"] > 0), | ||
key=lambda o: o["distanceX"] | ||
) | ||
curr_observation = np.array([ | ||
state['birdY'], | ||
state['birdSpeedY'], | ||
state['gravity'], | ||
state['jumpPowerY'], | ||
nearest_obstacle['distanceX'], | ||
nearest_obstacle['centerGapY'] | ||
]) | ||
|
||
for i in range(6): | ||
if i == 1: | ||
if curr_observation[i] < 0: | ||
curr_observation[i] = (curr_observation[i] + 20) / 20 - 1 | ||
else: | ||
curr_observation[i] = curr_observation[i] / 90 | ||
elif i == 4: | ||
curr_observation[i] = 2 * ((curr_observation[i] - self.min_values[i]) / | ||
(self.max_values[i] - self.min_values[i])) - 1 | ||
else: | ||
curr_observation[i] = ((curr_observation[i] - self.min_values[i]) / | ||
(self.max_values[i] - self.min_values[i])) | ||
|
||
return self.state_stack(curr_observation) | ||
|
||
def return_prediction(self, data: dict) -> dict: | ||
obs = self.prepare_observation(data) | ||
if not self.should_start: | ||
action, _states = self.model.predict( | ||
observation=obs, | ||
deterministic=True | ||
) | ||
return {'jump': int(action)} | ||
return {'jump': 1} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from src.agents.web_env import WebsocketAgent | ||
from stable_baselines3.common.base_class import BaseAlgorithm | ||
|
||
import numpy as np | ||
|
||
|
||
class PongAgent(WebsocketAgent): | ||
def __init__(self, model: BaseAlgorithm, history_length: int): | ||
super().__init__(model, history_length) | ||
self.min_values = np.array([0, 0, 0, 0, -100, -100], dtype=np.float32) | ||
self.max_values = np.array([600, 600, 1000, 600, 100, 100], dtype=np.float32) | ||
self.action_map = {0: -1, 1: 0, 2: 1} | ||
|
||
def prepare_observation(self, data: dict) -> np.array: | ||
player = data['playerId'] | ||
state = data['state'] | ||
if player == 0: | ||
curr_observation = np.array([ | ||
state['leftPaddleY'], | ||
state['rightPaddleY'], | ||
state['ballX'], | ||
state['ballY'], | ||
state['ballSpeedX'], | ||
state['ballSpeedY'], | ||
], dtype=np.float32) | ||
else: | ||
curr_observation = np.array([ | ||
state['rightPaddleY'], | ||
state['leftPaddleY'], | ||
1000 - state['ballX'], | ||
state['ballY'], | ||
-state['ballSpeedX'], | ||
state['ballSpeedY'], | ||
], dtype=np.float32) | ||
|
||
curr_observation[:4] = ((curr_observation[:4] - self.min_values[:4]) / | ||
(self.max_values[:4] - self.min_values[:4])) | ||
|
||
curr_observation[4:] = 2 * ((curr_observation[4:] - self.min_values[4:]) / | ||
(self.max_values[4:] - self.min_values[4:])) - 1 | ||
|
||
return self.state_stack(curr_observation) | ||
|
||
def return_prediction(self, data: dict) -> dict: | ||
obs = self.prepare_observation(data) | ||
action, _states = self.model.predict( | ||
observation=obs, | ||
deterministic=True | ||
) | ||
move = self.action_map[int(action)] | ||
return {'move': move, 'start': 1} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.