diff --git a/kscale/formats/mjcf.py b/kscale/formats/mjcf.py index c4bce4f..a163e44 100644 --- a/kscale/formats/mjcf.py +++ b/kscale/formats/mjcf.py @@ -92,6 +92,21 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: joint.set("frictionloss", str(self.frictionloss)) return joint +@dataclass +class Inertial: + mass: float | None = None + pos: tuple[float, float, float] | None = None + inertia: tuple[float, float, float] | None = None + + def to_xml(self, root: ET.Element | None = None) -> ET.Element: + inertial = ET.Element("inertial") if root is None else ET.SubElement(root, "inertial") + if self.mass is not None: + inertial.set("mass", str(self.mass)) + if self.pos is not None: + inertial.set("pos", " ".join(map(str, self.pos))) + if self.inertia is not None: + inertial.set("inertia", " ".join(map(str, self.inertia))) + return inertial @dataclass class Geom: @@ -145,7 +160,6 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: geom.set("density", str(self.density)) return geom - @dataclass class Body: name: str @@ -153,9 +167,7 @@ class Body: quat: tuple[float, float, float, float] | None = field(default=None) geom: Geom | None = field(default=None) joint: Joint | None = field(default=None) - - # TODO - Fix inertia, until then rely on Mujoco's engine - # inertial: Inertial = None + inertial: Inertial | None = field(default=None) # Add inertial property def to_xml(self, root: ET.Element | None = None) -> ET.Element: body = ET.Element("body") if root is None else ET.SubElement(root, "body") @@ -168,9 +180,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: self.joint.to_xml(body) if self.geom is not None: self.geom.to_xml(body) + if self.inertial is not None: + self.inertial.to_xml(body) # Add inertial to the XML return body - @dataclass class Flag: frictionloss: str | None = None @@ -408,7 +421,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: def _copy_stl_files(source_directory: str | Path, destination_directory: str | Path) -> None: # Ensure the destination directory exists, create if not - os.makedirs(destination_directory, exist_ok=True) + if not os.path.exists(destination_directory): + os.makedirs(destination_directory, exist_ok=True) + elif not os.path.isdir(destination_directory): + raise FileExistsError(f"Destination path exists and is not a directory: {destination_directory}") # Use glob to find all .stl files in the source directory pattern = os.path.join(source_directory, "*.stl") diff --git a/kscale/store/cli.py b/kscale/store/cli.py index 174072f..9af1b71 100644 --- a/kscale/store/cli.py +++ b/kscale/store/cli.py @@ -3,7 +3,7 @@ import argparse from typing import Sequence -from kscale.store import pybullet, urdf +from kscale.store import pybullet, urdf, mjcf def main(args: Sequence[str] | None = None) -> None: @@ -12,6 +12,7 @@ def main(args: Sequence[str] | None = None) -> None: "subcommand", choices=[ "urdf", + "mjcf", "pybullet", ], help="The subcommand to run", @@ -21,6 +22,8 @@ def main(args: Sequence[str] | None = None) -> None: match parsed_args.subcommand: case "urdf": urdf.main(remaining_args) + case "mjcf": + mjcf.main(remaining_args) case "pybullet": pybullet.main(remaining_args) diff --git a/kscale/store/gen/api.py b/kscale/store/gen/api.py index ad477cf..514ae39 100644 --- a/kscale/store/gen/api.py +++ b/kscale/store/gen/api.py @@ -265,3 +265,11 @@ class GetBatchListingsResponse(BaseModel): class HTTPValidationError(BaseModel): detail: Optional[List[ValidationError]] = Field(None, title="Detail") + +class MjcfInfo(BaseModel): + artifact_id: str = Field(..., title="Artifact Id") + url: str = Field(..., title="Url") + +class MjcfResponse(BaseModel): + mjcf: Optional[MjcfInfo] + listing_id: str = Field(..., title="Listing Id") \ No newline at end of file diff --git a/kscale/store/mjcf.py b/kscale/store/mjcf.py new file mode 100644 index 0000000..8ceb443 --- /dev/null +++ b/kscale/store/mjcf.py @@ -0,0 +1,193 @@ +"""Utility functions for managing artifacts in the K-Scale store with MJCF support.""" + +import argparse +import asyncio +import logging +import os +import sys +import tarfile +from pathlib import Path +from typing import Literal, Sequence + +import httpx +import requests + +from kscale.conf import Settings +from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf +from kscale.store.gen.api import MjcfResponse + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_api_key() -> str: + api_key = Settings.load().store.api_key + if not api_key: + raise ValueError( + "API key not found! Get one here and set it as the `KSCALE_API_KEY` environment variable or in your" + "config file: https://kscale.store/keys" + ) + return api_key + + +def get_cache_dir() -> Path: + return Path(Settings.load().store.cache_dir).expanduser().resolve() + + +def fetch_mjcf_info(listing_id: str) -> MjcfResponse: + url = f"https://api.kscale.store/mjcf/info/{listing_id}" + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + response = requests.get(url, headers=headers) + response.raise_for_status() + return MjcfResponse(**response.json()) + + +async def download_artifact(artifact_url: str, cache_dir: Path) -> str: + filename = os.path.join(cache_dir, artifact_url.split("/")[-1]) + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + + if not os.path.exists(filename): + logger.info("Downloading artifact from %s" % artifact_url) + + async with httpx.AsyncClient() as client: + response = await client.get(artifact_url, headers=headers) + response.raise_for_status() + with open(filename, "wb") as f: + for chunk in response.iter_bytes(chunk_size=8192): + f.write(chunk) + logger.info("Artifact downloaded to %s" % filename) + else: + logger.info("Artifact already cached at %s" % filename) + + # Extract the .tgz file + extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0]) + if not os.path.exists(extract_dir): + logger.info(f"Extracting {filename} to {extract_dir}") + with tarfile.open(filename, "r:gz") as tar: + tar.extractall(path=extract_dir) + logger.info("Extraction complete") + else: + logger.info("Artifact already extracted at %s" % extract_dir) + + return extract_dir + + +def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Path) -> str: + tarball_path = os.path.join(cache_dir, output_filename) + with tarfile.open(tarball_path, "w:gz") as tar: + for root, _, files in os.walk(folder_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, start=folder_path) + tar.add(file_path, arcname=arcname) + logger.info("Added %s as %s" % (file_path, arcname)) + logger.info("Created tarball %s" % tarball_path) + return tarball_path + + +async def upload_artifact(tarball_path: str, listing_id: str) -> None: + url = f"https://api.kscale.store/mjcf/upload/{listing_id}" + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + + async with httpx.AsyncClient() as client: + with open(tarball_path, "rb") as f: + files = {"file": (f.name, f, "application/gzip")} + response = await client.post(url, headers=headers, files=files) + + response.raise_for_status() + + logger.info("Uploaded artifact to %s" % url) + + +def main(args: Sequence[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="K-Scale MJCF Store", add_help=False) + parser.add_argument( + "command", + choices=["get", "info", "upload", "convert"], + help="The command to run", + ) + parser.add_argument("listing_id", help="The listing ID to operate on") + parsed_args, remaining_args = parser.parse_known_args(args) + + command: Literal["get", "info", "upload", "convert"] = parsed_args.command + listing_id: str = parsed_args.listing_id + + def get_listing_dir() -> Path: + (cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True) + return cache_dir + + match command: + case "get": + try: + mjcf_info = fetch_mjcf_info(listing_id) + + if mjcf_info.mjcf: + artifact_url = mjcf_info.mjcf.url + asyncio.run(download_artifact(artifact_url, get_listing_dir())) + else: + logger.info("No MJCF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch MJCF info: %s" % e) + sys.exit(1) + + case "info": + try: + mjcf_info = fetch_mjcf_info(listing_id) + + if mjcf_info.mjcf: + logger.info("MJCF Artifact ID: %s" % mjcf_info.mjcf.artifact_id) + logger.info("MJCF URL: %s" % mjcf_info.mjcf.url) + else: + logger.info("No MJCF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch MJCF info: %s" % e) + sys.exit(1) + + case "upload": + parser = argparse.ArgumentParser(description="Upload an MJCF artifact to the K-Scale store") + parser.add_argument("folder_path", help="The path to the folder containing the MJCF files") + parsed_args = parser.parse_args(remaining_args) + folder_path = Path(parsed_args.folder_path).expanduser().resolve() + + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) + if urdf_or_mjcf == 'mjcf': + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + + try: + mjcf_info = fetch_mjcf_info(listing_id) + asyncio.run(upload_artifact(tarball_path, listing_id)) + except requests.RequestException as e: + logger.error("Failed to upload artifact: %s" % e) + sys.exit(1) + else: + logger.error("No MJCF files found in %s" % folder_path) + sys.exit(1) + + case "convert": + parser = argparse.ArgumentParser(description="Convert an MJCF to a URDF file") + parser.add_argument("file_path", help="The path of the MJCF file") + parsed_args = parser.parse_args(remaining_args) + file_path = Path(parsed_args.file_path).expanduser().resolve() + + if file_path.suffix == ".xml": + urdf_path = mjcf_to_urdf(file_path) + logger.info(f"Converted MJCF to URDF: {urdf_path}") + else: + logger.error("No MJCF files found in %s" % file_path) + sys.exit(1) + + case _: + logger.error("Invalid command") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index dd5c71b..40b313b 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -13,6 +13,7 @@ import requests from kscale.conf import Settings +from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf from kscale.store.gen.api import UrdfResponse # Set up logging @@ -109,13 +110,13 @@ def main(args: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False) parser.add_argument( "command", - choices=["get", "info", "upload"], + choices=["get", "info", "upload", "convert"], help="The command to run", ) parser.add_argument("listing_id", help="The listing ID to operate on") parsed_args, remaining_args = parser.parse_known_args(args) - command: Literal["get", "info", "upload"] = parsed_args.command + command: Literal["get", "info", "upload", "convert"] = parsed_args.command listing_id: str = parsed_args.listing_id def get_listing_dir() -> Path: @@ -155,16 +156,35 @@ def get_listing_dir() -> Path: parsed_args = parser.parse_args(remaining_args) folder_path = Path(parsed_args.folder_path).expanduser().resolve() - output_filename = f"{listing_id}.tgz" - tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) - - try: - urdf_info = fetch_urdf_info(listing_id) - asyncio.run(upload_artifact(tarball_path, listing_id)) - except requests.RequestException as e: - logger.error("Failed to upload artifact: %s" % e) + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) + if urdf_or_mjcf == "urdf": + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + + try: + urdf_info = fetch_urdf_info(listing_id) + asyncio.run(upload_artifact(tarball_path, listing_id)) + except requests.RequestException as e: + logger.error("Failed to upload artifact: %s" % e) + sys.exit(1) + else: + logger.error("No URDF files found in %s" % folder_path) sys.exit(1) + case "convert": + parser = argparse.ArgumentParser(description="Convert a URDF to an MJCF file") + parser.add_argument("file_path", help="The path of the URDF file") + parsed_args = parser.parse_args(remaining_args) + + # NOTE: folder path actually + file_path = Path(parsed_args.file_path).expanduser().resolve() + + if file_path.suffix == ".urdf": + urdf_to_mjcf(file_path.parent, file_path.stem) + logger.info("Converted URDF to MJCF") + else: + logger.error("%s is not a URDF file" % file_path) + sys.exit(1) case _: logger.error("Invalid command") sys.exit(1) diff --git a/kscale/utils.py b/kscale/utils.py new file mode 100644 index 0000000..7f87010 --- /dev/null +++ b/kscale/utils.py @@ -0,0 +1,61 @@ +import os +from pathlib import Path + +import argparse +from pathlib import Path +from typing import Any, Dict, Sequence + + +import os.path as osp + +import pybullet_utils.bullet_client as bullet_client +import pybullet_utils.urdfEditor as urdfEditor + +from kscale.formats import mjcf + + +def contains_urdf_or_mjcf(folder_path: Path) -> str: + urdf_found = False + xml_found = False + + for file in folder_path.iterdir(): + if file.suffix == ".urdf": + urdf_found = True + elif file.suffix == ".xml": + xml_found = True + + if urdf_found: + return "urdf" + elif xml_found: + return "mjcf" + else: + return None + + +def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: + # Extract the base name from the URDF file path + + # Loading the URDF file and adapting it to the MJCF format + mjcf_robot = mjcf.Robot(robot_name, urdf_path, mjcf.Compiler(angle="radian", meshdir="meshes")) + mjcf_robot.adapt_world() + + # Save the MJCF file with the base name + mjcf_robot.save(urdf_path.parent / f"{robot_name}.xml") + + +def mjcf_to_urdf(input_mjcf: Path) -> None: + # Set output_path to the directory of the input_mjcf file + output_path = input_mjcf.parent + + client = bullet_client.BulletClient() + objs: Dict[int, Any] = client.loadMJCF(str(input_mjcf), flags=client.URDF_USE_IMPLICIT_CYLINDER) + + for obj in objs: + humanoid = objs[obj] + ue = urdfEditor.UrdfEditor() + ue.initializeFromBulletBody(humanoid, client._client) + robot_name: str = str(client.getBodyInfo(obj)[1], "utf-8") + part_name: str = str(client.getBodyInfo(obj)[0], "utf-8") + save_visuals: bool = False + outpath: str = osp.join(output_path, "{}_{}.urdf".format(robot_name, part_name)) + ue.saveUrdf(outpath, save_visuals)