Skip to content

Commit

Permalink
Merge pull request #840 from StanfordVL/fix/data-collection-wrapper
Browse files Browse the repository at this point in the history
fix demo collection bug
  • Loading branch information
wensi-ai authored Aug 20, 2024
2 parents b734935 + 3107e2c commit f448668
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions omnigibson/envs/data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def step(self, action):
- bool: truncated, i.e. whether this episode ended due to a time limit etc.
- dict: info, i.e. dictionary with any useful information
"""
# Make sure actions are always flattened numpy arrays
if isinstance(action, dict):
action = np.concatenate([act for act in action.values()])

next_obs, reward, terminated, truncated, info = self.env.step(action)
self.step_count += 1

Expand Down Expand Up @@ -133,21 +137,22 @@ def observation_spec(self):
"""
return self.env.observation_spec()

def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
def process_traj_to_hdf5(self, traj_data, traj_grp_name, nested_keys=("obs",)):
"""
Processes trajectory data @traj_data and stores them as a new group under @traj_grp_name.
Args:
traj_data (list of dict): Trajectory data, where each entry is a keyword-mapped set of data for a single
sim step
traj_grp_name (str): Name of the trajectory group to store
obs_key (str): Name of key corresponding to observation data in @traj_data. This specific data is
assumed to be its own keyword-mapped dictionary of observations, and will be parsed differently from
the rest of the data
nested_keys (list of str): Name of key(s) corresponding to nested data in @traj_data. This specific data
is assumed to be its own keyword-mapped dictionary of numpy array values, and will be parsed
differently from the rest of the data
Returns:
hdf5.Group: Generated hdf5 group storing the recorded trajectory data
"""
nested_keys = set(nested_keys)
data_grp = self.hdf5_file.require_group("data")
traj_grp = data_grp.create_group(traj_grp_name)
traj_grp.attrs["num_samples"] = len(traj_data)
Expand All @@ -156,11 +161,12 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
# We need to do this because we're not guaranteed to have a full set of keys at every trajectory step; e.g.
# if the first step only has state or observations but no actions
data = defaultdict(list)
data[obs_key] = defaultdict(list)
for key in nested_keys:
data[key] = defaultdict(list)

for step_data in traj_data:
for k, v in step_data.items():
if k == obs_key:
if k in nested_keys:
for mod, step_mod_data in v.items():
data[k][mod].append(step_mod_data)
else:
Expand All @@ -172,7 +178,7 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
continue

# Create datasets for all keys with valid data
if k == obs_key:
if k in nested_keys:
obs_grp = traj_grp.create_group(k)
for mod, traj_mod_data in dat.items():
obs_grp.create_dataset(mod, data=np.stack(traj_mod_data, axis=0))
Expand All @@ -189,7 +195,7 @@ def flush_current_traj(self):
success = self.env.task.success or not self.only_successes
if success and self.hdf5_file is not None:
traj_grp_name = f"demo_{self.traj_count}"
traj_grp = self.process_traj_to_hdf5(self.current_traj_history, traj_grp_name, obs_key="obs")
traj_grp = self.process_traj_to_hdf5(self.current_traj_history, traj_grp_name, nested_keys=["obs"])
self.traj_count += 1
else:
# Remove this demo
Expand Down Expand Up @@ -345,7 +351,7 @@ def _parse_step_data(self, action, obs, reward, terminated, truncated, info):

return step_data

def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
def process_traj_to_hdf5(self, traj_data, traj_grp_name, nested_keys=("obs",)):
# First pad all state values to be the same max (uniform) size
for step_data in traj_data:
state = step_data["state"]
Expand All @@ -354,7 +360,7 @@ def process_traj_to_hdf5(self, traj_data, traj_grp_name, obs_key="obs"):
step_data["state"] = padded_state

# Call super
traj_grp = super().process_traj_to_hdf5(traj_data, traj_grp_name, obs_key)
traj_grp = super().process_traj_to_hdf5(traj_data, traj_grp_name, nested_keys)

# Add in transition info
self.add_metadata(group=traj_grp, name="transitions", data=self.current_transitions)
Expand Down

0 comments on commit f448668

Please sign in to comment.