Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance pi0 model inference #872

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

xuaner233
Copy link

This is a simple change for pi0 model inference, along with minor fix for loss_dict in training part.

  1. Pass task into observation for VLA model(pi0)
  2. Update loss_dict stats data format.

What this does

Add inference support for pi0 model.

How it was tested

  • first train pi0 with dataset, e.g.
python lerobot/scripts/train.py \
  --steps=40000 \
  --policy.type=pi0 \
  --dataset.repo_id=xuaner233/so100_grasp_place_20250313 \
  --wandb.enable=true \
  --wandb.disable_artifact=true
  • then inference with trained pi0 model, set the control.single_task as pi0's text prompt for task:
HF_USER=xuaner233
REPO_ID="${HF_USER}/eval_pi0_so100_test"

python lerobot/scripts/control_robot.py \
  --robot.type=so100 \
  --control.type=record \
  --control.fps=30 \
  --control.single_task="Grasp a white cube and put it in the bin." \
  --control.repo_id=${REPO_ID} \
  --control.tags='["pi0"]' \
  --control.warmup_time_s=5 \
  --control.episode_time_s=300 \
  --control.reset_time_s=10 \
  --control.num_episodes=1 \
  --control.push_to_hub=false \
  --control.policy.device=cuda \
  --control.policy.path=outputs/train/2025-03-14_pi0/checkpoints/last/pretrained_model

  1. Pass task into observation for VLA model(pi0)
  2. Update loss_dict stats data format.
@imstevenpmwork imstevenpmwork self-requested a review March 18, 2025 15:01
@imstevenpmwork imstevenpmwork added bug Something isn’t working correctly enhancement Suggestions for new features or improvements policies Items related to robot policies labels Mar 18, 2025
Copy link
Collaborator

@imstevenpmwork imstevenpmwork left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a comment in code

@@ -317,16 +317,16 @@ def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tens

loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
loss_dict["losses_after_forward"] = losses.mean().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning of getting the mean here?

Wouldn't it be better to use .detach() in here instead of clone()?

Copy link
Author

@xuaner233 xuaner233 Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Steven, thanks for the review!

mean() is to align the value of losses in loss_dict, which means these value need to align with the loss_dict["l2_loss"] aka the single mean loss value. Otherwise wandb would complain as below:

WARNING 2025-03-26 09:52:43 db_utils.py:116 WandB logging of key "losses_after_forward" was ignored as its type is not handled by this wrapper.
WARNING 2025-03-26 09:52:43 db_utils.py:116 WandB logging of key "losses_after_rm_padding" was ignored as its type is not handled by this wrapper.

As for the removal of clone() is because: instead of adding the actual loss "data", just calcuate mean() seems doesn't need the clone() or detach() for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn’t working correctly enhancement Suggestions for new features or improvements policies Items related to robot policies
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants