Skip to content

Change rendering to use pygame #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Output
logs/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# PyCharm project settings
.idea

# mkdocs documentation
/site

# mypy
.mypy_cache/
79 changes: 79 additions & 0 deletions gym_snake/envs/snake/game_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pygame.locals import *
import pygame
import time
import numpy as np

class GameRender:
window_width = 1000
window_height = 1000

def __init__(self, grid_size):

self.grid_size = grid_size
self._block_size_width = self.window_width / grid_size[0]
self._block_size_height = self.window_height / grid_size[1]

self._running = True
self._display_surf = None
self._snake_surf = None
self._apple_surf = None

self.on_init()

def on_init(self):
pygame.init()
self._display_surf = pygame.display.set_mode((self.window_width, self.window_height), pygame.HWSURFACE)
pygame.display.set_caption('Gym-Snake')
self._running = True

self._snake_head_surf = pygame.Surface([self._block_size_width, self._block_size_height])
self._snake_head_surf.fill((0, 0, 255))

self._snake__body_surf = pygame.Surface([self._block_size_width, self._block_size_height])
self._snake__body_surf.fill((255, 255, 0))

self._apple_surf = pygame.Surface([self._block_size_width, self._block_size_height])
self._apple_surf.fill((0, 255, 0))

self._space_surf = pygame.Surface([self._block_size_width, self._block_size_height])
self._space_surf.fill((255, 255, 255))

def on_event(self, event):
if event.type == QUIT:
self._running = False

def render(self, grid):
self._display_surf.fill((0, 0, 0))

grid_shape = grid.shape
M = grid_shape[0]
N = grid_shape[1]

for i in range(M):
for j in range(N):
if grid[i][j] != 0:
if grid[i][j] == 1:
self._display_surf.blit(self._space_surf, (i * self._block_size_width, j * self._block_size_height))
elif grid[i][j] == -1:
self._display_surf.blit(self._snake__body_surf, (i * self._block_size_width, j * self._block_size_height))
elif grid[i][j] == -2:
self._display_surf.blit(self._snake_head_surf, (i * self._block_size_width, j * self._block_size_height))
elif grid[i][j] == 2:
self._display_surf.blit(self._apple_surf, (i * self._block_size_width, j * self._block_size_height))
pygame.display.flip()

def cleanup(self):
pygame.quit()

if __name__ == "__main__":

game = GameRender(grid_size=(10,10))
grid = np.zeros((10, 10))
for i in range(10):
grid[i][i] = 1
grid[4][7] = 10
print(grid.shape)
game.render(grid)
for i in range(5):
time.sleep(1)
game.on_cleanup()
11 changes: 4 additions & 7 deletions gym_snake/envs/snake/view.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from gym_snake.envs.snake import Snake
from gym_snake.envs.snake import Grid
import copy
import matplotlib.pyplot as plt
import numpy as np
import time

class BaseView:
def __init__(self, grid):
self.grid = grid
self.prev_action = Snake.DOWN

def get(self, action, snake_id):
return grid
def get(self, offset=(0, 0), action=None):
self.prev_action = action
return self.grid.grid

class LocalView:
def __init__(self, grid):
#super.__init__(BaseView, grid)
self.grid = grid
self.prev_action = Snake.DOWN

Expand Down
25 changes: 9 additions & 16 deletions gym_snake/envs/snake_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from gym.utils import seeding
from gym_snake.envs.snake import Controller
from gym_snake.envs.snake.view import LocalView
from gym_snake.envs.snake.game_render import GameRender
import time

import numpy as np

try:
import matplotlib.pyplot as plt
except ImportError as e:
raise error.DependencyNotInstalled("{}. (HINT: see matplotlib documentation for installation https://matplotlib.org/faq/installing_faq.html#installation".format(e))

class SnakeEnv(gym.Env):
metadata = {'render.modes': ['human']}

Expand All @@ -26,8 +23,9 @@ def __init__(self, grid_size=[15,15], unit_size=10, unit_gap=1, snake_size=1, n_
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(low=-2, high=2,
shape=(self.grid_size[0]*2+1, self.grid_size[1]*2+1, 1), dtype=np.float32)
# self.observation_space = spaces.Box(low=-2, high=2,
# shape=([27*27]), dtype=np.int8)


self.viewer = GameRender(self.observation_space.shape)
self.random_init = random_init
self.action_transformer = action_transformer

Expand Down Expand Up @@ -55,17 +53,12 @@ def reset(self):
return self.last_obs

def render(self, mode='human', close=False):
if self.viewer is None:
self.viewer = plt.imshow(np.squeeze(self.last_obs), interpolation='none')
#self.viewer = plt.imshow(self.controller.grid.grid, interpolation='none')
else:
self.viewer.set_data(np.squeeze(self.last_obs))
#self.viewer.set_data(np.squeeze(self.controller.grid.grid))

plt.pause(1.4)
plt.draw()
self.viewer.render(np.squeeze(self.last_obs))
#time.sleep(0.1)

def seed(self, x):
pass

def close(self):
self.viewer.cleanup()