-
Notifications
You must be signed in to change notification settings - Fork 136
Pixels-Based Sim2Real Demo for Aloha Peg Insertion #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Andrew, this is really awesome. First two high-level comments are: [1] move/rename s2r
to follow the same pattern we have elsewhere in the repo (potentially nix aloha single_peg that exists currently, since this is a superior version for the teacher policy), and [2] nix the param files, or convert the visionMLP to orbax (it's small enough that I think it's OK to include in git history).
Really excited to help get this checked in!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Andrew-Luo1 ! LGTM modulo the use of pickles and small nits. I'm happy to make the pkl change once I get a good GPU again and can run this locally (for now I'm a bit bottlenecked on the hardware side). If you get to making the changes, we can merge this in.
|
||
|
||
def apply_line_noise(img, line_noise): | ||
return _or_reduce(jp.stack([img, line_noise]), axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this just be jp.where(line_noise != 0, line_noise, img)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Andrew-Luo1 same question as before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes nice catch
learning/train_jax_ppo.py
Outdated
@@ -361,6 +355,12 @@ def progress(num_steps, metrics): | |||
print(f"Time to JIT compile: {times[1] - times[0]}") | |||
print(f"Time to train: {times[-1] - times[1]}") | |||
|
|||
if _SAVE_PARAMS_PATH.value is not None: | |||
model.save_params(epath.Path(_SAVE_PARAMS_PATH.value).resolve(), params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Andrew-Luo1 would really like to not use the pkl stuff, is this absolutely necessary?
@@ -37,6 +37,9 @@ def get_assets() -> Dict[str, bytes]: | |||
path = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls" | |||
mjx_env.update_assets(assets, path, "*.xml") | |||
mjx_env.update_assets(assets, path / "assets") | |||
path = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls" / "s2r" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no longer needed FWIU
|
||
f_pick_teacher = pathlib.Path(__file__).parent / 'params' / 'AlohaPick.prms' | ||
f_insert_teacher = ( | ||
pathlib.Path(__file__).parent / 'params' / 'AlohaPegInsertion.prms' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess these were removed and not converted to orbax? Nbd, I would still merge, but at least add a comment or more informative error
import pathlib | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
|
||
from brax.io import model as brax_loader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would really like brax.io.model nixed from this PR if it can be redone using the existing checkpointing mechanism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming is getting a bit confusing. "single" is overloaded to mean single peg and single arm
Maybe name is "mjx_aloha_single_arm.xml" and similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe mjx_half_aloha and mjx_half_scene? Fine with the single_arm convention too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM for half!
Apologies for the very late update - life's been hectic. It'd be great to get another review! |
A demo of using Madrona MJX as a "real-world adapter". A standard playground-style state-based peg-insertion policy is first trained in sim. Madrona MJX is then used to distill this policy into a deployable pixel-based policy in 3 minutes.
Note that this PR relies on a concurrent Brax PR that implements Online Dagger for behaviour cloning.
Please see the technical report for more details on the Aloha Peg Insertion task.
Both phases of the teacher policy show stable training.

The distillation to the student is also stable. (2.5e6 samples corresponds to 3m30s wall-clock)

Looking forward to incorporating any feedback!