Skip to content

Commit

Permalink
make cli more async
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Aug 23, 2024
1 parent 1f6eb8a commit 8f4a848
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 76 deletions.
13 changes: 9 additions & 4 deletions kscale/store/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Defines the top-level KOL CLI."""

import argparse
import asyncio
from typing import Sequence

from kscale.store import pybullet, urdf


def main(args: Sequence[str] | None = None) -> None:
async def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="K-Scale OnShape Library", add_help=False)
parser.add_argument(
"subcommand",
Expand All @@ -20,11 +21,15 @@ def main(args: Sequence[str] | None = None) -> None:

match parsed_args.subcommand:
case "urdf":
urdf.main(remaining_args)
await urdf.main(remaining_args)
case "pybullet":
pybullet.main(remaining_args)
await pybullet.main(remaining_args)


def sync_main(args: Sequence[str] | None = None) -> None:
asyncio.run(main(args))


if __name__ == "__main__":
# python3 -m kscale.store.cli
main()
sync_main()
21 changes: 11 additions & 10 deletions kscale/store/pybullet.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
"""Simple script to interact with a URDF in PyBullet."""

import argparse
import asyncio
import itertools
import logging
import math
import time
from pathlib import Path
from typing import Sequence

from kscale.store.urdf import download_urdf

logger = logging.getLogger(__name__)


def main(args: Sequence[str] | None = None) -> None:
async def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Show a URDF in PyBullet")
parser.add_argument("urdf", nargs="?", help="Path to the URDF file")
parser.add_argument("listing_id", help="Listing ID for the URDF")
parser.add_argument("--dt", type=float, default=0.01, help="Time step")
parser.add_argument("-n", "--hide-gui", action="store_true", help="Hide the GUI")
parser.add_argument("--no-merge", action="store_true", help="Do not merge fixed links")
Expand All @@ -23,6 +26,11 @@ def main(args: Sequence[str] | None = None) -> None:
parser.add_argument("--show-collision", action="store_true", help="Show collision meshes")
parsed_args = parser.parse_args(args)

# Gets the URDF path.
urdf_path = await download_urdf(parsed_args.listing_id)

breakpoint()

try:
import pybullet as p # type: ignore[import-not-found]
except ImportError:
Expand All @@ -46,13 +54,6 @@ def main(args: Sequence[str] | None = None) -> None:
# Loads the floor plane.
floor = p.loadURDF(str((Path(__file__).parent / "bullet" / "plane.urdf").resolve()))

urdf_path = Path("robot" if parsed_args.urdf is None else parsed_args.urdf)
if urdf_path.is_dir():
try:
urdf_path = next(urdf_path.glob("*.urdf"))
except StopIteration:
raise FileNotFoundError(f"No URDF files found in {urdf_path}")

# Load the robot URDF.
start_position = [0.0, 0.0, 1.0]
start_orientation = p.getQuaternionFromEuler([0.0, 0.0, 0.0])
Expand Down Expand Up @@ -175,4 +176,4 @@ def draw_box(pt: list[list[float]], color: tuple[float, float, float], obj_id: i

if __name__ == "__main__":
# python -m kscale.store.pybullet
main()
asyncio.run(main())
161 changes: 100 additions & 61 deletions kscale/store/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import asyncio
import logging
import os
import shutil
import sys
import tarfile
from pathlib import Path
from typing import Literal, Sequence
from typing import Literal, Sequence, get_args

import httpx
import requests
Expand All @@ -34,6 +35,11 @@ def get_cache_dir() -> Path:
return Path(Settings.load().store.cache_dir).expanduser().resolve()


def get_listing_dir(listing_id: str) -> Path:
(cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True)
return cache_dir


def fetch_urdf_info(listing_id: str) -> UrdfResponse:
url = f"https://api.kscale.store/urdf/info/{listing_id}"
headers = {
Expand All @@ -44,34 +50,33 @@ def fetch_urdf_info(listing_id: str) -> UrdfResponse:
return UrdfResponse(**response.json())


async def download_artifact(artifact_url: str, cache_dir: Path) -> str:
async def download_artifact(artifact_url: str, cache_dir: Path) -> Path:
filename = os.path.join(cache_dir, artifact_url.split("/")[-1])
headers = {
"Authorization": f"Bearer {get_api_key()}",
}

if not os.path.exists(filename):
logger.info("Downloading artifact from %s" % artifact_url)
logger.info("Downloading artifact from %s", artifact_url)

async with httpx.AsyncClient() as client:
response = await client.get(artifact_url, headers=headers)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_bytes(chunk_size=8192):
f.write(chunk)
logger.info("Artifact downloaded to %s" % filename)
logger.info("Artifact downloaded to %s", filename)
else:
logger.info("Artifact already cached at %s" % filename)
logger.info("Artifact already cached at %s", filename)

# Extract the .tgz file
extract_dir = os.path.join(cache_dir, os.path.splitext(os.path.basename(filename))[0])
if not os.path.exists(extract_dir):
logger.info(f"Extracting {filename} to {extract_dir}")
extract_dir = cache_dir / os.path.splitext(os.path.basename(filename))[0]
if not extract_dir.exists():
logger.info("Extracting %s to %s", filename, extract_dir)
with tarfile.open(filename, "r:gz") as tar:
tar.extractall(path=extract_dir)
logger.info("Extraction complete")
else:
logger.info("Artifact already extracted at %s" % extract_dir)
logger.info("Artifact already extracted at %s", extract_dir)

return extract_dir

Expand All @@ -84,8 +89,8 @@ def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Pat
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, start=folder_path)
tar.add(file_path, arcname=arcname)
logger.info("Added %s as %s" % (file_path, arcname))
logger.info("Created tarball %s" % tarball_path)
logger.info("Added %s as %s", file_path, arcname)
logger.info("Created tarball %s", tarball_path)
return tarball_path


Expand All @@ -102,73 +107,107 @@ async def upload_artifact(tarball_path: str, listing_id: str) -> None:

response.raise_for_status()

logger.info("Uploaded artifact to %s" % url)
logger.info("Uploaded artifact to %s", url)


async def download_urdf(listing_id: str) -> Path:
try:
urdf_info = fetch_urdf_info(listing_id)

if urdf_info.urdf is None:
breakpoint()
raise ValueError(f"No URDF found for listing {listing_id}")

artifact_url = urdf_info.urdf.url
return await download_artifact(artifact_url, get_listing_dir(listing_id))

except requests.RequestException:
logger.exception("Failed to fetch URDF info")
raise


async def show_urdf_info(listing_id: str) -> None:
try:
urdf_info = fetch_urdf_info(listing_id)

if urdf_info.urdf:
logger.info("URDF Artifact ID: %s", urdf_info.urdf.artifact_id)
logger.info("URDF URL: %s", urdf_info.urdf.url)
else:
logger.info("No URDF found for listing %s", listing_id)
except requests.RequestException:
logger.exception("Failed to fetch URDF info")
raise


async def upload_urdf(listing_id: str, args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Upload a URDF artifact to the K-Scale store")
parser.add_argument("folder_path", help="The path to the folder containing the URDF files")
parsed_args = parser.parse_args(args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

def main(args: Sequence[str] | None = None) -> None:
output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir(listing_id))

try:
fetch_urdf_info(listing_id)
await upload_artifact(tarball_path, listing_id)
except requests.RequestException:
logger.exception("Failed to upload artifact")
raise


async def remove_local_urdf(listing_id: str) -> None:
try:
if listing_id.lower() == "all":
cache_dir = get_cache_dir()
if cache_dir.exists():
logger.info("Removing all local caches at %s", cache_dir)
shutil.rmtree(cache_dir)
else:
logger.error("No local caches found")
else:
listing_dir = get_listing_dir(listing_id)
if listing_dir.exists():
logger.info("Removing local cache at %s", listing_dir)
shutil.rmtree(listing_dir)
else:
logger.error("No local cache found for listing %s", listing_id)

except Exception:
logger.error("Failed to remove local cache")
raise


Command = Literal["download", "info", "upload", "remove-local"]


async def main(args: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False)
parser.add_argument(
"command",
choices=["get", "info", "upload"],
help="The command to run",
)
parser.add_argument("command", choices=get_args(Command), help="The command to run")
parser.add_argument("listing_id", help="The listing ID to operate on")
parsed_args, remaining_args = parser.parse_known_args(args)

command: Literal["get", "info", "upload"] = parsed_args.command
command: Command = parsed_args.command
listing_id: str = parsed_args.listing_id

def get_listing_dir() -> Path:
(cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True)
return cache_dir

match command:
case "get":
try:
urdf_info = fetch_urdf_info(listing_id)

if urdf_info.urdf:
artifact_url = urdf_info.urdf.url
asyncio.run(download_artifact(artifact_url, get_listing_dir()))
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)
case "download":
await download_urdf(listing_id)

case "info":
try:
urdf_info = fetch_urdf_info(listing_id)

if urdf_info.urdf:
logger.info("URDF Artifact ID: %s" % urdf_info.urdf.artifact_id)
logger.info("URDF URL: %s" % urdf_info.urdf.url)
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)
await show_urdf_info(listing_id)

case "upload":
parser = argparse.ArgumentParser(description="Upload a URDF artifact to the K-Scale store")
parser.add_argument("folder_path", help="The path to the folder containing the URDF files")
parsed_args = parser.parse_args(remaining_args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())
await upload_urdf(listing_id, remaining_args)

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)
case "remove-local":
await remove_local_urdf(listing_id)

case _:
logger.error("Invalid command")
sys.exit(1)


if __name__ == "__main__":
main()
asyncio.run(main())
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ exclude =
[options.entry_points]

console_scripts =
kscale = kscale.store.cli:main
kscale = kscale.store.cli:sync_main

0 comments on commit 8f4a848

Please sign in to comment.