diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index 22b0c9f..ab7faa7 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -29,7 +29,7 @@ def fetch_urdf_info(listing_id: str, api_key: str = "") -> UrdfResponse: return UrdfResponse(**response.json()) -def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> str: +async def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> str: os.makedirs(cache_dir, exist_ok=True) filename = os.path.join(cache_dir, artifact_url.split("/")[-1]) headers = { @@ -38,13 +38,14 @@ def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> if not os.path.exists(filename): logger.info("Downloading artifact from %s" % artifact_url) - response = requests.get(artifact_url, headers=headers, stream=True) - - response.raise_for_status() - with open(filename, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - logger.info("Artifact downloaded to %s" % filename) + + async with httpx.AsyncClient() as client: + response = await client.get(artifact_url, headers=headers) + response.raise_for_status() + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logger.info("Artifact downloaded to %s" % filename) else: logger.info("Artifact already cached at %s" % filename) @@ -95,12 +96,29 @@ def main(args: Sequence[str] | None = None) -> None: command = args[0] listing_id = args[1] - if command == "info": + if command == "get": try: api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None) urdf_info = fetch_urdf_info(listing_id, api_key) - artifact_url = urdf_info.urdf.url - download_artifact(artifact_url, CACHE_DIR, api_key) + + if urdf_info.urdf: + artifact_url = urdf_info.urdf.url + asyncio.run(download_artifact(artifact_url, CACHE_DIR, api_key)) + else: + logger.info("No URDF found for listing %s" % listing_id) + except requests.RequestException as e: + logger.error("Failed to fetch URDF info: %s" % e) + sys.exit(1) + elif command == "info": + try: + api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None) + urdf_info = fetch_urdf_info(listing_id, api_key) + + if urdf_info.urdf: + logger.info("URDF Artifact ID: %s" % urdf_info.urdf.artifact_id) + logger.info("URDF URL: %s" % urdf_info.urdf.url) + else: + logger.info("No URDF found for listing %s" % listing_id) except requests.RequestException as e: logger.error("Failed to fetch URDF info: %s" % e) sys.exit(1)