Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAIL is not trainig with image-based data #864

Open
aha85b opened this issue Jan 8, 2025 · 0 comments
Open

GAIL is not trainig with image-based data #864

aha85b opened this issue Jan 8, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@aha85b
Copy link

aha85b commented Jan 8, 2025

Bug description

I am new to imitation learning package.
The issue I am facing now is that, I want to train an agent using GAIL. However, I keep getting some errors regarding data shape, for example, I am using a custom environment extended Gymnasium, the error is about the class tuple I believe the env.reset() is triggering it. when I change the reset function to return only observation it fixed the error then I get another issue which is, I could it workaround.

I just noticed something while I was debugging this issue, gail.Gail does not like gymnasium and sb3.PPO() does not like gym. So the Gail class through an error due to in compatibility with gym and the other way around for PPO.

round: 0%| | 0/4 [00:00<?, ?it/s]


RuntimeError Traceback (most recent call last)
Cell In[27], line 1
----> 1 gail_trainer.train(10000)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/imitation/algorithms/adversarial/common.py:454, in AdversarialTrainer.train(self, total_timesteps, callback)
448 assert n_rounds >= 1, (
449 "No updates (need at least "
450 f"{self.gen_train_timesteps} timesteps, have only "
451 f"total_timesteps={total_timesteps})!"
452 )
453 for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
--> 454 self.train_gen(self.gen_train_timesteps)
455 for _ in range(self.n_disc_updates_per_round):
456 with networks.training(self.reward_train):
457 # switch to training mode (affects dropout, normalization)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/imitation/algorithms/adversarial/common.py:414, in AdversarialTrainer.train_gen(self, total_timesteps, learn_kwargs)
411 learn_kwargs = {}
413 with self.logger.accumulate_means("gen"):
--> 414 self.gen_algo.learn(
415 total_timesteps=total_timesteps,
416 reset_num_timesteps=False,
417 callback=self.gen_callback,
418 **learn_kwargs,
419 )
420 self._global_step += 1
422 gen_trajs, ep_lens = self.venv_buffering.pop_trajectories()

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/ppo/ppo.py:315, in PPO.learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
306 def learn(
307 self: SelfPPO,
308 total_timesteps: int,
(...)
313 progress_bar: bool = False,
314 ) -> SelfPPO:
--> 315 return super().learn(
316 total_timesteps=total_timesteps,
317 callback=callback,
318 log_interval=log_interval,
319 tb_log_name=tb_log_name,
320 reset_num_timesteps=reset_num_timesteps,
321 progress_bar=progress_bar,
322 )

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py:300, in OnPolicyAlgorithm.learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
297 assert self.env is not None
299 while self.num_timesteps < total_timesteps:
--> 300 continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
302 if not continue_training:
303 break

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py:179, in OnPolicyAlgorithm.collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps)
176 with th.no_grad():
177 # Convert to pytorch tensor or to TensorDict
178 obs_tensor = obs_as_tensor(self._last_obs, self.device)
--> 179 actions, values, log_probs = self.policy(obs_tensor)
180 actions = actions.cpu().numpy()
182 # Rescale and perform action

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/policies.py:645, in ActorCriticPolicy.forward(self, obs, deterministic)
637 """
638 Forward pass in all the networks (actor and critic)
639
(...)
642 :return: action, value and log probability of the action
643 """
644 # Preprocess the observation if needed
--> 645 features = self.extract_features(obs)
646 if self.share_features_extractor:
647 latent_pi, latent_vf = self.mlp_extractor(features)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/policies.py:672, in ActorCriticPolicy.extract_features(self, obs, features_extractor)
663 """
664 Preprocess the observation if needed and extract features.
665
(...)
669 features for the actor and the features for the critic.
670 """
671 if self.share_features_extractor:
--> 672 return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor)
673 else:
674 if features_extractor is not None:

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/policies.py:131, in BaseModel.extract_features(self, obs, features_extractor)
123 """
124 Preprocess the observation if needed and extract features.
125
(...)
128 :return: The extracted features
129 """
130 preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
--> 131 return features_extractor(preprocessed_obs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/stable_baselines3/common/torch_layers.py:106, in NatureCNN.forward(self, observations)
105 def forward(self, observations: th.Tensor) -> th.Tensor:
--> 106 return self.linear(self.cnn(observations))

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/anaconda3/envs/py39ai2thor/lib/python3.9/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
115 def forward(self, input: Tensor) -> Tensor:
--> 116 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x49 and 3136x512)

Steps to reproduce

This is the Code that get the issue
SEED = 42
learner = PPO(
env=env,
policy=CnnPolicy,
batch_size=32,
ent_coef=0.0,
learning_rate=0.0004,
gamma=0.95,
n_epochs=5,
seed=SEED,
# device='cpu'
)

reward_net = CnnRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
use_state=True,
use_action=True,
use_next_state=False,
use_done=False,
hwc_format=False

)

gail_trainer = GAIL(
demonstrations=transitions,
demo_batch_size=32,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=8,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)

gail_trainer.train(10000)

@aha85b aha85b added the bug Something isn't working label Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant