diff --git a/kscale/store/cli.py b/kscale/store/cli.py index 174072f..708d2f3 100644 --- a/kscale/store/cli.py +++ b/kscale/store/cli.py @@ -1,12 +1,13 @@ """Defines the top-level KOL CLI.""" import argparse +import asyncio from typing import Sequence from kscale.store import pybullet, urdf -def main(args: Sequence[str] | None = None) -> None: +async def main(args: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="K-Scale OnShape Library", add_help=False) parser.add_argument( "subcommand", @@ -20,11 +21,15 @@ def main(args: Sequence[str] | None = None) -> None: match parsed_args.subcommand: case "urdf": - urdf.main(remaining_args) + await urdf.main(remaining_args) case "pybullet": - pybullet.main(remaining_args) + await pybullet.main(remaining_args) + + +def sync_main(args: Sequence[str] | None = None) -> None: + asyncio.run(main(args)) if __name__ == "__main__": # python3 -m kscale.store.cli - main() + sync_main() diff --git a/kscale/store/pybullet.py b/kscale/store/pybullet.py index 86e77a9..142b90b 100644 --- a/kscale/store/pybullet.py +++ b/kscale/store/pybullet.py @@ -1,6 +1,7 @@ """Simple script to interact with a URDF in PyBullet.""" import argparse +import asyncio import itertools import logging import math @@ -8,12 +9,14 @@ from pathlib import Path from typing import Sequence +from kscale.store.urdf import download_urdf + logger = logging.getLogger(__name__) -def main(args: Sequence[str] | None = None) -> None: +async def main(args: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="Show a URDF in PyBullet") - parser.add_argument("urdf", nargs="?", help="Path to the URDF file") + parser.add_argument("listing_id", help="Listing ID for the URDF") parser.add_argument("--dt", type=float, default=0.01, help="Time step") parser.add_argument("-n", "--hide-gui", action="store_true", help="Hide the GUI") parser.add_argument("--no-merge", action="store_true", help="Do not merge fixed links") @@ -23,6 +26,11 @@ def main(args: Sequence[str] | None = None) -> None: parser.add_argument("--show-collision", action="store_true", help="Show collision meshes") parsed_args = parser.parse_args(args) + # Gets the URDF path. + urdf_path = await download_urdf(parsed_args.listing_id) + + breakpoint() + try: import pybullet as p # type: ignore[import-not-found] except ImportError: @@ -46,13 +54,6 @@ def main(args: Sequence[str] | None = None) -> None: # Loads the floor plane. floor = p.loadURDF(str((Path(__file__).parent / "bullet" / "plane.urdf").resolve())) - urdf_path = Path("robot" if parsed_args.urdf is None else parsed_args.urdf) - if urdf_path.is_dir(): - try: - urdf_path = next(urdf_path.glob("*.urdf")) - except StopIteration: - raise FileNotFoundError(f"No URDF files found in {urdf_path}") - # Load the robot URDF. start_position = [0.0, 0.0, 1.0] start_orientation = p.getQuaternionFromEuler([0.0, 0.0, 0.0]) @@ -175,4 +176,4 @@ def draw_box(pt: list[list[float]], color: tuple[float, float, float], obj_id: i if __name__ == "__main__": # python -m kscale.store.pybullet - main() + asyncio.run(main()) diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index 84d537b..3972f52 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -4,10 +4,11 @@ import asyncio import logging import os +import shutil import sys import tarfile from pathlib import Path -from typing import Literal, Sequence +from typing import Literal, Sequence, get_args import httpx import requests @@ -34,6 +35,11 @@ def get_cache_dir() -> Path: return Path(Settings.load().store.cache_dir).expanduser().resolve() +def get_listing_dir(listing_id: str) -> Path: + (cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True) + return cache_dir + + def fetch_urdf_info(listing_id: str) -> UrdfResponse: url = f"https://api.kscale.store/urdf/info/{listing_id}" headers = { @@ -44,14 +50,14 @@ def fetch_urdf_info(listing_id: str) -> UrdfResponse: return UrdfResponse(**response.json()) -async def download_artifact(artifact_url: str, cache_dir: Path) -> str: +async def download_artifact(artifact_url: str, cache_dir: Path) -> Path: filename = os.path.join(cache_dir, artifact_url.split("/")[-1]) headers = { "Authorization": f"Bearer {get_api_key()}", } if not os.path.exists(filename): - logger.info("Downloading artifact from %s" % artifact_url) + logger.info("Downloading artifact from %s", artifact_url) async with httpx.AsyncClient() as client: response = await client.get(artifact_url, headers=headers) @@ -59,19 +65,18 @@ async def download_artifact(artifact_url: str, cache_dir: Path) -> str: with open(filename, "wb") as f: for chunk in response.iter_bytes(chunk_size=8192): f.write(chunk) - logger.info("Artifact downloaded to %s" % filename) + logger.info("Artifact downloaded to %s", filename) else: - logger.info("Artifact already cached at %s" % filename) + logger.info("Artifact already cached at %s", filename) # Extract the .tgz file - extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0]) - if not os.path.exists(extract_dir): - logger.info(f"Extracting {filename} to {extract_dir}") + extract_dir = cache_dir / os.path.splitext(os.path.basename(filename))[0] + if not extract_dir.exists(): + logger.info("Extracting %s to %s", filename, extract_dir) with tarfile.open(filename, "r:gz") as tar: tar.extractall(path=extract_dir) - logger.info("Extraction complete") else: - logger.info("Artifact already extracted at %s" % extract_dir) + logger.info("Artifact already extracted at %s", extract_dir) return extract_dir @@ -84,8 +89,8 @@ def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Pat file_path = os.path.join(root, file) arcname = os.path.relpath(file_path, start=folder_path) tar.add(file_path, arcname=arcname) - logger.info("Added %s as %s" % (file_path, arcname)) - logger.info("Created tarball %s" % tarball_path) + logger.info("Added %s as %s", file_path, arcname) + logger.info("Created tarball %s", tarball_path) return tarball_path @@ -102,68 +107,102 @@ async def upload_artifact(tarball_path: str, listing_id: str) -> None: response.raise_for_status() - logger.info("Uploaded artifact to %s" % url) + logger.info("Uploaded artifact to %s", url) + + +async def download_urdf(listing_id: str) -> Path: + try: + urdf_info = fetch_urdf_info(listing_id) + + if urdf_info.urdf is None: + breakpoint() + raise ValueError(f"No URDF found for listing {listing_id}") + + artifact_url = urdf_info.urdf.url + return await download_artifact(artifact_url, get_listing_dir(listing_id)) + + except requests.RequestException: + logger.exception("Failed to fetch URDF info") + raise + + +async def show_urdf_info(listing_id: str) -> None: + try: + urdf_info = fetch_urdf_info(listing_id) + + if urdf_info.urdf: + logger.info("URDF Artifact ID: %s", urdf_info.urdf.artifact_id) + logger.info("URDF URL: %s", urdf_info.urdf.url) + else: + logger.info("No URDF found for listing %s", listing_id) + except requests.RequestException: + logger.exception("Failed to fetch URDF info") + raise + +async def upload_urdf(listing_id: str, args: Sequence[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="Upload a URDF artifact to the K-Scale store") + parser.add_argument("folder_path", help="The path to the folder containing the URDF files") + parsed_args = parser.parse_args(args) + folder_path = Path(parsed_args.folder_path).expanduser().resolve() -def main(args: Sequence[str] | None = None) -> None: + output_filename = f"{listing_id}.tgz" + tarball_path = create_tarball(folder_path, output_filename, get_listing_dir(listing_id)) + + try: + fetch_urdf_info(listing_id) + await upload_artifact(tarball_path, listing_id) + except requests.RequestException: + logger.exception("Failed to upload artifact") + raise + + +async def remove_local_urdf(listing_id: str) -> None: + try: + if listing_id.lower() == "all": + cache_dir = get_cache_dir() + if cache_dir.exists(): + logger.info("Removing all local caches at %s", cache_dir) + shutil.rmtree(cache_dir) + else: + logger.error("No local caches found") + else: + listing_dir = get_listing_dir(listing_id) + if listing_dir.exists(): + logger.info("Removing local cache at %s", listing_dir) + shutil.rmtree(listing_dir) + else: + logger.error("No local cache found for listing %s", listing_id) + + except Exception: + logger.error("Failed to remove local cache") + raise + + +Command = Literal["download", "info", "upload", "remove-local"] + + +async def main(args: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False) - parser.add_argument( - "command", - choices=["get", "info", "upload"], - help="The command to run", - ) + parser.add_argument("command", choices=get_args(Command), help="The command to run") parser.add_argument("listing_id", help="The listing ID to operate on") parsed_args, remaining_args = parser.parse_known_args(args) - command: Literal["get", "info", "upload"] = parsed_args.command + command: Command = parsed_args.command listing_id: str = parsed_args.listing_id - def get_listing_dir() -> Path: - (cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True) - return cache_dir - match command: - case "get": - try: - urdf_info = fetch_urdf_info(listing_id) - - if urdf_info.urdf: - artifact_url = urdf_info.urdf.url - asyncio.run(download_artifact(artifact_url, get_listing_dir())) - else: - logger.info("No URDF found for listing %s" % listing_id) - except requests.RequestException as e: - logger.error("Failed to fetch URDF info: %s" % e) - sys.exit(1) + case "download": + await download_urdf(listing_id) case "info": - try: - urdf_info = fetch_urdf_info(listing_id) - - if urdf_info.urdf: - logger.info("URDF Artifact ID: %s" % urdf_info.urdf.artifact_id) - logger.info("URDF URL: %s" % urdf_info.urdf.url) - else: - logger.info("No URDF found for listing %s" % listing_id) - except requests.RequestException as e: - logger.error("Failed to fetch URDF info: %s" % e) - sys.exit(1) + await show_urdf_info(listing_id) case "upload": - parser = argparse.ArgumentParser(description="Upload a URDF artifact to the K-Scale store") - parser.add_argument("folder_path", help="The path to the folder containing the URDF files") - parsed_args = parser.parse_args(remaining_args) - folder_path = Path(parsed_args.folder_path).expanduser().resolve() - - output_filename = f"{listing_id}.tgz" - tarball_path = create_tarball(folder_path, output_filename, get_listing_dir()) + await upload_urdf(listing_id, remaining_args) - try: - urdf_info = fetch_urdf_info(listing_id) - asyncio.run(upload_artifact(tarball_path, listing_id)) - except requests.RequestException as e: - logger.error("Failed to upload artifact: %s" % e) - sys.exit(1) + case "remove-local": + await remove_local_urdf(listing_id) case _: logger.error("Invalid command") @@ -171,4 +210,4 @@ def get_listing_dir() -> Path: if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/setup.cfg b/setup.cfg index 13ecb16..3f34515 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,4 +10,4 @@ exclude = [options.entry_points] console_scripts = - kscale = kscale.store.cli:main \ No newline at end of file + kscale = kscale.store.cli:sync_main