Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

File conversion #4

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- name: Install package
run: |
pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[dev]'
pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[all]'

- name: Run static checks
run: |
Expand Down
Empty file added kscale/formats/__init__.py
Empty file.
28 changes: 24 additions & 4 deletions kscale/formats/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:
return joint


@dataclass
class Inertial:
mass: float | None = None
pos: tuple[float, float, float] | None = None
inertia: tuple[float, float, float] | None = None

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
inertial = ET.Element("inertial") if root is None else ET.SubElement(root, "inertial")
if self.mass is not None:
inertial.set("mass", str(self.mass))
if self.pos is not None:
inertial.set("pos", " ".join(map(str, self.pos)))
if self.inertia is not None:
inertial.set("inertia", " ".join(map(str, self.inertia)))
return inertial


@dataclass
class Geom:
name: str | None = None
Expand Down Expand Up @@ -153,9 +170,7 @@ class Body:
quat: tuple[float, float, float, float] | None = field(default=None)
geom: Geom | None = field(default=None)
joint: Joint | None = field(default=None)

# TODO - Fix inertia, until then rely on Mujoco's engine
# inertial: Inertial = None
inertial: Inertial | None = field(default=None) # Add inertial property

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
body = ET.Element("body") if root is None else ET.SubElement(root, "body")
Expand All @@ -168,6 +183,8 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:
self.joint.to_xml(body)
if self.geom is not None:
self.geom.to_xml(body)
if self.inertial is not None:
self.inertial.to_xml(body) # Add inertial to the XML
return body


Expand Down Expand Up @@ -408,7 +425,10 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element:

def _copy_stl_files(source_directory: str | Path, destination_directory: str | Path) -> None:
# Ensure the destination directory exists, create if not
os.makedirs(destination_directory, exist_ok=True)
if not os.path.exists(destination_directory):
os.makedirs(destination_directory, exist_ok=True)
elif not os.path.isdir(destination_directory):
raise FileExistsError(f"Destination path exists and is not a directory: {destination_directory}")

# Use glob to find all .stl files in the source directory
pattern = os.path.join(source_directory, "*.stl")
Expand Down
5 changes: 4 additions & 1 deletion kscale/store/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
from typing import Sequence

from kscale.store import pybullet, urdf
from kscale.store import mjcf, pybullet, urdf


def main(args: Sequence[str] | None = None) -> None:
Expand All @@ -12,6 +12,7 @@ def main(args: Sequence[str] | None = None) -> None:
"subcommand",
choices=[
"urdf",
"mjcf",
"pybullet",
],
help="The subcommand to run",
Expand All @@ -21,6 +22,8 @@ def main(args: Sequence[str] | None = None) -> None:
match parsed_args.subcommand:
case "urdf":
urdf.main(remaining_args)
case "mjcf":
mjcf.main(remaining_args)
case "pybullet":
pybullet.main(remaining_args)

Expand Down
2 changes: 1 addition & 1 deletion kscale/store/gen/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# generated by datamodel-codegen:
# filename: openapi.json
# timestamp: 2024-08-19T06:07:36+00:00
# timestamp: 2024-08-22T02:19:26+00:00

from __future__ import annotations

Expand Down
193 changes: 193 additions & 0 deletions kscale/store/mjcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Utility functions for managing artifacts in the K-Scale store with MJCF support."""

import argparse
import asyncio
import logging
import os
import sys
import tarfile
from pathlib import Path
from typing import Literal, Sequence

import httpx
import requests

from kscale.conf import Settings
from kscale.store.gen.api import MjcfResponse
from kscale.utils import contains_urdf_or_mjcf, mjcf_to_urdf

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_api_key() -> str:
api_key = Settings.load().store.api_key
if not api_key:
raise ValueError(
"API key not found! Get one here and set it as the `KSCALE_API_KEY` environment variable or in your"
"config file: https://kscale.store/keys"
)
return api_key


def get_cache_dir() -> Path:
return Path(Settings.load().store.cache_dir).expanduser().resolve()


def fetch_mjcf_info(listing_id: str) -> MjcfResponse:
url = f"https://api.kscale.store/mjcf/info/{listing_id}"
headers = {
"Authorization": f"Bearer {get_api_key()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return MjcfResponse(**response.json())


async def download_artifact(artifact_url: str, cache_dir: Path) -> str:
filename = os.path.join(cache_dir, artifact_url.split("/")[-1])
headers = {
"Authorization": f"Bearer {get_api_key()}",
}

if not os.path.exists(filename):
logger.info("Downloading artifact from %s" % artifact_url)

async with httpx.AsyncClient() as client:
response = await client.get(artifact_url, headers=headers)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_bytes(chunk_size=8192):
f.write(chunk)
logger.info("Artifact downloaded to %s" % filename)
else:
logger.info("Artifact already cached at %s" % filename)

# Extract the .tgz file
extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0])
if not os.path.exists(extract_dir):
logger.info(f"Extracting {filename} to {extract_dir}")
with tarfile.open(filename, "r:gz") as tar:
tar.extractall(path=extract_dir)
logger.info("Extraction complete")
else:
logger.info("Artifact already extracted at %s" % extract_dir)

return extract_dir


def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Path) -> str:
tarball_path = os.path.join(cache_dir, output_filename)
with tarfile.open(tarball_path, "w:gz") as tar:
for root, _, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, start=folder_path)
tar.add(file_path, arcname=arcname)
logger.info("Added %s as %s" % (file_path, arcname))
logger.info("Created tarball %s" % tarball_path)
return tarball_path


async def upload_artifact(tarball_path: str, listing_id: str) -> None:
url = f"https://api.kscale.store/mjcf/upload/{listing_id}"
headers = {
"Authorization": f"Bearer {get_api_key()}",
}

async with httpx.AsyncClient() as client:
with open(tarball_path, "rb") as f:
files = {"file": (f.name, f, "application/gzip")}
response = await client.post(url, headers=headers, files=files)

response.raise_for_status()

logger.info("Uploaded artifact to %s" % url)


def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="K-Scale MJCF Store", add_help=False)
parser.add_argument(
"command",
choices=["get", "info", "upload", "convert"],
help="The command to run",
)
parser.add_argument("listing_id", help="The listing ID to operate on")
parsed_args, remaining_args = parser.parse_known_args(args)

command: Literal["get", "info", "upload", "convert"] = parsed_args.command
listing_id: str = parsed_args.listing_id

def get_listing_dir() -> Path:
(cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True)
return cache_dir

match command:
case "get":
try:
mjcf_info = fetch_mjcf_info(listing_id)

if mjcf_info.mjcf:
artifact_url = mjcf_info.mjcf.url
asyncio.run(download_artifact(artifact_url, get_listing_dir()))
else:
logger.info("No MJCF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch MJCF info: %s" % e)
sys.exit(1)

case "info":
try:
mjcf_info = fetch_mjcf_info(listing_id)

if mjcf_info.mjcf:
logger.info("MJCF Artifact ID: %s" % mjcf_info.mjcf.artifact_id)
logger.info("MJCF URL: %s" % mjcf_info.mjcf.url)
else:
logger.info("No MJCF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch MJCF info: %s" % e)
sys.exit(1)

case "upload":
parser = argparse.ArgumentParser(description="Upload an MJCF artifact to the K-Scale store")
parser.add_argument("folder_path", help="The path to the folder containing the MJCF files")
parsed_args = parser.parse_args(remaining_args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

urdf_or_mjcf = contains_urdf_or_mjcf(folder_path)
if urdf_or_mjcf == "mjcf":
output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
mjcf_info = fetch_mjcf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)
else:
logger.error("No MJCF files found in %s" % folder_path)
sys.exit(1)

case "convert":
parser = argparse.ArgumentParser(description="Convert an MJCF to a URDF file")
parser.add_argument("file_path", help="The path of the MJCF file")
parsed_args = parser.parse_args(remaining_args)
file_path = Path(parsed_args.file_path).expanduser().resolve()

if file_path.suffix == ".xml":
urdf_path = mjcf_to_urdf(file_path)
logger.info(f"Converted MJCF to URDF: {urdf_path}")
else:
logger.error("%s is not an MJCF file" % file_path)
sys.exit(1)

case _:
logger.error("Invalid command")
sys.exit(1)


if __name__ == "__main__":
main()
40 changes: 30 additions & 10 deletions kscale/store/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from kscale.conf import Settings
from kscale.store.gen.api import UrdfResponse
from kscale.utils import contains_urdf_or_mjcf, urdf_to_mjcf

# Set up logging
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -109,13 +110,13 @@ def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False)
parser.add_argument(
"command",
choices=["get", "info", "upload"],
choices=["get", "info", "upload", "convert"],
help="The command to run",
)
parser.add_argument("listing_id", help="The listing ID to operate on")
parsed_args, remaining_args = parser.parse_known_args(args)

command: Literal["get", "info", "upload"] = parsed_args.command
command: Literal["get", "info", "upload", "convert"] = parsed_args.command
listing_id: str = parsed_args.listing_id

def get_listing_dir() -> Path:
Expand Down Expand Up @@ -155,16 +156,35 @@ def get_listing_dir() -> Path:
parsed_args = parser.parse_args(remaining_args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
urdf_or_mjcf = contains_urdf_or_mjcf(folder_path)
if urdf_or_mjcf == "urdf":
output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)
else:
logger.error("No URDF files found in %s" % folder_path)
sys.exit(1)

case "convert":
parser = argparse.ArgumentParser(description="Convert a URDF to an MJCF file")
parser.add_argument("file_path", help="The path of the URDF file")
parsed_args = parser.parse_args(remaining_args)

# NOTE: folder path actually
file_path = Path(parsed_args.file_path).expanduser().resolve()

if file_path.suffix == ".urdf":
urdf_to_mjcf(file_path.parent, file_path.stem)
logger.info("Converted URDF to MJCF")
else:
logger.error("%s is not a URDF file" % file_path)
sys.exit(1)
case _:
logger.error("Invalid command")
sys.exit(1)
Expand Down
Loading
Loading