diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 04aa24d..81134bf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,7 +44,7 @@ jobs: - name: Install package run: | - pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[dev]' + pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[all]' - name: Run static checks run: | diff --git a/kscale/formats/__init__.py b/kscale/formats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kscale/formats/mjcf.py b/kscale/formats/mjcf.py index c4bce4f..e847da9 100644 --- a/kscale/formats/mjcf.py +++ b/kscale/formats/mjcf.py @@ -93,6 +93,23 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: return joint +@dataclass +class Inertial: + mass: float | None = None + pos: tuple[float, float, float] | None = None + inertia: tuple[float, float, float] | None = None + + def to_xml(self, root: ET.Element | None = None) -> ET.Element: + inertial = ET.Element("inertial") if root is None else ET.SubElement(root, "inertial") + if self.mass is not None: + inertial.set("mass", str(self.mass)) + if self.pos is not None: + inertial.set("pos", " ".join(map(str, self.pos))) + if self.inertia is not None: + inertial.set("inertia", " ".join(map(str, self.inertia))) + return inertial + + @dataclass class Geom: name: str | None = None @@ -153,9 +170,7 @@ class Body: quat: tuple[float, float, float, float] | None = field(default=None) geom: Geom | None = field(default=None) joint: Joint | None = field(default=None) - - # TODO - Fix inertia, until then rely on Mujoco's engine - # inertial: Inertial = None + inertial: Inertial | None = field(default=None) # Add inertial property def to_xml(self, root: ET.Element | None = None) -> ET.Element: body = ET.Element("body") if root is None else ET.SubElement(root, "body") @@ -168,6 +183,8 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: self.joint.to_xml(body) if self.geom is not None: self.geom.to_xml(body) + if self.inertial is not None: + self.inertial.to_xml(body) # Add inertial to the XML return body @@ -408,7 +425,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: def _copy_stl_files(source_directory: str | Path, destination_directory: str | Path) -> None: # Ensure the destination directory exists, create if not - os.makedirs(destination_directory, exist_ok=True) + if not os.path.exists(destination_directory): + os.makedirs(destination_directory, exist_ok=True) + elif not os.path.isdir(destination_directory): + raise FileExistsError(f"Destination path exists and is not a directory: {destination_directory}") # Use glob to find all .stl files in the source directory pattern = os.path.join(source_directory, "*.stl") diff --git a/kscale/store/cli.py b/kscale/store/cli.py index 174072f..ac4e3d0 100644 --- a/kscale/store/cli.py +++ b/kscale/store/cli.py @@ -3,7 +3,7 @@ import argparse from typing import Sequence -from kscale.store import pybullet, urdf +from kscale.store import mjcf, pybullet, urdf def main(args: Sequence[str] | None = None) -> None: @@ -12,6 +12,7 @@ def main(args: Sequence[str] | None = None) -> None: "subcommand", choices=[ "urdf", + "mjcf", "pybullet", ], help="The subcommand to run", @@ -21,6 +22,8 @@ def main(args: Sequence[str] | None = None) -> None: match parsed_args.subcommand: case "urdf": urdf.main(remaining_args) + case "mjcf": + mjcf.main(remaining_args) case "pybullet": pybullet.main(remaining_args) diff --git a/kscale/store/gen/api.py b/kscale/store/gen/api.py index ad477cf..79ad9d5 100644 --- a/kscale/store/gen/api.py +++ b/kscale/store/gen/api.py @@ -2,7 +2,7 @@ # generated by datamodel-codegen: # filename: openapi.json -# timestamp: 2024-08-19T06:07:36+00:00 +# timestamp: 2024-08-22T02:19:26+00:00 from __future__ import annotations diff --git a/kscale/store/mjcf.py b/kscale/store/mjcf.py new file mode 100644 index 0000000..949cae1 --- /dev/null +++ b/kscale/store/mjcf.py @@ -0,0 +1,193 @@ +"""Utility functions for managing artifacts in the K-Scale store with MJCF support.""" + +import argparse +import asyncio +import logging +import os +import sys +import tarfile +from pathlib import Path +from typing import Literal, Sequence + +import httpx +import requests + +from kscale.conf import Settings +from kscale.store.gen.api import MjcfResponse +from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_api_key() -> str: + api_key = Settings.load().store.api_key + if not api_key: + raise ValueError( + "API key not found! Get one here and set it as the `KSCALE_API_KEY` environment variable or in your" + "config file: https://kscale.store/keys" + ) + return api_key + + +def get_cache_dir() -> Path: + return Path(Settings.load().store.cache_dir).expanduser().resolve() + + +def fetch_mjcf_info(listing_id: str) -> MjcfResponse: + url = f"https://api.kscale.store/mjcf/info/{listing_id}" + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + response = requests.get(url, headers=headers) + response.raise_for_status() + return MjcfResponse(**response.json()) + + +async def download_artifact(artifact_url: str, cache_dir: Path) -> str: + filename = os.path.join(cache_dir, artifact_url.split("/")[-1]) + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + + if not os.path.exists(filename): + logger.info("Downloading artifact from %s" % artifact_url) + + async with httpx.AsyncClient() as client: + response = await client.get(artifact_url, headers=headers) + response.raise_for_status() + with open(filename, "wb") as f: + for chunk in response.iter_bytes(chunk_size=8192): + f.write(chunk) + logger.info("Artifact downloaded to %s" % filename) + else: + logger.info("Artifact already cached at %s" % filename) + + # Extract the .tgz file + extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0]) + if not os.path.exists(extract_dir): + logger.info(f"Extracting {filename} to {extract_dir}") + with tarfile.open(filename, "r:gz") as tar: + tar.extractall(path=extract_dir) + logger.info("Extraction complete") + else: + logger.info("Artifact already extracted at %s" % extract_dir) + + return extract_dir + + +def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Path) -> str: + tarball_path = os.path.join(cache_dir, output_filename) + with tarfile.open(tarball_path, "w:gz") as tar: + for root, _, files in os.walk(folder_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, start=folder_path) + tar.add(file_path, arcname=arcname) + logger.info("Added %s as %s" % (file_path, arcname)) + logger.info("Created tarball %s" % tarball_path) + return tarball_path + + +async def upload_artifact(tarball_path: str, listing_id: str) -> None: + url = f"https://api.kscale.store/mjcf/upload/{listing_id}" + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + + async with httpx.AsyncClient() as client: + with open(tarball_path, "rb") as f: + files = {"file": (f.name, f, "application/gzip")} + response = await client.post(url, headers=headers, files=files) + + response.raise_for_status() + + logger.info("Uploaded artifact to %s" % url) + + +def main(args: Sequence[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="K-Scale MJCF Store", add_help=False) + parser.add_argument( + "command", + choices=["get", "info", "upload", "convert"], + help="The command to run", + ) + parser.add_argument("listing_id", help="The listing ID to operate on") + parsed_args, remaining_args = parser.parse_known_args(args) + + command: Literal["get", "info", "upload", "convert"] = parsed_args.command + listing_id: str = parsed_args.listing_id + + def get_listing_dir() -> Path: + (cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True) + return cache_dir + + match command: + case "get": + try: + mjcf_info = fetch_mjcf_info(listing_id) + + if mjcf_info.mjcf: + artifact_url = mjcf_info.mjcf.url + asyncio.run(download_artifact(artifact_url, get_listing_dir())) + else: + logger.info("No MJCF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch MJCF info: %s" % e) + sys.exit(1) + + case "info": + try: + mjcf_info = fetch_mjcf_info(listing_id) + + if mjcf_info.mjcf: + logger.info("MJCF Artifact ID: %s" % mjcf_info.mjcf.artifact_id) + logger.info("MJCF URL: %s" % mjcf_info.mjcf.url) + else: + logger.info("No MJCF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch MJCF info: %s" % e) + sys.exit(1) + + case "upload": + parser = argparse.ArgumentParser(description="Upload an MJCF artifact to the K-Scale store") + parser.add_argument("folder_path", help="The path to the folder containing the MJCF files") + parsed_args = parser.parse_args(remaining_args) + folder_path = Path(parsed_args.folder_path).expanduser().resolve() + + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) + if urdf_or_mjcf == "mjcf": + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + + try: + mjcf_info = fetch_mjcf_info(listing_id) + asyncio.run(upload_artifact(tarball_path, listing_id)) + except requests.RequestException as e: + logger.error("Failed to upload artifact: %s" % e) + sys.exit(1) + else: + logger.error("No MJCF files found in %s" % folder_path) + sys.exit(1) + + case "convert": + parser = argparse.ArgumentParser(description="Convert an MJCF to a URDF file") + parser.add_argument("file_path", help="The path of the MJCF file") + parsed_args = parser.parse_args(remaining_args) + file_path = Path(parsed_args.file_path).expanduser().resolve() + + if file_path.suffix == ".xml": + urdf_path = mjcf_to_urdf(file_path) + logger.info(f"Converted MJCF to URDF: {urdf_path}") + else: + logger.error("%s is not an MJCF file" % file_path) + sys.exit(1) + + case _: + logger.error("Invalid command") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index dd5c71b..0762733 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -14,6 +14,7 @@ from kscale.conf import Settings from kscale.store.gen.api import UrdfResponse +from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf # Set up logging logging.basicConfig(level=logging.INFO) @@ -109,13 +110,13 @@ def main(args: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False) parser.add_argument( "command", - choices=["get", "info", "upload"], + choices=["get", "info", "upload", "convert"], help="The command to run", ) parser.add_argument("listing_id", help="The listing ID to operate on") parsed_args, remaining_args = parser.parse_known_args(args) - command: Literal["get", "info", "upload"] = parsed_args.command + command: Literal["get", "info", "upload", "convert"] = parsed_args.command listing_id: str = parsed_args.listing_id def get_listing_dir() -> Path: @@ -155,16 +156,35 @@ def get_listing_dir() -> Path: parsed_args = parser.parse_args(remaining_args) folder_path = Path(parsed_args.folder_path).expanduser().resolve() - output_filename = f"{listing_id}.tgz" - tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) - - try: - urdf_info = fetch_urdf_info(listing_id) - asyncio.run(upload_artifact(tarball_path, listing_id)) - except requests.RequestException as e: - logger.error("Failed to upload artifact: %s" % e) + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) + if urdf_or_mjcf == "urdf": + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + + try: + urdf_info = fetch_urdf_info(listing_id) + asyncio.run(upload_artifact(tarball_path, listing_id)) + except requests.RequestException as e: + logger.error("Failed to upload artifact: %s" % e) + sys.exit(1) + else: + logger.error("No URDF files found in %s" % folder_path) sys.exit(1) + case "convert": + parser = argparse.ArgumentParser(description="Convert a URDF to an MJCF file") + parser.add_argument("file_path", help="The path of the URDF file") + parsed_args = parser.parse_args(remaining_args) + + # NOTE: folder path actually + file_path = Path(parsed_args.file_path).expanduser().resolve() + + if file_path.suffix == ".urdf": + urdf_to_mjcf(file_path.parent, file_path.stem) + logger.info("Converted URDF to MJCF") + else: + logger.error("%s is not a URDF file" % file_path) + sys.exit(1) case _: logger.error("Invalid command") sys.exit(1) diff --git a/kscale/utils.py b/kscale/utils.py new file mode 100644 index 0000000..bcb6b2a --- /dev/null +++ b/kscale/utils.py @@ -0,0 +1,87 @@ +"""Utility functions for the kscale package.""" + +import os +from pathlib import Path +from typing import Any, Dict + +from kscale.formats import mjcf + + +def contains_urdf_or_mjcf(folder_path: Path) -> str: + urdf_found = False + xml_found = False + + for file in folder_path.iterdir(): + if file.suffix == ".urdf": + urdf_found = True + elif file.suffix == ".xml": + xml_found = True + + if urdf_found: + return "urdf" + if xml_found: + return "mjcf" + raise ValueError("No URDF or MJCF files found in the folder.") + + +def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: + # Extract the base name from the URDF file path + + # Loading the URDF file and adapting it to the MJCF format + mjcf_robot = mjcf.Robot(robot_name, urdf_path, mjcf.Compiler(angle="radian", meshdir="meshes")) + mjcf_robot.adapt_world() + + # Save the MJCF file with the base name + mjcf_robot.save(urdf_path.parent / f"{robot_name}.xml") + + +def mjcf_to_urdf(input_mjcf: Path, name: str = "robot.urdf") -> Path: + """Convert an MJCF file to a single URDF file with all parts combined. + + Args: + input_mjcf: The path to the input MJCF file. + name: The name of the output URDF file. + + Returns: + The path to the output URDF file. + """ + try: + from pybullet_utils import bullet_client, urdfEditor # type: ignore[import-not-found] + except ImportError: + raise ImportError("To use PyBullet, do `pip install 'kscale[pybullet]'`.") + + # Set output_path to the directory of the input MJCF file + output_path = input_mjcf.parent + + # Initialize the Bullet client + client = bullet_client.BulletClient() + + # Load the MJCF model + objs: Dict[int, Any] = client.loadMJCF(str(input_mjcf), flags=client.URDF_USE_IMPLICIT_CYLINDER) + + # Initialize a single URDF editor to store all parts + combined_urdf_editor = urdfEditor.UrdfEditor() + + # Iterate over all objects in the MJCF model + for obj in objs: + humanoid = obj # Get the current object + part_urdf_editor = urdfEditor.UrdfEditor() + part_urdf_editor.initializeFromBulletBody(humanoid, client._client) + + # Add all links from the part URDF editor to the combined editor + for link in part_urdf_editor.urdfLinks: + if link not in combined_urdf_editor.urdfLinks: + combined_urdf_editor.urdfLinks.append(link) + + # Add all joints from the part URDF editor to the combined editor + for joint in part_urdf_editor.urdfJoints: + if joint not in combined_urdf_editor.urdfJoints: + combined_urdf_editor.urdfJoints.append(joint) + + # Set the output path for the combined URDF file + combined_urdf_path = os.path.join(output_path, name) + + # Save the combined URDF + combined_urdf_editor.saveUrdf(combined_urdf_path) + + return Path(combined_urdf_path) diff --git a/setup.py b/setup.py index 03d0121..2bead08 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,9 @@ assert version_re is not None, "Could not find version in kscale/__init__.py" version: str = version_re.group(1) +# Additional packages. +requirements_pybullet = ["pybullet"] +requirements_all = requirements + requirements_dev + requirements_pybullet setup( name="kscale", @@ -35,5 +38,9 @@ python_requires=">=3.11", install_requires=requirements, tests_require=requirements_dev, - extras_require={"dev": requirements_dev}, + extras_require={ + "dev": requirements_dev, + "pybullet": requirements_pybullet, + "all": requirements_all, + }, )