Skip to content

Commit

Permalink
[fix] targetを複数追加できるように
Browse files Browse the repository at this point in the history
  • Loading branch information
nomutin committed May 29, 2024
1 parent ddf453e commit 7dd5897
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/pinpad/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
from gymnasium import Env, spaces
from numpy.random import Generator, RandomState, default_rng

from pinpad.custom_types import (
ACTION,
Expand Down Expand Up @@ -49,16 +50,16 @@ class PinPad(Env[ACTION, OBSERVATION]):
action_space = spaces.Discrete(5)
observation_space = spaces.Box(0, 255, shape=(3,), dtype=np.uint8)
render_mode = "rgb_array"
_np_random: np.random.Generator = np.random.default_rng()
_np_random: Generator = default_rng()

def __init__(self, layout: str, target: list[str]) -> None:
def __init__(self, layout: str, targets: list[list[str]]) -> None:
"""Initialize the environment."""
super().__init__()
self.layout = np.array([list(line) for line in layout.split("\n")]).T
self.pads = set(self.layout.flatten().tolist()) - set("* #\n")
self.target = target
self.targets = targets

self.random = np.random.RandomState()
self.random = RandomState()
self.spawns = []
for (x, y), char in np.ndenumerate(self.layout):
if char != "#":
Expand Down Expand Up @@ -98,7 +99,7 @@ def step(
contains player position and sequence
"""
reward = 0.0
move = [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)][action]
move = [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)][action] # type: ignore[index]
x = np.clip(self.player[0] + move[0], 0, 15)
y = np.clip(self.player[1] + move[1], 0, 13)
tile = self.layout[x][y]
Expand All @@ -108,8 +109,8 @@ def step(
is_new_tile = not self.sequence or self.sequence[-1] != tile
if tile in self.pads and is_new_tile:
self.sequence.append(tile[0])
if self.sequence[-7:] == self.target:
reward += 10.0
if self.sequence in self.targets:
reward = 1.0

observation = self.render()
terminated = False
Expand All @@ -135,13 +136,13 @@ def reset(
Returns
-------
observation : Observation
observation : OBSERVATION
the initial observation of the space.
info : INFO
contains NEW player position and PREVIOUS sequence
"""
if seed is not None:
self._np_random = np.random.default_rng(seed)
self._np_random = default_rng(seed)
self.player = tuple(self._np_random.choice(self.spawns))
if options is not None and "player" in options:
self.player = options["player"]
Expand Down Expand Up @@ -209,5 +210,4 @@ def make(cls, layout: LAYOUTS) -> PinPad:
A new instance of PinPad with the specified layout.
"""
path = Path(__file__).parent / "layouts" / f"{layout}.txt"
with path.open(mode="r") as f:
return cls(layout=f.read().strip("\n"), target=[])
return cls(layout=path.read_text().strip("\n"), targets=[])

0 comments on commit 7dd5897

Please sign in to comment.