diff --git a/omnigibson/object_states/__init__.py b/omnigibson/object_states/__init__.py index f36cd7b61..1a3384c1a 100644 --- a/omnigibson/object_states/__init__.py +++ b/omnigibson/object_states/__init__.py @@ -21,6 +21,7 @@ from omnigibson.object_states.particle_modifier import ParticleRemover, ParticleApplier from omnigibson.object_states.particle_source_or_sink import ParticleSource, ParticleSink from omnigibson.object_states.pose import Pose +from omnigibson.object_states.robot_related_states import InHandOfRobot from omnigibson.object_states.saturated import Saturated from omnigibson.object_states.temperature import Temperature from omnigibson.object_states.toggle import ToggledOn diff --git a/omnigibson/object_states/factory.py b/omnigibson/object_states/factory.py index 75b33aadf..ca0fc6119 100644 --- a/omnigibson/object_states/factory.py +++ b/omnigibson/object_states/factory.py @@ -36,6 +36,7 @@ Touching, Under, Covered, + InHandOfRobot, ] ) diff --git a/omnigibson/object_states/robot_related_states.py b/omnigibson/object_states/robot_related_states.py new file mode 100644 index 000000000..d7fc6b0c4 --- /dev/null +++ b/omnigibson/object_states/robot_related_states.py @@ -0,0 +1,69 @@ +import numpy as np +import omnigibson as og +from omnigibson.object_states.object_state_base import AbsoluteObjectState, BooleanStateMixin + + +_IN_REACH_DISTANCE_THRESHOLD = 2.0 + +_IN_FOV_PIXEL_FRACTION_THRESHOLD = 0.05 + + +def _get_robot(): + from omnigibson.robots import ManipulationRobot + valid_robots = [robot for robot in og.sim.scene.robots if isinstance(robot, ManipulationRobot)] + if not valid_robots: + return None + + if len(valid_robots) > 1: + raise ValueError("Multiple robots found.") + + return valid_robots[0] + + +# class InReachOfRobot(AbsoluteObjectState, BooleanStateMixin): +# def _compute_value(self): +# robot = _get_robot(self.simulator) +# if not robot: +# return False + +# robot_pos = robot.get_position() +# object_pos = self.obj.get_position() +# return np.linalg.norm(object_pos - np.array(robot_pos)) < _IN_REACH_DISTANCE_THRESHOLD + + +class InHandOfRobot(AbsoluteObjectState, BooleanStateMixin): + def _get_value(self): + robot = _get_robot() + if not robot: + return False + + return any( + robot._ag_obj_in_hand[arm] == self.obj + for arm in robot.arm_names + ) + + +# class InFOVOfRobot(AbsoluteObjectState, BooleanStateMixin): +# @staticmethod +# def get_optional_dependencies(): +# return AbsoluteObjectState.get_optional_dependencies() + [ObjectsInFOVOfRobot] + +# def _get_value(self): +# robot = _get_robot(self.simulator) +# if not robot: +# return False + +# body_ids = set(self.obj.get_body_ids()) +# return not body_ids.isdisjoint(robot.states[ObjectsInFOVOfRobot].get_value()) + + +# class ObjectsInFOVOfRobot(AbsoluteObjectState): +# def _get_value(self): +# # Pass the FOV through the instance-to-body ID mapping. +# seg = self.simulator.renderer.render_single_robot_camera(self.obj, modes="ins_seg")[0][:, :, 0] +# seg = np.round(seg * MAX_INSTANCE_COUNT).astype(int) +# body_ids = self.simulator.renderer.get_pb_ids_for_instance_ids(seg) + +# # Pixels that don't contain an object are marked -1 but we don't want to include that +# # as a body ID. +# return set(np.unique(body_ids)) - {-1} \ No newline at end of file diff --git a/tests/test_symbolic_primitives.py b/tests/test_symbolic_primitives.py index 1d0341eb5..4b53cab61 100644 --- a/tests/test_symbolic_primitives.py +++ b/tests/test_symbolic_primitives.py @@ -159,6 +159,12 @@ def sponge(env): def knife(env): return next(iter(env.scene.object_registry("category", "carving_knife"))) +def test_in_hand_state(env, prim_gen, steak): + assert not steak.states[object_states.InHandOfRobot].get_value() + for action in prim_gen.apply_ref(SymbolicSemanticActionPrimitiveSet.GRASP, steak): + env.step(action) + assert steak.states[object_states.InHandOfRobot].get_value() + # def test_navigate(): # pass