Skip to content

Commit

Permalink
Cleaned up chiefinvestigator and made it inherit from investigator
Browse files Browse the repository at this point in the history
however functionality to use the fixedpointfinder is removed for now.
  • Loading branch information
prstolpe committed Mar 20, 2020
1 parent 9ca72af commit 46cfa45
Showing 1 changed file with 18 additions and 25 deletions.
43 changes: 18 additions & 25 deletions analysis/chiefinvestigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from agent.ppo import PPOAgent
from analysis.investigation import Investigator
from utilities.wrappers import CombiWrapper, StateNormalizationWrapper, RewardNormalizationWrapper
from fixedpointfinder.FixedPointFinder import Adamfixedpointfinder
from fixedpointfinder.plot_utils import plot_fixed_points, plot_velocities
from analysis.rsa.rsa import RSA

class Chiefinvestigator:

class Chiefinvestigator(Investigator):

def __init__(self, agent_id: int, enforce_env_name: str = None):
"""Chiefinvestigator can assign investigator to inspect the model and produce high-level analysis.
Expand All @@ -20,7 +18,9 @@ def __init__(self, agent_id: int, enforce_env_name: str = None):
agent_id: ID of agent that will be analyzed.
env_name: Name of the gym environment that the agent was trained in. Default is set to CartPole-v1
"""

self.agent = PPOAgent.from_agent_state(agent_id, from_iteration='best')
super(Investigator, self).__init__(self.agent.policy, self.agent.distribution, self.agent.preprocessor)
self.env = self.agent.env
if enforce_env_name is not None:
print(f"Enforcing environment {enforce_env_name} over agents original environment. If you want to use"
Expand All @@ -29,8 +29,7 @@ def __init__(self, agent_id: int, enforce_env_name: str = None):
self.env = gym.make(enforce_env_name)
self.agent.preprocessor = CombiWrapper([StateNormalizationWrapper(self.agent.state_dim),
RewardNormalizationWrapper()]) # dirty fix, TODO remove soon
self.slave_investigator = Investigator.from_agent(self.agent)
self.weights = self.slave_investigator.get_layer_weights('policy_recurrent_layer')
self.weights = self.get_layer_weights('policy_recurrent_layer')
self.n_hidden = self.weights[1].shape[0]
self._get_rnn_type()

Expand All @@ -43,7 +42,7 @@ def _get_rnn_type(self):
self.rnn_type = 'lstm'

def get_layer_names(self):
return self.slave_investigator.list_layer_names()
return self.list_layer_names()

def parse_data(self, layer_name: str, previous_layer_name: str):
"""Get state, activation, action, reward over one episode. Parse data to output.
Expand All @@ -54,7 +53,7 @@ def parse_data(self, layer_name: str, previous_layer_name: str):
Returns:
activation_data, action_data, state_data, all_rewards
"""
states, activations, rewards, actions = self.slave_investigator.get_activations_over_episode(
states, activations, rewards, actions = self.get_activations_over_episode(
[layer_name, previous_layer_name],
self.env, False)

Expand Down Expand Up @@ -89,12 +88,6 @@ def get_data_over_single_run(self, layer_name: str, previous_layer_name: str):
actions = np.vstack(actions)

return activations, inputs, actions
# TODO: build neural network to predict grasp -> we trained a simple prediction model fp(z) containing one hidden
# layer with 64 units and ReLU activation, followed by a sigmoid output.

# TODO: sequence analysis ideas -> sequence pattern and so forth:
# one possibility could be sequence alignment, time series analysis (datacamp), rsa
# TODO: consider making chiefinvestigator and investigator child classes of BaseInvestigator

if __name__ == "__main__":
os.chdir("../") # remove if you want to search for ids in the analysis directory
Expand All @@ -113,14 +106,14 @@ def get_data_over_single_run(self, layer_name: str, previous_layer_name: str):
layer_names[1])

# employ fixedpointfinder
adamfpf = Adamfixedpointfinder(chiefinvesti.weights, chiefinvesti.rnn_type,
q_threshold=1e-06,
epsilon=0.01,
alr_decayr=1e-04,
max_iters=5000)
states, sampled_inputs = adamfpf.sample_inputs_and_states(activations_over_all_episodes,
inputs_over_all_episodes,
1000, 0.2)
sampled_inputs = np.zeros((states.shape[0], chiefinvesti.n_hidden))
fps = adamfpf.find_fixed_points(states, sampled_inputs)
plot_fixed_points(activations_over_all_episodes, fps, 4000, 1)
# adamfpf = Adamfixedpointfinder(chiefinvesti.weights, chiefinvesti.rnn_type,
# q_threshold=1e-06,
# epsilon=0.01,
# alr_decayr=1e-04,
# max_iters=5000)
# states, sampled_inputs = adamfpf.sample_inputs_and_states(activations_over_all_episodes,
# inputs_over_all_episodes,
# 1000, 0.2)
# sampled_inputs = np.zeros((states.shape[0], chiefinvesti.n_hidden))
# fps = adamfpf.find_fixed_points(states, sampled_inputs)
# plot_fixed_points(activations_over_all_episodes, fps, 4000, 1)

0 comments on commit 46cfa45

Please sign in to comment.