Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-zhng committed Jan 21, 2025
1 parent 2075ed4 commit df031e1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
40 changes: 40 additions & 0 deletions stac_mjx/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

@dataclass
class ModelConfig:
"""Configuration for body model."""

MJCF_PATH: str
FTOL: float
ROOT_FTOL: float
Expand All @@ -44,13 +46,17 @@ class ModelConfig:

@dataclass
class MujocoConfig:
"""Configuration for Mujoco."""

solver: str
iterations: int
ls_iterations: int


@dataclass
class StacConfig:
"""Configuration for STAC."""

fit_offsets_path: str
ik_only_path: str
data_path: str
Expand All @@ -65,12 +71,16 @@ class StacConfig:

@dataclass
class Config:
"""Combined configuration for the model and STAC."""

model: ModelConfig
stac: StacConfig


@dataclass
class StacData:
"""Data structure for STAC output."""

qpos: np.ndarray
xpos: np.ndarray
xquat: np.ndarray
Expand Down Expand Up @@ -253,6 +263,12 @@ def _load_params(param_path):


def save_dict_to_hdf5(group, dictionary):
"""Save a dictionary to an HDF5 group.
Args:
group (h5py.Group): HDF5 group to save the dictionary to.
dictionary (dict): Dictionary to save.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
subgroup = group.create_group(key)
Expand All @@ -275,6 +291,22 @@ def save_data_to_h5(
qvel: np.ndarray,
file_path: str,
):
"""Save configuration and STAC data to an HDF5 file.
Args:
config (Config): Configuration dataclass.
kp_names (list): List of keypoint names.
names_qpos (list): List of qpos names.
names_xpos (list): List of xpos names.
kp_data (np.ndarray): Keypoint data array.
marker_sites (np.ndarray): Marker sites array.
offsets (np.ndarray): Offsets array.
qpos (np.ndarray): Qpos array.
xpos (np.ndarray): Xpos array.
xquat (np.ndarray): Xquat array.
qvel (np.ndarray): Qvel array.
file_path (str): Path to the HDF5 file.
"""
with h5py.File(file_path, "w") as f:

Check warning on line 310 in stac_mjx/io.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/io.py#L310

Added line #L310 was not covered by tests
# Save config as a YAML string
config_yaml = OmegaConf.to_yaml(OmegaConf.structured(config))
Expand All @@ -294,6 +326,14 @@ def save_data_to_h5(


def load_stac_data(file_path) -> tuple[Config, StacData]:
"""Load configuration and STAC data from an HDF5 file.
Args:
file_path (str): Path to the HDF5 file.
Returns:
tuple: A tuple containing the Config and StacData dataclasses.
"""
with h5py.File(file_path, "r") as f:

Check warning on line 337 in stac_mjx/io.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/io.py#L337

Added line #L337 was not covered by tests
# Load config from YAML string
config_yaml = f["config"][()].decode("utf-8")
Expand Down
2 changes: 1 addition & 1 deletion stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def mjx_setup(kp_data, mj_model):
def _package_data(
self, mjx_model, qposes, xposes, xquats, marker_sites, kp_data, batched=False
):
"""Extract pose, offsets, data, and all parameters. Infer qvel
"""Extract pose, offsets, data, and all parameters.
marker_sites is the marker positions for each frame--the rodent model's kp_data equivalent
"""
Expand Down
1 change: 0 additions & 1 deletion stac_mjx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def quat_mul(quat1, quat2):

def _clip_within_precision(number, low, high, precision=_TOL):
"""Clips input to provided range, checking precision.
Args:
number: (float) number to be clipped.
low: (float) lower bound.
Expand Down
1 change: 0 additions & 1 deletion stac_mjx/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def viz_stac(
show_marker_error=False,
):
"""Render forward kinematics from keypoint positions.
Args:
data_path (Union[Path, str]): Path to stac output pickle file
cfg (DictConfig): configs
Expand Down

0 comments on commit df031e1

Please sign in to comment.