Skip to content

Commit

Permalink
added pydantic and auth headers to try to make my urdf downloadable -…
Browse files Browse the repository at this point in the history
…-- i think it's probably a problem on the fastapi end though
  • Loading branch information
nathanjzhao committed Aug 20, 2024
1 parent 9d0a6a5 commit 4079417
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions kscale/store/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,36 @@
import httpx
import requests

from kscale.store.gen.api import UrdfResponse

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CACHE_DIR = os.path.expanduser("~/.cache/kscale")


def fetch_urdf_info(listing_id: str) -> dict:
def fetch_urdf_info(listing_id: str, api_key: str = "") -> UrdfResponse:
url = f"https://api.kscale.store/urdf/info/{listing_id}"
response = requests.get(url)
headers = {
"Authorization": f"Bearer {api_key}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return response.json()
return UrdfResponse(**response.json())


def download_artifact(artifact_url: str, cache_dir: str) -> str:
def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> str:
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.join(cache_dir, artifact_url.split("/")[-1])
headers = {
"Authorization": f"Bearer {api_key}",
}

if not os.path.exists(filename):
logger.info("Downloading artifact from %s" % artifact_url)
response = requests.get(artifact_url, stream=True)
response = requests.get(artifact_url, headers=headers, stream=True)

response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
Expand Down Expand Up @@ -88,9 +97,10 @@ def main(args: Sequence[str] | None = None) -> None:

if command == "info":
try:
urdf_info = fetch_urdf_info(listing_id)
artifact_url = urdf_info["urdf"]["url"]
download_artifact(artifact_url, CACHE_DIR)
api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None)
urdf_info = fetch_urdf_info(listing_id, api_key)
artifact_url = urdf_info.urdf.url
download_artifact(artifact_url, CACHE_DIR, api_key)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)
Expand Down

0 comments on commit 4079417

Please sign in to comment.