Skip to content

Commit

Permalink
add openvla simple inference
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanlinli17 committed Aug 18, 2024
1 parent 3667e65 commit 12e3747
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions simpler_env/simple_inference_visual_matching_prepackaged_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
--ckpt-path ./checkpoints/rt_1_tf_trained_for_000400120 --task google_robot_pick_coke_can --logging-root ./results_simple_eval/ --n-trajs 10
python simpler_env/simple_inference_visual_matching_prepackaged_envs.py --policy octo-small \
--ckpt-path None --task widowx_spoon_on_towel --logging-root ./results_simple_eval/ --n-trajs 10
python simpler_env/simple_inference_visual_matching_prepackaged_envs.py --policy openvla/openvla-7b \
--ckpt-path None --task google_robot_move_near_v1 --logging-root ./results_simple_eval/ --n-trajs 10
"""

import argparse
Expand All @@ -21,7 +23,7 @@

parser = argparse.ArgumentParser()

parser.add_argument("--policy", default="rt1", choices=["rt1", "octo-base", "octo-small"])
parser.add_argument("--policy", default="rt1", choices=["rt1", "octo-base", "octo-small", "openvla/openvla-7b"])
parser.add_argument(
"--ckpt-path",
type=str,
Expand All @@ -37,7 +39,7 @@
parser.add_argument("--n-trajs", type=int, default=10)

args = parser.parse_args()
if args.policy in ["octo-base", "octo-small"]:
if args.policy in ["octo-base", "octo-small", "openvla/openvla-7b"]:
if args.ckpt_path in [None, "None"] or "rt_1_x" in args.ckpt_path:
args.ckpt_path = args.policy
if args.ckpt_path[-1] == "/":
Expand Down Expand Up @@ -75,6 +77,10 @@
from simpler_env.policies.octo.octo_model import OctoInference

model = OctoInference(model_type=args.ckpt_path, policy_setup=policy_setup, init_rng=0)
elif "openvla" in args.policy:
from simpler_env.policies.openvla.openvla_model import OpenVLAInference

model = OpenVLAInference(saved_model_path=args.ckpt_path, policy_setup=policy_setup)
else:
raise NotImplementedError()

Expand Down

0 comments on commit 12e3747

Please sign in to comment.