Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions grid_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self, rows, cols):
self.state_actions = {}
self.action_result = {}
self.action_probabilities = {}
self.start_state = (0,0)
self.start_state = (0, 0)
self.current_state = (0, 0)
for i in range(rows):
for j in range(cols):
self.states.append((i,j))
Expand Down Expand Up @@ -98,25 +99,15 @@ def get_list_of_actions_for_state(self, state):
'''retuns a list of actions for a given state'''
return self.actions[state]

def set_current_state(self, state):
'''state is a tuple (row,col)
sets current state of the env to be state'''
self.current_state = state

def current_state(self):
'''returns current state of the env'''
return self.current_state

def take_action(self, action):
'''action - action from available list of actions.
Checks if action is allowed for this state. If it is not
does nothing. Otherwise sets current state to resulting state'''
if action not in self.actions[current_state]:
if action not in self.state_actions[self.current_state]:
print("Action {0} is not allowed for this state.".format(action))
return
else:
self.current_state = \
self.action_result[(self.current_state,action)]
self.current_state = self.action_result[(self.current_state, action)]

def is_terminal_state(self, state):
return state in self.current_state
Expand Down Expand Up @@ -152,6 +143,22 @@ def show_rewards(self):
print("+---",end="")
print("+")

def show_current_state(self):
print("#####CURRENT STATE#####")
for row in range(self.rows):
for col in range(self.cols):
print("+---", end="")
print("+")
for col in range(self.cols):
if self.current_state == (row, col):
print("| X ", end="")
else:
print("| ", end="")
print("|")
for col in range(self.cols):
print("+---", end="")
print("+")

def str_list(self, list):
s = ""
for t in list: s = s+t
Expand Down Expand Up @@ -198,14 +205,15 @@ def show_env(self):
self.show_policy()
# end of Grid_World class


def init_simple_grid():
g = Grid_World(4,4)
g = Grid_World(4, 4)
#init all states
g.discount = 1
#set starting and terminal states
g.start_state = (0,0)
g.terminal_states.append((g.rows-1,g.cols-1))
#init all rewards except terminal state to be -1
g.terminal_states.append((g.rows-1, g.cols-1))
# init all rewards except terminal state to be -1
for state in g.states:
if state in g.terminal_states:
g.rewards[state] = 0
Expand Down
78 changes: 53 additions & 25 deletions iterative_policy_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,30 @@

from grid_world import init_simple_grid


def show_values(values, grid, iteration):
'''values is a dict from (row,col) to scalar value'''
"""
values is a dict from (row,col) to scalar value
"""
print("#####VALUES##### iteration {0}".format(iteration))
for row in range(grid.rows):
for col in range(grid.cols):
print("+------",end="")
print("+\n|",end="")
print("+------", end="")
print("+\n|", end="")
for col in range(grid.cols):
value = values[(row,col)]
value = values[(row, col)]
if value >= 0:
print("{:>6.2f}|".format(value),end="")
print("{:>6.2f}|".format(value), end="")
else:
print("{:>6.2f}|".format(value),end="")
print("{:>6.2f}|".format(value), end="")
print("")
for col in range(grid.cols):
print("+------",end="")
print("+------", end="")
print("+")


def run_algorithm(show_all_iterations=False):
'''
"""
num_of_states - number of available states
actions - matrix where col num represents a state, row represents action,
element at action,state is the resulting state
Expand All @@ -52,7 +55,8 @@ def run_algorithm(show_all_iterations=False):
policy - matrix probability distribution of actions over states
num of states x num of actions. Each row contains probabilities
of taking some action while being in row state
'''
"""

grid = init_simple_grid()
states = grid.states
gamma = grid.discount
Expand All @@ -61,37 +65,61 @@ def run_algorithm(show_all_iterations=False):
action_result = grid.action_result
action_prob = grid.action_probabilities
threshold = 0.000000000000000001
#initialize dictionary of state to values
# initialize dictionary of state to values
values = {}
#set initial values to be zeros
for s in states: values[s] = 0
#Repeat until convergence
# set initial values to be zeros
for s in states:
values[s] = 0
# Repeat until convergence
iteration = 0
while True:
if show_all_iterations: show_values(values,grid, 0)
if show_all_iterations:
show_values(values, grid, 0)
new_values = {}
delta = 0
for state in states:
old_value = values[state]
#exclude terminal states.
#there are no actions defined for terminal state
if len(grid.state_actions[state])!=0:
# exclude terminal states.
# there are no actions defined for terminal state
if len(grid.state_actions[state]) != 0:
weighted_sum = 0
#accumulate weighted with probabilities values of possible next states
# accumulate weighted with probabilities values of possible next states
for action in actions[state]:
prob = action_prob[(state,action)]
next_state_value = values[action_result[(state,action)]]
prob = action_prob[(state, action)]
next_state_value = values[action_result[(state, action)]]
weighted_sum = weighted_sum + (prob * next_state_value)
new_value = rewards[state] + gamma * weighted_sum
new_values[state] = new_value
delta = max(delta, abs(old_value - new_value))
values = new_values
values[(3,3)]=0 # not part of the algorithm but I set up thigns this way, too lazy to fix
values[(3, 3)] = 0 # not part of the algorithm but I set up things this way, too lazy to fix
iteration = iteration + 1
if show_all_iterations: show_values(values,grid,iteration)
if show_all_iterations:
show_values(values, grid, iteration)
if delta < threshold:
print("policy evaluated")
show_values(values,grid,iteration)
show_values(values, grid, iteration)
break
# return values


# play

grid = init_simple_grid()
print('Evaluating performance...')
while(grid.current_state not in grid.terminal_states):
possible_actions = grid.state_actions[grid.current_state]
best_action = None
highest_value = float("-inf")
for action in possible_actions:
next_state = grid.action_result[(grid.current_state, action)]
next_state_value = values[next_state]
if next_state_value > highest_value:
highest_value = next_state_value
best_action = action
grid.show_current_state()
print(best_action)
grid.take_action(best_action)
grid.show_current_state()


if __name__ == '__main__':
run_algorithm(show_all_iterations=True)