Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Behavior Task Flexibility #323

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions omnigibson/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from omnigibson.envs.env_base import Environment
from omnigibson.envs.env_wrapper import EnvironmentWrapper, create_wrapper, REGISTERED_ENV_WRAPPERS
72 changes: 63 additions & 9 deletions omnigibson/envs/env_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gym
import numpy as np
from copy import deepcopy

import omnigibson as og
from omnigibson.objects import REGISTERED_OBJECTS
Expand Down Expand Up @@ -50,7 +51,9 @@ def __init__(
self._automatic_reset = automatic_reset
self._flatten_action_space = flatten_action_space
self._flatten_obs_space = flatten_obs_space
self.physics_timestep = physics_timestep
self.action_timestep = action_timestep
self.device = device

# Initialize other placeholders that will be filled in later
self._initial_pos_z_offset = None # how high to offset object placement to account for one action step of dropping
Expand All @@ -71,12 +74,6 @@ def __init__(
for config in configs:
merge_nested_dicts(base_dict=self.config, extra_dict=parse_config(config), inplace=True)

# Set the simulator settings
og.sim.set_simulation_dt(physics_dt=physics_timestep, rendering_dt=action_timestep)
og.sim.viewer_width = self.render_config["viewer_width"]
og.sim.viewer_height = self.render_config["viewer_height"]
og.sim.device = device

# Load this environment
self.load()

Expand Down Expand Up @@ -153,10 +150,22 @@ def _load_variables(self):
drop_distance = 0.5 * 9.8 * (self.action_timestep ** 2)
assert drop_distance < self._initial_pos_z_offset, "initial_pos_z_offset is too small for collision checking"

def _load_task(self):
def _load_task(self, task_config=None):
"""
Load task

Args:
task_confg (None or dict): If specified, custom task configuration to use. Otherwise, will use
self.task_config. Note that if a custom task configuration is specified, the internal task config
will be updated as well
"""
# Update internal config if specified
if task_config is not None:
# Copy task config, in case self.task_config and task_config are the same!
task_config = deepcopy(task_config)
self.task_config.clear()
self.task_config.update(task_config)

# Sanity check task to make sure it's valid
task_type = self.task_config["type"]
assert_valid_key(key=task_type, valid_keys=REGISTERED_TASKS, name="task type")
Expand Down Expand Up @@ -188,6 +197,13 @@ def _load_scene(self):
cls_type_descriptor="scene",
)
og.sim.import_scene(scene)

# Set the simulator settings
og.sim.set_simulation_dt(physics_dt=self.physics_timestep, rendering_dt=self.action_timestep)
og.sim.viewer_width = self.render_config["viewer_width"]
og.sim.viewer_height = self.render_config["viewer_height"]
og.sim.device = self.device

assert og.sim.is_stopped(), "Simulator must be stopped after loading scene!"

def _load_robots(self):
Expand Down Expand Up @@ -324,6 +340,31 @@ def load(self):
# Denote that the scene is loaded
self._loaded = True

def update_task(self, task_config):
"""
Updates the internal task using @task_config. NOTE: This will internally reset the environment as well!

Args:
task_config (dict): Task configuration for updating the new task
"""
# Make sure sim is playing
assert og.sim.is_playing(), "Update task should occur while sim is playing!"

# Denote scene as not loaded yet
self._loaded = False
og.sim.stop()
self._load_task(task_config=task_config)
og.sim.play()
self.reset()

# Load obs / action spaces
self.load_observation_space()
self._load_action_space()

# Scene is now loaded again
self._loaded = True


def close(self):
"""
Clean up the environment and shut down the simulation.
Expand Down Expand Up @@ -443,7 +484,7 @@ def reset(self):
# Grab and return observations
obs = self.get_obs()

if self.observation_space is not None and not self.observation_space.contains(obs):
if self._loaded and not self.observation_space.contains(obs):
# Flatten obs, and print out all keys and values
log.error("OBSERVATION SPACE:")
for key, value in recursively_generate_flat_dict(dic=self.observation_space).items():
Expand Down Expand Up @@ -543,6 +584,14 @@ def task_config(self):
"""
return self.config["task"]

@property
def wrapper_config(self):
"""
Returns:
dict: Wrapper-specific configuration kwargs
"""
return self.config["wrapper"]

@property
def default_config(self):
"""
Expand Down Expand Up @@ -584,5 +633,10 @@ def default_config(self):
# Task kwargs
"task": {
"type": "DummyTask",
}
},

# Wrapper kwargs
"wrapper": {
"type": None,
},
}
39 changes: 38 additions & 1 deletion omnigibson/envs/env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
from omnigibson.utils.python_utils import Wrapper
from omnigibson.utils.python_utils import Registerable, classproperty, create_class_from_registry_and_config
from omnigibson.utils.ui_utils import create_module_logger
from copy import deepcopy

# Global dicts that will contain mappings
REGISTERED_ENV_WRAPPERS = dict()

class EnvironmentWrapper(Wrapper):
# Create module logger
log = create_module_logger(module_name=__name__)


def create_wrapper(env):
"""
Wraps environment @env with wrapper defined by env.wrapper_config
"""
wrapper_cfg = deepcopy(env.wrapper_config)
wrapper_type = wrapper_cfg.pop("type")
wrapper_cfg["env"] = env

return create_class_from_registry_and_config(
cls_name=wrapper_type,
cls_registry=REGISTERED_ENV_WRAPPERS,
cfg=wrapper_cfg,
cls_type_descriptor="wrapper",
)


class EnvironmentWrapper(Wrapper, Registerable):
"""
Base class for all environment wrappers in OmniGibson. In general, reset(), step(), and observation_spec() should
be overwritten
Expand Down Expand Up @@ -50,3 +75,15 @@ def observation_spec(self):
"""
return self.env.observation_spec()

@classproperty
def _do_not_register_classes(cls):
# Don't register this class since it's an abstract template
classes = super()._do_not_register_classes
classes.add("EnvironmentWrapper")
return classes

@classproperty
def _cls_registry(cls):
# Global robot registry
global REGISTERED_ENV_WRAPPERS
return REGISTERED_ENV_WRAPPERS
7 changes: 4 additions & 3 deletions omnigibson/tasks/behavior_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from omnigibson.robots.robot_base import BaseRobot
from omnigibson.systems.system_base import get_system, add_callback_on_system_init, add_callback_on_system_clear, \
REGISTERED_SYSTEMS
from omnigibson.scenes.scene_base import Scene
from omnigibson.scenes.interactive_traversable_scene import InteractiveTraversableScene
from omnigibson.utils.bddl_utils import OmniGibsonBDDLBackend, BDDLEntity, BEHAVIOR_ACTIVITIES, BDDLSampler
from omnigibson.tasks.task_base import BaseTask
Expand Down Expand Up @@ -142,7 +143,7 @@ def verify_scene_and_task_config(cls, scene_cfg, task_cfg):
task_cfg.get("predefined_problem", None) is not None else task_cfg["activity_name"]
if scene_file is None and scene_instance is None and not task_cfg["online_object_sampling"]:
scene_instance = cls.get_cached_activity_scene_filename(
scene_model=scene_cfg["scene_model"],
scene_model=scene_cfg.get("scene_model", "Scene"),
activity_name=activity_name,
activity_definition_id=task_cfg.get("activity_definition_id", 0),
activity_instance_id=task_cfg.get("activity_instance_id", 0),
Expand Down Expand Up @@ -522,8 +523,8 @@ def name(self):

@classproperty
def valid_scene_types(cls):
# Must be an interactive traversable scene
return {InteractiveTraversableScene}
# Any scene can be used
return {Scene}

@classproperty
def default_termination_config(cls):
Expand Down
14 changes: 13 additions & 1 deletion omnigibson/tasks/task_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, termination_config=None, reward_config=None):
self._loaded = False
self._reward = None
self._done = None
self._success = None
self._info = None
self._low_dim_obs_dim = None

Expand Down Expand Up @@ -160,7 +161,8 @@ def _reset_variables(self, env):
"""
# By default, reset reward, done, and info
self._reward = None
self._done = None
self._done = False
self._success = False
self._info = None

def reset(self, env):
Expand Down Expand Up @@ -311,6 +313,7 @@ def step(self, env, action):
# Update the internal state of this task
self._reward = reward
self._done = done
self._success = done_info["success"]
self._info = {
"reward": reward_info,
"done": done_info,
Expand Down Expand Up @@ -344,6 +347,15 @@ def done(self):
assert self._done is not None, "At least one step() must occur before done can be calculated!"
return self._done

@property
def success(self):
"""
Returns:
bool: Whether this task has succeeded or not
"""
assert self._success is not None, "At least one step() must occur before success can be calculated!"
return self._success

@property
def info(self):
"""
Expand Down
28 changes: 20 additions & 8 deletions omnigibson/utils/bddl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from omnigibson import object_states
from omnigibson.object_states.factory import _KINEMATIC_STATE_SET
from omnigibson.systems.system_base import is_system_active, get_system
from omnigibson.scenes.interactive_traversable_scene import InteractiveTraversableScene

# Create module logger
log = create_module_logger(module_name=__name__)
Expand Down Expand Up @@ -276,7 +277,7 @@ def __init__(
):
# Store internal variables from inputs
self._env = env
self._scene_model = self._env.scene.scene_model
self._scene_model = self._env.scene.scene_model if isinstance(self._env.scene, InteractiveTraversableScene) else None
self._agent = self._env.robots[0]
if debug:
gm.DEBUG = True
Expand Down Expand Up @@ -411,9 +412,9 @@ def _parse_inroom_object_room_assignment(self):
# Invalid room assignment
return f"You have assigned room type for [{obj_synset}], but [{obj_synset}] is sampleable. " \
f"Only non-sampleable (scene) objects can have room assignment."
if room_type not in og.sim.scene.seg_map.room_sem_name_to_ins_name:
if self._scene_model is not None and room_type not in og.sim.scene.seg_map.room_sem_name_to_ins_name:
# Missing room type
return f"Room type [{room_type}] missing in scene [{og.sim.scene.scene_model}]."
return f"Room type [{room_type}] missing in scene [{self._scene_model}]."
if room_type not in self._room_type_to_object_instance:
self._room_type_to_object_instance[room_type] = []
self._room_type_to_object_instance[room_type].append(obj_inst)
Expand Down Expand Up @@ -521,6 +522,12 @@ def _build_sampling_order(self):
# Sanity check kinematic objects -- any non-system must be kinematically sampled
remaining_kinematic_entities = nonparticle_entities - unsampleable_obj_instances - \
self._inroom_object_instances - set.union(*(self._object_sampling_orders["kinematic"] + [set()]))

# Possibly remove the agent entity if we're in an empty scene -- i.e.: no kinematic sampling needed for the
# agent
if self._scene_model is None:
remaining_kinematic_entities -= {"agent.n.01_1"}
cremebrule marked this conversation as resolved.
Show resolved Hide resolved

if len(remaining_kinematic_entities) != 0:
return f"Some objects do not have any kinematic condition defined for them in the initial conditions: " \
f"{', '.join(remaining_kinematic_entities)}"
Expand Down Expand Up @@ -566,7 +573,8 @@ def _build_inroom_object_scope(self):
valid_models = {cat: set(get_all_object_category_models_with_abilities(cat, abilities))
for cat in categories}

for room_inst in og.sim.scene.seg_map.room_sem_name_to_ins_name[room_type]:
room_insts = [None] if self._scene_model is None else og.sim.scene.seg_map.room_sem_name_to_ins_name[room_type]
for room_inst in room_insts:
# A list of scene objects that satisfy the requested categories
room_objs = og.sim.scene.object_registry("in_rooms", room_inst, default_val=[])
scene_objs = [obj for obj in room_objs if obj.category in categories and obj.model in valid_models[obj.category]]
Expand Down Expand Up @@ -660,11 +668,15 @@ def _filter_object_scope(self, input_object_scope, conditions, condition_type):
filtered_object_scope[room_type][scene_obj][room_inst].append(obj)

# Compute most problematic objects
problematic_objs_by_proportion = defaultdict(list)
for child_scope_name, parent_obj_names in problematic_objs.items():
problematic_objs_by_proportion[np.mean(list(parent_obj_names.values()))].append(child_scope_name)
if len(problematic_objs) == 0:
max_problematic_objs = []
else:
problematic_objs_by_proportion = defaultdict(list)
for child_scope_name, parent_obj_names in problematic_objs.items():
problematic_objs_by_proportion[np.mean(list(parent_obj_names.values()))].append(child_scope_name)
max_problematic_objs = problematic_objs_by_proportion[min(problematic_objs_by_proportion.keys())]

return filtered_object_scope, problematic_objs_by_proportion[min(problematic_objs_by_proportion.keys())]
return filtered_object_scope, max_problematic_objs

def _consolidate_room_instance(self, filtered_object_scope, condition_type):
"""
Expand Down
Loading