From 64b637f42f5c0cd421df91ee0ce3352a7504a7b7 Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Tue, 20 Aug 2024 23:26:41 +0000 Subject: [PATCH 1/6] mjcf2urdf urdf2mjcf --- kscale/formats/mjcf.py | 5 +- kscale/store/gen/api.py | 8 ++ kscale/store/mjcf.py | 193 ++++++++++++++++++++++++++++++++++++++++ kscale/store/urdf.py | 40 ++++++--- kscale/utils.py | 61 +++++++++++++ 5 files changed, 296 insertions(+), 11 deletions(-) create mode 100644 kscale/store/mjcf.py create mode 100644 kscale/utils.py diff --git a/kscale/formats/mjcf.py b/kscale/formats/mjcf.py index c4bce4f..3dace2c 100644 --- a/kscale/formats/mjcf.py +++ b/kscale/formats/mjcf.py @@ -408,7 +408,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: def _copy_stl_files(source_directory: str | Path, destination_directory: str | Path) -> None: # Ensure the destination directory exists, create if not - os.makedirs(destination_directory, exist_ok=True) + if not os.path.exists(destination_directory): + os.makedirs(destination_directory, exist_ok=True) + elif not os.path.isdir(destination_directory): + raise FileExistsError(f"Destination path exists and is not a directory: {destination_directory}") # Use glob to find all .stl files in the source directory pattern = os.path.join(source_directory, "*.stl") diff --git a/kscale/store/gen/api.py b/kscale/store/gen/api.py index ad477cf..514ae39 100644 --- a/kscale/store/gen/api.py +++ b/kscale/store/gen/api.py @@ -265,3 +265,11 @@ class GetBatchListingsResponse(BaseModel): class HTTPValidationError(BaseModel): detail: Optional[List[ValidationError]] = Field(None, title="Detail") + +class MjcfInfo(BaseModel): + artifact_id: str = Field(..., title="Artifact Id") + url: str = Field(..., title="Url") + +class MjcfResponse(BaseModel): + mjcf: Optional[MjcfInfo] + listing_id: str = Field(..., title="Listing Id") \ No newline at end of file diff --git a/kscale/store/mjcf.py b/kscale/store/mjcf.py new file mode 100644 index 0000000..8ceb443 --- /dev/null +++ b/kscale/store/mjcf.py @@ -0,0 +1,193 @@ +"""Utility functions for managing artifacts in the K-Scale store with MJCF support.""" + +import argparse +import asyncio +import logging +import os +import sys +import tarfile +from pathlib import Path +from typing import Literal, Sequence + +import httpx +import requests + +from kscale.conf import Settings +from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf +from kscale.store.gen.api import MjcfResponse + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_api_key() -> str: + api_key = Settings.load().store.api_key + if not api_key: + raise ValueError( + "API key not found! Get one here and set it as the `KSCALE_API_KEY` environment variable or in your" + "config file: https://kscale.store/keys" + ) + return api_key + + +def get_cache_dir() -> Path: + return Path(Settings.load().store.cache_dir).expanduser().resolve() + + +def fetch_mjcf_info(listing_id: str) -> MjcfResponse: + url = f"https://api.kscale.store/mjcf/info/{listing_id}" + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + response = requests.get(url, headers=headers) + response.raise_for_status() + return MjcfResponse(**response.json()) + + +async def download_artifact(artifact_url: str, cache_dir: Path) -> str: + filename = os.path.join(cache_dir, artifact_url.split("/")[-1]) + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + + if not os.path.exists(filename): + logger.info("Downloading artifact from %s" % artifact_url) + + async with httpx.AsyncClient() as client: + response = await client.get(artifact_url, headers=headers) + response.raise_for_status() + with open(filename, "wb") as f: + for chunk in response.iter_bytes(chunk_size=8192): + f.write(chunk) + logger.info("Artifact downloaded to %s" % filename) + else: + logger.info("Artifact already cached at %s" % filename) + + # Extract the .tgz file + extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0]) + if not os.path.exists(extract_dir): + logger.info(f"Extracting {filename} to {extract_dir}") + with tarfile.open(filename, "r:gz") as tar: + tar.extractall(path=extract_dir) + logger.info("Extraction complete") + else: + logger.info("Artifact already extracted at %s" % extract_dir) + + return extract_dir + + +def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Path) -> str: + tarball_path = os.path.join(cache_dir, output_filename) + with tarfile.open(tarball_path, "w:gz") as tar: + for root, _, files in os.walk(folder_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, start=folder_path) + tar.add(file_path, arcname=arcname) + logger.info("Added %s as %s" % (file_path, arcname)) + logger.info("Created tarball %s" % tarball_path) + return tarball_path + + +async def upload_artifact(tarball_path: str, listing_id: str) -> None: + url = f"https://api.kscale.store/mjcf/upload/{listing_id}" + headers = { + "Authorization": f"Bearer {get_api_key()}", + } + + async with httpx.AsyncClient() as client: + with open(tarball_path, "rb") as f: + files = {"file": (f.name, f, "application/gzip")} + response = await client.post(url, headers=headers, files=files) + + response.raise_for_status() + + logger.info("Uploaded artifact to %s" % url) + + +def main(args: Sequence[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="K-Scale MJCF Store", add_help=False) + parser.add_argument( + "command", + choices=["get", "info", "upload", "convert"], + help="The command to run", + ) + parser.add_argument("listing_id", help="The listing ID to operate on") + parsed_args, remaining_args = parser.parse_known_args(args) + + command: Literal["get", "info", "upload", "convert"] = parsed_args.command + listing_id: str = parsed_args.listing_id + + def get_listing_dir() -> Path: + (cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True) + return cache_dir + + match command: + case "get": + try: + mjcf_info = fetch_mjcf_info(listing_id) + + if mjcf_info.mjcf: + artifact_url = mjcf_info.mjcf.url + asyncio.run(download_artifact(artifact_url, get_listing_dir())) + else: + logger.info("No MJCF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch MJCF info: %s" % e) + sys.exit(1) + + case "info": + try: + mjcf_info = fetch_mjcf_info(listing_id) + + if mjcf_info.mjcf: + logger.info("MJCF Artifact ID: %s" % mjcf_info.mjcf.artifact_id) + logger.info("MJCF URL: %s" % mjcf_info.mjcf.url) + else: + logger.info("No MJCF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch MJCF info: %s" % e) + sys.exit(1) + + case "upload": + parser = argparse.ArgumentParser(description="Upload an MJCF artifact to the K-Scale store") + parser.add_argument("folder_path", help="The path to the folder containing the MJCF files") + parsed_args = parser.parse_args(remaining_args) + folder_path = Path(parsed_args.folder_path).expanduser().resolve() + + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) + if urdf_or_mjcf == 'mjcf': + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + + try: + mjcf_info = fetch_mjcf_info(listing_id) + asyncio.run(upload_artifact(tarball_path, listing_id)) + except requests.RequestException as e: + logger.error("Failed to upload artifact: %s" % e) + sys.exit(1) + else: + logger.error("No MJCF files found in %s" % folder_path) + sys.exit(1) + + case "convert": + parser = argparse.ArgumentParser(description="Convert an MJCF to a URDF file") + parser.add_argument("file_path", help="The path of the MJCF file") + parsed_args = parser.parse_args(remaining_args) + file_path = Path(parsed_args.file_path).expanduser().resolve() + + if file_path.suffix == ".xml": + urdf_path = mjcf_to_urdf(file_path) + logger.info(f"Converted MJCF to URDF: {urdf_path}") + else: + logger.error("No MJCF files found in %s" % file_path) + sys.exit(1) + + case _: + logger.error("Invalid command") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index dd5c71b..40b313b 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -13,6 +13,7 @@ import requests from kscale.conf import Settings +from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf from kscale.store.gen.api import UrdfResponse # Set up logging @@ -109,13 +110,13 @@ def main(args: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False) parser.add_argument( "command", - choices=["get", "info", "upload"], + choices=["get", "info", "upload", "convert"], help="The command to run", ) parser.add_argument("listing_id", help="The listing ID to operate on") parsed_args, remaining_args = parser.parse_known_args(args) - command: Literal["get", "info", "upload"] = parsed_args.command + command: Literal["get", "info", "upload", "convert"] = parsed_args.command listing_id: str = parsed_args.listing_id def get_listing_dir() -> Path: @@ -155,16 +156,35 @@ def get_listing_dir() -> Path: parsed_args = parser.parse_args(remaining_args) folder_path = Path(parsed_args.folder_path).expanduser().resolve() - output_filename = f"{listing_id}.tgz" - tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) - - try: - urdf_info = fetch_urdf_info(listing_id) - asyncio.run(upload_artifact(tarball_path, listing_id)) - except requests.RequestException as e: - logger.error("Failed to upload artifact: %s" % e) + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) + if urdf_or_mjcf == "urdf": + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + + try: + urdf_info = fetch_urdf_info(listing_id) + asyncio.run(upload_artifact(tarball_path, listing_id)) + except requests.RequestException as e: + logger.error("Failed to upload artifact: %s" % e) + sys.exit(1) + else: + logger.error("No URDF files found in %s" % folder_path) sys.exit(1) + case "convert": + parser = argparse.ArgumentParser(description="Convert a URDF to an MJCF file") + parser.add_argument("file_path", help="The path of the URDF file") + parsed_args = parser.parse_args(remaining_args) + + # NOTE: folder path actually + file_path = Path(parsed_args.file_path).expanduser().resolve() + + if file_path.suffix == ".urdf": + urdf_to_mjcf(file_path.parent, file_path.stem) + logger.info("Converted URDF to MJCF") + else: + logger.error("%s is not a URDF file" % file_path) + sys.exit(1) case _: logger.error("Invalid command") sys.exit(1) diff --git a/kscale/utils.py b/kscale/utils.py new file mode 100644 index 0000000..e36c1b3 --- /dev/null +++ b/kscale/utils.py @@ -0,0 +1,61 @@ +import os +from pathlib import Path + +import argparse +from pathlib import Path +from typing import Any, Dict, Sequence + + +import os.path as osp + +import pybullet_utils.bullet_client as bullet_client +import pybullet_utils.urdfEditor as urdfEditor + +from kscale.formats import mjcf + + +def contains_urdf_or_mjcf(folder_path: Path) -> str: + urdf_found = False + xml_found = False + + for file in folder_path.iterdir(): + if file.suffix == ".urdf": + urdf_found = True + elif file.suffix == ".xml": + xml_found = True + + if urdf_found: + return "urdf" + elif xml_found: + return "mjcf" + else: + return None + + +def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: + # Extract the base name from the URDF file path + + # Loading the URDF file and adapting it to the MJCF format + mjcf_robot = mjcf.Robot(robot_name, urdf_path, mjcf.Compiler(angle="radian", meshdir="meshes")) + mjcf_robot.adapt_world() + + # Save the MJCF file with the base name + mjcf_robot.save(urdf_path.parent / f"{robot_name}.xml") + + +def mjcf_to_urdf(input_mjcf: str, ) -> None: + # Set output_path to the directory of the input_mjcf file + output_path = Path(input_mjcf).parent + + client = bullet_client.BulletClient() + objs: Dict[int, Any] = client.loadMJCF(input_mjcf, flags=client.URDF_USE_IMPLICIT_CYLINDER) + + for obj in objs: + humanoid = objs[obj] + ue = urdfEditor.UrdfEditor() + ue.initializeFromBulletBody(humanoid, client._client) + robot_name: str = str(client.getBodyInfo(obj)[1], "utf-8") + part_name: str = str(client.getBodyInfo(obj)[0], "utf-8") + save_visuals: bool = False + outpath: str = osp.join(output_path, "{}_{}.urdf".format(robot_name, part_name)) + ue.saveUrdf(outpath, save_visuals) From fefb7a5b76f9ee7898a3234377773f00b5ea0591 Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Tue, 20 Aug 2024 23:47:08 +0000 Subject: [PATCH 2/6] works --- kscale/formats/mjcf.py | 23 ++++++++++++++++++----- kscale/store/cli.py | 5 ++++- kscale/utils.py | 6 +++--- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/kscale/formats/mjcf.py b/kscale/formats/mjcf.py index 3dace2c..a163e44 100644 --- a/kscale/formats/mjcf.py +++ b/kscale/formats/mjcf.py @@ -92,6 +92,21 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: joint.set("frictionloss", str(self.frictionloss)) return joint +@dataclass +class Inertial: + mass: float | None = None + pos: tuple[float, float, float] | None = None + inertia: tuple[float, float, float] | None = None + + def to_xml(self, root: ET.Element | None = None) -> ET.Element: + inertial = ET.Element("inertial") if root is None else ET.SubElement(root, "inertial") + if self.mass is not None: + inertial.set("mass", str(self.mass)) + if self.pos is not None: + inertial.set("pos", " ".join(map(str, self.pos))) + if self.inertia is not None: + inertial.set("inertia", " ".join(map(str, self.inertia))) + return inertial @dataclass class Geom: @@ -145,7 +160,6 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: geom.set("density", str(self.density)) return geom - @dataclass class Body: name: str @@ -153,9 +167,7 @@ class Body: quat: tuple[float, float, float, float] | None = field(default=None) geom: Geom | None = field(default=None) joint: Joint | None = field(default=None) - - # TODO - Fix inertia, until then rely on Mujoco's engine - # inertial: Inertial = None + inertial: Inertial | None = field(default=None) # Add inertial property def to_xml(self, root: ET.Element | None = None) -> ET.Element: body = ET.Element("body") if root is None else ET.SubElement(root, "body") @@ -168,9 +180,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: self.joint.to_xml(body) if self.geom is not None: self.geom.to_xml(body) + if self.inertial is not None: + self.inertial.to_xml(body) # Add inertial to the XML return body - @dataclass class Flag: frictionloss: str | None = None diff --git a/kscale/store/cli.py b/kscale/store/cli.py index 174072f..9af1b71 100644 --- a/kscale/store/cli.py +++ b/kscale/store/cli.py @@ -3,7 +3,7 @@ import argparse from typing import Sequence -from kscale.store import pybullet, urdf +from kscale.store import pybullet, urdf, mjcf def main(args: Sequence[str] | None = None) -> None: @@ -12,6 +12,7 @@ def main(args: Sequence[str] | None = None) -> None: "subcommand", choices=[ "urdf", + "mjcf", "pybullet", ], help="The subcommand to run", @@ -21,6 +22,8 @@ def main(args: Sequence[str] | None = None) -> None: match parsed_args.subcommand: case "urdf": urdf.main(remaining_args) + case "mjcf": + mjcf.main(remaining_args) case "pybullet": pybullet.main(remaining_args) diff --git a/kscale/utils.py b/kscale/utils.py index e36c1b3..7f87010 100644 --- a/kscale/utils.py +++ b/kscale/utils.py @@ -43,12 +43,12 @@ def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: mjcf_robot.save(urdf_path.parent / f"{robot_name}.xml") -def mjcf_to_urdf(input_mjcf: str, ) -> None: +def mjcf_to_urdf(input_mjcf: Path) -> None: # Set output_path to the directory of the input_mjcf file - output_path = Path(input_mjcf).parent + output_path = input_mjcf.parent client = bullet_client.BulletClient() - objs: Dict[int, Any] = client.loadMJCF(input_mjcf, flags=client.URDF_USE_IMPLICIT_CYLINDER) + objs: Dict[int, Any] = client.loadMJCF(str(input_mjcf), flags=client.URDF_USE_IMPLICIT_CYLINDER) for obj in objs: humanoid = objs[obj] From 3f5de9ffd9b1097e4e3accc999b00c752ae89812 Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Wed, 21 Aug 2024 20:20:33 +0000 Subject: [PATCH 3/6] mjcf -> urdf fixes --- kscale/store/mjcf.py | 2 +- kscale/utils.py | 40 +++++++++++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/kscale/store/mjcf.py b/kscale/store/mjcf.py index 8ceb443..e23e348 100644 --- a/kscale/store/mjcf.py +++ b/kscale/store/mjcf.py @@ -181,7 +181,7 @@ def get_listing_dir() -> Path: urdf_path = mjcf_to_urdf(file_path) logger.info(f"Converted MJCF to URDF: {urdf_path}") else: - logger.error("No MJCF files found in %s" % file_path) + logger.error("%s is not an MJCF file" % file_path) sys.exit(1) case _: diff --git a/kscale/utils.py b/kscale/utils.py index 7f87010..622d07d 100644 --- a/kscale/utils.py +++ b/kscale/utils.py @@ -44,18 +44,40 @@ def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: def mjcf_to_urdf(input_mjcf: Path) -> None: - # Set output_path to the directory of the input_mjcf file + """Convert an MJCF file to a single URDF file with all parts combined.""" + + # Set output_path to the directory of the input MJCF file output_path = input_mjcf.parent + # Initialize the Bullet client client = bullet_client.BulletClient() + + # Load the MJCF model objs: Dict[int, Any] = client.loadMJCF(str(input_mjcf), flags=client.URDF_USE_IMPLICIT_CYLINDER) + # Initialize a single URDF editor to store all parts + combined_urdf_editor = urdfEditor.UrdfEditor() + + # Iterate over all objects in the MJCF model for obj in objs: - humanoid = objs[obj] - ue = urdfEditor.UrdfEditor() - ue.initializeFromBulletBody(humanoid, client._client) - robot_name: str = str(client.getBodyInfo(obj)[1], "utf-8") - part_name: str = str(client.getBodyInfo(obj)[0], "utf-8") - save_visuals: bool = False - outpath: str = osp.join(output_path, "{}_{}.urdf".format(robot_name, part_name)) - ue.saveUrdf(outpath, save_visuals) + humanoid = obj # Get the current object + part_urdf_editor = urdfEditor.UrdfEditor() + part_urdf_editor.initializeFromBulletBody(humanoid, client._client) + + # Add all links from the part URDF editor to the combined editor + for link in part_urdf_editor.urdfLinks: + if link not in combined_urdf_editor.urdfLinks: + combined_urdf_editor.urdfLinks.append(link) + + # Add all joints from the part URDF editor to the combined editor + for joint in part_urdf_editor.urdfJoints: + if joint not in combined_urdf_editor.urdfJoints: + combined_urdf_editor.urdfJoints.append(joint) + + # Set the output path for the combined URDF file + combined_urdf_path = osp.join(output_path, "combined_robot.urdf") + + # Save the combined URDF + combined_urdf_editor.saveUrdf(combined_urdf_path) + + print(f"Combined URDF saved to: {combined_urdf_path}") \ No newline at end of file From e49a11c0c480017d63f38f0290a4654391932bd9 Mon Sep 17 00:00:00 2001 From: Benjamin Bolte Date: Wed, 21 Aug 2024 19:12:48 -0700 Subject: [PATCH 4/6] format --- kscale/formats/mjcf.py | 4 ++++ kscale/store/gen/api.py | 4 +++- kscale/store/mjcf.py | 4 ++-- kscale/utils.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/kscale/formats/mjcf.py b/kscale/formats/mjcf.py index a163e44..e847da9 100644 --- a/kscale/formats/mjcf.py +++ b/kscale/formats/mjcf.py @@ -92,6 +92,7 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: joint.set("frictionloss", str(self.frictionloss)) return joint + @dataclass class Inertial: mass: float | None = None @@ -108,6 +109,7 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: inertial.set("inertia", " ".join(map(str, self.inertia))) return inertial + @dataclass class Geom: name: str | None = None @@ -160,6 +162,7 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: geom.set("density", str(self.density)) return geom + @dataclass class Body: name: str @@ -184,6 +187,7 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: self.inertial.to_xml(body) # Add inertial to the XML return body + @dataclass class Flag: frictionloss: str | None = None diff --git a/kscale/store/gen/api.py b/kscale/store/gen/api.py index 514ae39..c013321 100644 --- a/kscale/store/gen/api.py +++ b/kscale/store/gen/api.py @@ -266,10 +266,12 @@ class GetBatchListingsResponse(BaseModel): class HTTPValidationError(BaseModel): detail: Optional[List[ValidationError]] = Field(None, title="Detail") + class MjcfInfo(BaseModel): artifact_id: str = Field(..., title="Artifact Id") url: str = Field(..., title="Url") + class MjcfResponse(BaseModel): mjcf: Optional[MjcfInfo] - listing_id: str = Field(..., title="Listing Id") \ No newline at end of file + listing_id: str = Field(..., title="Listing Id") diff --git a/kscale/store/mjcf.py b/kscale/store/mjcf.py index e23e348..44862bf 100644 --- a/kscale/store/mjcf.py +++ b/kscale/store/mjcf.py @@ -155,9 +155,9 @@ def get_listing_dir() -> Path: parser.add_argument("folder_path", help="The path to the folder containing the MJCF files") parsed_args = parser.parse_args(remaining_args) folder_path = Path(parsed_args.folder_path).expanduser().resolve() - + urdf_or_mjcf = contains_urdf_or_mjcf(folder_path) - if urdf_or_mjcf == 'mjcf': + if urdf_or_mjcf == "mjcf": output_filename = f"{listing_id}.tgz" tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) diff --git a/kscale/utils.py b/kscale/utils.py index 622d07d..9c82d6f 100644 --- a/kscale/utils.py +++ b/kscale/utils.py @@ -80,4 +80,4 @@ def mjcf_to_urdf(input_mjcf: Path) -> None: # Save the combined URDF combined_urdf_editor.saveUrdf(combined_urdf_path) - print(f"Combined URDF saved to: {combined_urdf_path}") \ No newline at end of file + print(f"Combined URDF saved to: {combined_urdf_path}") From 5f44a6afca5f78ec6e8c2b1df56f300003b9065e Mon Sep 17 00:00:00 2001 From: Benjamin Bolte Date: Wed, 21 Aug 2024 19:18:34 -0700 Subject: [PATCH 5/6] fix lint --- .github/workflows/test.yml | 2 +- kscale/formats/__init__.py | 0 kscale/store/cli.py | 2 +- kscale/store/mjcf.py | 2 +- kscale/store/urdf.py | 2 +- kscale/utils.py | 38 +++++++++++++++++++++----------------- setup.py | 9 ++++++++- 7 files changed, 33 insertions(+), 22 deletions(-) create mode 100644 kscale/formats/__init__.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 04aa24d..81134bf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,7 +44,7 @@ jobs: - name: Install package run: | - pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[dev]' + pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[all]' - name: Run static checks run: | diff --git a/kscale/formats/__init__.py b/kscale/formats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kscale/store/cli.py b/kscale/store/cli.py index 9af1b71..ac4e3d0 100644 --- a/kscale/store/cli.py +++ b/kscale/store/cli.py @@ -3,7 +3,7 @@ import argparse from typing import Sequence -from kscale.store import pybullet, urdf, mjcf +from kscale.store import mjcf, pybullet, urdf def main(args: Sequence[str] | None = None) -> None: diff --git a/kscale/store/mjcf.py b/kscale/store/mjcf.py index 44862bf..949cae1 100644 --- a/kscale/store/mjcf.py +++ b/kscale/store/mjcf.py @@ -13,8 +13,8 @@ import requests from kscale.conf import Settings -from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf from kscale.store.gen.api import MjcfResponse +from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf # Set up logging logging.basicConfig(level=logging.INFO) diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index 40b313b..0762733 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -13,8 +13,8 @@ import requests from kscale.conf import Settings -from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf from kscale.store.gen.api import UrdfResponse +from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf # Set up logging logging.basicConfig(level=logging.INFO) diff --git a/kscale/utils.py b/kscale/utils.py index 9c82d6f..bcb6b2a 100644 --- a/kscale/utils.py +++ b/kscale/utils.py @@ -1,15 +1,8 @@ -import os -from pathlib import Path +"""Utility functions for the kscale package.""" -import argparse +import os from pathlib import Path -from typing import Any, Dict, Sequence - - -import os.path as osp - -import pybullet_utils.bullet_client as bullet_client -import pybullet_utils.urdfEditor as urdfEditor +from typing import Any, Dict from kscale.formats import mjcf @@ -26,10 +19,9 @@ def contains_urdf_or_mjcf(folder_path: Path) -> str: if urdf_found: return "urdf" - elif xml_found: + if xml_found: return "mjcf" - else: - return None + raise ValueError("No URDF or MJCF files found in the folder.") def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: @@ -43,8 +35,20 @@ def urdf_to_mjcf(urdf_path: Path, robot_name: str) -> None: mjcf_robot.save(urdf_path.parent / f"{robot_name}.xml") -def mjcf_to_urdf(input_mjcf: Path) -> None: - """Convert an MJCF file to a single URDF file with all parts combined.""" +def mjcf_to_urdf(input_mjcf: Path, name: str = "robot.urdf") -> Path: + """Convert an MJCF file to a single URDF file with all parts combined. + + Args: + input_mjcf: The path to the input MJCF file. + name: The name of the output URDF file. + + Returns: + The path to the output URDF file. + """ + try: + from pybullet_utils import bullet_client, urdfEditor # type: ignore[import-not-found] + except ImportError: + raise ImportError("To use PyBullet, do `pip install 'kscale[pybullet]'`.") # Set output_path to the directory of the input MJCF file output_path = input_mjcf.parent @@ -75,9 +79,9 @@ def mjcf_to_urdf(input_mjcf: Path) -> None: combined_urdf_editor.urdfJoints.append(joint) # Set the output path for the combined URDF file - combined_urdf_path = osp.join(output_path, "combined_robot.urdf") + combined_urdf_path = os.path.join(output_path, name) # Save the combined URDF combined_urdf_editor.saveUrdf(combined_urdf_path) - print(f"Combined URDF saved to: {combined_urdf_path}") + return Path(combined_urdf_path) diff --git a/setup.py b/setup.py index 03d0121..2bead08 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,9 @@ assert version_re is not None, "Could not find version in kscale/__init__.py" version: str = version_re.group(1) +# Additional packages. +requirements_pybullet = ["pybullet"] +requirements_all = requirements + requirements_dev + requirements_pybullet setup( name="kscale", @@ -35,5 +38,9 @@ python_requires=">=3.11", install_requires=requirements, tests_require=requirements_dev, - extras_require={"dev": requirements_dev}, + extras_require={ + "dev": requirements_dev, + "pybullet": requirements_pybullet, + "all": requirements_all, + }, ) From 95d89be83dd6ab61c94fd47bce987ff4a364e8e7 Mon Sep 17 00:00:00 2001 From: Benjamin Bolte Date: Wed, 21 Aug 2024 19:19:34 -0700 Subject: [PATCH 6/6] update generate --- kscale/store/gen/api.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/kscale/store/gen/api.py b/kscale/store/gen/api.py index c013321..79ad9d5 100644 --- a/kscale/store/gen/api.py +++ b/kscale/store/gen/api.py @@ -2,7 +2,7 @@ # generated by datamodel-codegen: # filename: openapi.json -# timestamp: 2024-08-19T06:07:36+00:00 +# timestamp: 2024-08-22T02:19:26+00:00 from __future__ import annotations @@ -265,13 +265,3 @@ class GetBatchListingsResponse(BaseModel): class HTTPValidationError(BaseModel): detail: Optional[List[ValidationError]] = Field(None, title="Detail") - - -class MjcfInfo(BaseModel): - artifact_id: str = Field(..., title="Artifact Id") - url: str = Field(..., title="Url") - - -class MjcfResponse(BaseModel): - mjcf: Optional[MjcfInfo] - listing_id: str = Field(..., title="Listing Id")