Skip to content

Commit

Permalink
sarsa update and flake8 error
Browse files Browse the repository at this point in the history
  • Loading branch information
amsks committed Nov 9, 2023
1 parent 7d8d969 commit efdea61
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rl_exercises/week_4/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, env: gym.Env, policy: EpsilonGreedyPolicy, alpha: float = 0.5

def predict_action(self, state: np.array, info: dict = {}, evaluate: bool = False) -> Any: # type: ignore # noqa
"""Predict the action for a given state"""
action = info
action = self.policy(self.Q, state, eval=evaluate) # type: ignore
info = {}
return action

def save(self, path: str) -> Any: # type: ignore
Expand All @@ -72,7 +72,7 @@ def load(self, path) -> Any: # type: ignore
"""
self.Q = np.load(path)

def update( # type: ignore
def update_agent( # type: ignore
self,
transition: list[np.array], # type: ignore
next_action: int,
Expand Down

0 comments on commit efdea61

Please sign in to comment.