Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mjcf2urdf urdf2mjcf #3

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions kscale/formats/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint.set("frictionloss", str(self.frictionloss))
return joint

@dataclass
class Inertial:
mass: float | None = None
pos: tuple[float, float, float] | None = None
inertia: tuple[float, float, float] | None = None

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
inertial = ET.Element("inertial") if root is None else ET.SubElement(root, "inertial")
if self.mass is not None:
inertial.set("mass", str(self.mass))
if self.pos is not None:
inertial.set("pos", " ".join(map(str, self.pos)))
if self.inertia is not None:
inertial.set("inertia", " ".join(map(str, self.inertia)))
return inertial

@dataclass
class Geom:
Expand Down Expand Up @@ -145,17 +160,14 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:
geom.set("density", str(self.density))
return geom


@dataclass
class Body:
name: str
pos: tuple[float, float, float] | None = field(default=None)
quat: tuple[float, float, float, float] | None = field(default=None)
geom: Geom | None = field(default=None)
joint: Joint | None = field(default=None)

# TODO - Fix inertia, until then rely on Mujoco's engine
# inertial: Inertial = None
inertial: Inertial | None = field(default=None) # Add inertial property

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
body = ET.Element("body") if root is None else ET.SubElement(root, "body")
Expand All @@ -168,9 +180,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:
self.joint.to_xml(body)
if self.geom is not None:
self.geom.to_xml(body)
if self.inertial is not None:
self.inertial.to_xml(body) # Add inertial to the XML
return body


@dataclass
class Flag:
frictionloss: str | None = None
Expand Down Expand Up @@ -408,7 +421,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:

def _copy_stl_files(source_directory: str | Path, destination_directory: str | Path) -> None:
# Ensure the destination directory exists, create if not
os.makedirs(destination_directory, exist_ok=True)
if not os.path.exists(destination_directory):
os.makedirs(destination_directory, exist_ok=True)
elif not os.path.isdir(destination_directory):
raise FileExistsError(f"Destination path exists and is not a directory: {destination_directory}")

# Use glob to find all .stl files in the source directory
pattern = os.path.join(source_directory, "*.stl")
Expand Down
5 changes: 4 additions & 1 deletion kscale/store/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
from typing import Sequence

from kscale.store import pybullet, urdf
from kscale.store import pybullet, urdf, mjcf


def main(args: Sequence[str] | None = None) -> None:
Expand All @@ -12,6 +12,7 @@ def main(args: Sequence[str] | None = None) -> None:
"subcommand",
choices=[
"urdf",
"mjcf",
"pybullet",
],
help="The subcommand to run",
Expand All @@ -21,6 +22,8 @@ def main(args: Sequence[str] | None = None) -> None:
match parsed_args.subcommand:
case "urdf":
urdf.main(remaining_args)
case "mjcf":
mjcf.main(remaining_args)
case "pybullet":
pybullet.main(remaining_args)

Expand Down
8 changes: 8 additions & 0 deletions kscale/store/gen/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,11 @@ class GetBatchListingsResponse(BaseModel):

class HTTPValidationError(BaseModel):
detail: Optional[List[ValidationError]] = Field(None, title="Detail")

class MjcfInfo(BaseModel):
artifact_id: str = Field(..., title="Artifact Id")
url: str = Field(..., title="Url")

class MjcfResponse(BaseModel):
mjcf: Optional[MjcfInfo]
listing_id: str = Field(..., title="Listing Id")
193 changes: 193 additions & 0 deletions kscale/store/mjcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Utility functions for managing artifacts in the K-Scale store with MJCF support."""

import argparse
import asyncio
import logging
import os
import sys
import tarfile
from pathlib import Path
from typing import Literal, Sequence

import httpx
import requests

from kscale.conf import Settings
from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf
from kscale.store.gen.api import MjcfResponse

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_api_key() -> str:
api_key = Settings.load().store.api_key
if not api_key:
raise ValueError(
"API key not found! Get one here and set it as the `KSCALE_API_KEY` environment variable or in your"
"config file: https://kscale.store/keys"
)
return api_key


def get_cache_dir() -> Path:
return Path(Settings.load().store.cache_dir).expanduser().resolve()


def fetch_mjcf_info(listing_id: str) -> MjcfResponse:
url = f"https://api.kscale.store/mjcf/info/{listing_id}"
headers = {
"Authorization": f"Bearer {get_api_key()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return MjcfResponse(**response.json())


async def download_artifact(artifact_url: str, cache_dir: Path) -> str:
filename = os.path.join(cache_dir, artifact_url.split("/")[-1])
headers = {
"Authorization": f"Bearer {get_api_key()}",
}

if not os.path.exists(filename):
logger.info("Downloading artifact from %s" % artifact_url)

async with httpx.AsyncClient() as client:
response = await client.get(artifact_url, headers=headers)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_bytes(chunk_size=8192):
f.write(chunk)
logger.info("Artifact downloaded to %s" % filename)
else:
logger.info("Artifact already cached at %s" % filename)

# Extract the .tgz file
extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0])
if not os.path.exists(extract_dir):
logger.info(f"Extracting {filename} to {extract_dir}")
with tarfile.open(filename, "r:gz") as tar:
tar.extractall(path=extract_dir)
logger.info("Extraction complete")
else:
logger.info("Artifact already extracted at %s" % extract_dir)

return extract_dir


def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Path) -> str:
tarball_path = os.path.join(cache_dir, output_filename)
with tarfile.open(tarball_path, "w:gz") as tar:
for root, _, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, start=folder_path)
tar.add(file_path, arcname=arcname)
logger.info("Added %s as %s" % (file_path, arcname))
logger.info("Created tarball %s" % tarball_path)
return tarball_path


async def upload_artifact(tarball_path: str, listing_id: str) -> None:
url = f"https://api.kscale.store/mjcf/upload/{listing_id}"
headers = {
"Authorization": f"Bearer {get_api_key()}",
}

async with httpx.AsyncClient() as client:
with open(tarball_path, "rb") as f:
files = {"file": (f.name, f, "application/gzip")}
response = await client.post(url, headers=headers, files=files)

response.raise_for_status()

logger.info("Uploaded artifact to %s" % url)


def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="K-Scale MJCF Store", add_help=False)
parser.add_argument(
"command",
choices=["get", "info", "upload", "convert"],
help="The command to run",
)
parser.add_argument("listing_id", help="The listing ID to operate on")
parsed_args, remaining_args = parser.parse_known_args(args)

command: Literal["get", "info", "upload", "convert"] = parsed_args.command
listing_id: str = parsed_args.listing_id

def get_listing_dir() -> Path:
(cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True)
return cache_dir

match command:
case "get":
try:
mjcf_info = fetch_mjcf_info(listing_id)

if mjcf_info.mjcf:
artifact_url = mjcf_info.mjcf.url
asyncio.run(download_artifact(artifact_url, get_listing_dir()))
else:
logger.info("No MJCF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch MJCF info: %s" % e)
sys.exit(1)

case "info":
try:
mjcf_info = fetch_mjcf_info(listing_id)

if mjcf_info.mjcf:
logger.info("MJCF Artifact ID: %s" % mjcf_info.mjcf.artifact_id)
logger.info("MJCF URL: %s" % mjcf_info.mjcf.url)
else:
logger.info("No MJCF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch MJCF info: %s" % e)
sys.exit(1)

case "upload":
parser = argparse.ArgumentParser(description="Upload an MJCF artifact to the K-Scale store")
parser.add_argument("folder_path", help="The path to the folder containing the MJCF files")
parsed_args = parser.parse_args(remaining_args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

urdf_or_mjcf = contains_urdf_or_mjcf(folder_path)
if urdf_or_mjcf == 'mjcf':
output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
mjcf_info = fetch_mjcf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)
else:
logger.error("No MJCF files found in %s" % folder_path)
sys.exit(1)

case "convert":
parser = argparse.ArgumentParser(description="Convert an MJCF to a URDF file")
parser.add_argument("file_path", help="The path of the MJCF file")
parsed_args = parser.parse_args(remaining_args)
file_path = Path(parsed_args.file_path).expanduser().resolve()

if file_path.suffix == ".xml":
urdf_path = mjcf_to_urdf(file_path)
logger.info(f"Converted MJCF to URDF: {urdf_path}")
else:
logger.error("No MJCF files found in %s" % file_path)
sys.exit(1)

case _:
logger.error("Invalid command")
sys.exit(1)


if __name__ == "__main__":
main()
40 changes: 30 additions & 10 deletions kscale/store/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import requests

from kscale.conf import Settings
from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf
from kscale.store.gen.api import UrdfResponse

# Set up logging
Expand Down Expand Up @@ -109,13 +110,13 @@ def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False)
parser.add_argument(
"command",
choices=["get", "info", "upload"],
choices=["get", "info", "upload", "convert"],
help="The command to run",
)
parser.add_argument("listing_id", help="The listing ID to operate on")
parsed_args, remaining_args = parser.parse_known_args(args)

command: Literal["get", "info", "upload"] = parsed_args.command
command: Literal["get", "info", "upload", "convert"] = parsed_args.command
listing_id: str = parsed_args.listing_id

def get_listing_dir() -> Path:
Expand Down Expand Up @@ -155,16 +156,35 @@ def get_listing_dir() -> Path:
parsed_args = parser.parse_args(remaining_args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
urdf_or_mjcf = contains_urdf_or_mjcf(folder_path)
if urdf_or_mjcf == "urdf":
output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)
else:
logger.error("No URDF files found in %s" % folder_path)
sys.exit(1)

case "convert":
parser = argparse.ArgumentParser(description="Convert a URDF to an MJCF file")
parser.add_argument("file_path", help="The path of the URDF file")
parsed_args = parser.parse_args(remaining_args)

# NOTE: folder path actually
file_path = Path(parsed_args.file_path).expanduser().resolve()

if file_path.suffix == ".urdf":
urdf_to_mjcf(file_path.parent, file_path.stem)
logger.info("Converted URDF to MJCF")
else:
logger.error("%s is not a URDF file" % file_path)
sys.exit(1)
case _:
logger.error("Invalid command")
sys.exit(1)
Expand Down
Loading
Loading