Skip to content

Commit

Permalink
fix nathan's code
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Aug 20, 2024
1 parent baaa44f commit 3bc924c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 81 deletions.
14 changes: 3 additions & 11 deletions kscale/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,13 @@ def get_path() -> Path:

@dataclass
class StoreSettings:
api_key: str = field(default=II("oc.env:KSCALE_API_KEY"))

def get_api_key(self) -> str:
try:
return self.api_key
except AttributeError:
raise ValueError(
"API key not found! Get one here and set it as the `KSCALE_API_KEY` "
"environment variable: https://kscale.store/keys"
)
api_key: str = field(default=II("oc.env:KSCALE_API_KEY,"))
cache_dir: str = field(default=II("oc.env:KSCALE_CACHE_DIR,'~/.kscale/cache/'"))


@dataclass
class Settings:
store: StoreSettings = StoreSettings()
store: StoreSettings = field(default_factory=StoreSettings)

def save(self) -> None:
(dir_path := get_path()).mkdir(parents=True, exist_ok=True)
Expand Down
5 changes: 5 additions & 0 deletions kscale/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# requirements.txt

# Configuration
omegaconf

# HTTP requests
httpx
requests
13 changes: 0 additions & 13 deletions kscale/store/auth.py

This file was deleted.

145 changes: 88 additions & 57 deletions kscale/store/urdf.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,53 @@
"""Utility functions for managing artifacts in the K-Scale store."""

import argparse
import asyncio
import logging
import os
import sys
import tarfile
from typing import Sequence
from pathlib import Path
from typing import Literal, Sequence

import httpx
import requests

from kscale.conf import Settings
from kscale.store.gen.api import UrdfResponse

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CACHE_DIR = os.path.expanduser("~/.cache/kscale")

def get_api_key() -> str:
api_key = Settings.load().store.api_key
if not api_key:
raise ValueError(
"API key not found! Get one here and set it as the `KSCALE_API_KEY` environment variable or in your"
"config file: https://kscale.store/keys"
)
return api_key

def fetch_urdf_info(listing_id: str, api_key: str = "") -> UrdfResponse:

def get_cache_dir() -> Path:
return Path(Settings.load().store.cache_dir).expanduser().resolve()


def fetch_urdf_info(listing_id: str) -> UrdfResponse:
url = f"https://api.kscale.store/urdf/info/{listing_id}"
headers = {
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {get_api_key()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return UrdfResponse(**response.json())


async def download_artifact(artifact_url: str, cache_dir: str, api_key: str = None) -> str:
os.makedirs(cache_dir, exist_ok=True)
async def download_artifact(artifact_url: str, cache_dir: Path) -> str:
filename = os.path.join(cache_dir, artifact_url.split("/")[-1])
headers = {
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {get_api_key()}",
}

if not os.path.exists(filename):
Expand All @@ -43,7 +57,7 @@ async def download_artifact(artifact_url: str, cache_dir: str, api_key: str = No
response = await client.get(artifact_url, headers=headers)
response.raise_for_status()
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
for chunk in response.iter_bytes(chunk_size=8192):
f.write(chunk)
logger.info("Artifact downloaded to %s" % filename)
else:
Expand All @@ -62,10 +76,10 @@ async def download_artifact(artifact_url: str, cache_dir: str, api_key: str = No
return extract_dir


def create_tarball(folder_path: str, output_filename: str) -> str:
tarball_path = os.path.join(CACHE_DIR, output_filename)
def create_tarball(folder_path: str | Path, output_filename: str, cache_dir: Path) -> str:
tarball_path = os.path.join(cache_dir, output_filename)
with tarfile.open(tarball_path, "w:gz") as tar:
for root, dirs, files in os.walk(folder_path):
for root, _, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, start=folder_path)
Expand All @@ -75,10 +89,10 @@ def create_tarball(folder_path: str, output_filename: str) -> str:
return tarball_path


async def upload_artifact(tarball_path: str, listing_id: str, api_key: str) -> None:
async def upload_artifact(tarball_path: str, listing_id: str) -> None:
url = f"https://api.kscale.store/urdf/upload/{listing_id}"
headers = {
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {get_api_key()}",
}

async with httpx.AsyncClient() as client:
Expand All @@ -92,51 +106,68 @@ async def upload_artifact(tarball_path: str, listing_id: str, api_key: str) -> N


def main(args: Sequence[str] | None = None) -> None:
command = args[0]
listing_id = args[1]

if command == "get":
try:
api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None)
urdf_info = fetch_urdf_info(listing_id, api_key)

if urdf_info.urdf:
artifact_url = urdf_info.urdf.url
asyncio.run(download_artifact(artifact_url, CACHE_DIR, api_key))
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False)
parser.add_argument(
"command",
choices=["get", "info", "upload"],
help="The command to run",
)
parser.add_argument("listing_id", help="The listing ID to operate on")
parsed_args, remaining_args = parser.parse_known_args(args)

command: Literal["get", "info", "upload"] = parsed_args.command
listing_id: str = parsed_args.listing_id

def get_listing_dir() -> Path:
(cache_dir := get_cache_dir() / listing_id).mkdir(parents=True, exist_ok=True)
return cache_dir

match command:
case "get":
try:
urdf_info = fetch_urdf_info(listing_id)

if urdf_info.urdf:
artifact_url = urdf_info.urdf.url
asyncio.run(download_artifact(artifact_url, get_listing_dir()))
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)

case "info":
try:
urdf_info = fetch_urdf_info(listing_id)

if urdf_info.urdf:
logger.info("URDF Artifact ID: %s" % urdf_info.urdf.artifact_id)
logger.info("URDF URL: %s" % urdf_info.urdf.url)
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)

case "upload":
parser = argparse.ArgumentParser(description="Upload a URDF artifact to the K-Scale store")
parser.add_argument("folder_path", help="The path to the folder containing the URDF files")
parsed_args = parser.parse_args(remaining_args)
folder_path = Path(parsed_args.folder_path).expanduser().resolve()

output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename, get_listing_dir())

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)

case _:
logger.error("Invalid command")
sys.exit(1)
elif command == "info":
try:
api_key = os.getenv("KSCALE_API_KEY") or (args[2] if len(args) >= 3 else None)
urdf_info = fetch_urdf_info(listing_id, api_key)

if urdf_info.urdf:
logger.info("URDF Artifact ID: %s" % urdf_info.urdf.artifact_id)
logger.info("URDF URL: %s" % urdf_info.urdf.url)
else:
logger.info("No URDF found for listing %s" % listing_id)
except requests.RequestException as e:
logger.error("Failed to fetch URDF info: %s" % e)
sys.exit(1)
elif command == "upload":
folder_path = args[2]
api_key = os.getenv("KSCALE_API_KEY") or args[3] # Use the environment variable if available

output_filename = f"{listing_id}.tgz"
tarball_path = create_tarball(folder_path, output_filename)

try:
urdf_info = fetch_urdf_info(listing_id)
asyncio.run(upload_artifact(tarball_path, listing_id, api_key))
except requests.RequestException as e:
logger.error("Failed to upload artifact: %s" % e)
sys.exit(1)
else:
logger.error("Invalid command")
sys.exit(1)


if __name__ == "__main__":
Expand Down

0 comments on commit 3bc924c

Please sign in to comment.