Skip to content

Commit

Permalink
urdf changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanjzhao committed Aug 20, 2024
1 parent 2e245d9 commit c31c883
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions kscale/store/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def fetch_urdf_info(listing_id: str, api_key: str = "") -> UrdfResponse:
return UrdfResponse(**response.json())


def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> str:
async def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> str:
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.join(cache_dir, artifact_url.split("/")[-1])
headers = {
Expand All @@ -38,13 +38,14 @@ def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) ->

if not os.path.exists(filename):
logger.info("Downloading artifact from %s" % artifact_url)
response = requests.get(artifact_url, headers=headers, stream=True)

response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Artifact downloaded to %s" % filename)

async with httpx.AsyncClient() as client:
response = await client.get(artifact_url, headers=headers)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Artifact downloaded to %s" % filename)
else:
logger.info("Artifact already cached at %s" % filename)

Expand Down Expand Up @@ -95,12 +96,29 @@ def main(args: Sequence[str] | None = None) -> None:
command = args[0]
listing_id = args[1]

if command == "info":
if command == "get":
try:
api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None)
urdf_info = fetch_urdf_info(listing_id, api_key)
artifact_url = urdf_info.urdf.url
download_artifact(artifact_url, CACHE_DIR, api_key)

if urdf_info.urdf:
artifact_url = urdf_info.urdf.url
asyncio.run(download_artifact(artifact_url, CACHE_DIR, api_key))
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)
elif command == "info":
try:
api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None)
urdf_info = fetch_urdf_info(listing_id, api_key)

if urdf_info.urdf:
logger.info("URDF Artifact ID: %s" % urdf_info.urdf.artifact_id)
logger.info("URDF URL: %s" % urdf_info.urdf.url)
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)
Expand Down

0 comments on commit c31c883

Please sign in to comment.