diff --git a/kscale/store/api.py b/kscale/store/api.py index f518426..30b4ed4 100644 --- a/kscale/store/api.py +++ b/kscale/store/api.py @@ -2,7 +2,8 @@ from pathlib import Path -from kscale.store.urdf import download_urdf +from kscale.store.gen.api import UploadArtifactResponse +from kscale.store.urdf import download_urdf, upload_urdf from kscale.utils.api_base import APIBase @@ -18,3 +19,6 @@ def __init__( async def urdf(self, artifact_id: str) -> Path: return await download_urdf(artifact_id) + + async def upload_urdf(self, listing_id: str, root_dir: Path) -> UploadArtifactResponse: + return await upload_urdf(listing_id, root_dir) diff --git a/kscale/store/client.py b/kscale/store/client.py index dceed7f..aef18a7 100644 --- a/kscale/store/client.py +++ b/kscale/store/client.py @@ -1,6 +1,7 @@ """Defines a typed client for the K-Scale Store API.""" import logging +from pathlib import Path from types import TracebackType from typing import Any, Dict, Type from urllib.parse import urljoin @@ -14,13 +15,13 @@ SingleArtifactResponse, UploadArtifactResponse, ) -from kscale.store.utils import API_ROOT, get_api_key +from kscale.store.utils import get_api_key, get_api_root logger = logging.getLogger(__name__) class KScaleStoreClient: - def __init__(self, base_url: str = API_ROOT) -> None: + def __init__(self, base_url: str = get_api_root()) -> None: self.base_url = base_url self.client = httpx.AsyncClient( base_url=self.base_url, @@ -55,8 +56,9 @@ async def get_artifact_info(self, artifact_id: str) -> SingleArtifactResponse: return SingleArtifactResponse(**data) async def upload_artifact(self, listing_id: str, file_path: str) -> UploadArtifactResponse: + file_name = Path(file_path).name with open(file_path, "rb") as f: - files = {"files": (f.name, f, "application/gzip")} + files = {"files": (file_name, f, "application/gzip")} data = await self._request("POST", f"/artifacts/upload/{listing_id}", files=files) return UploadArtifactResponse(**data) diff --git a/kscale/store/urdf.py b/kscale/store/urdf.py index 5b24a84..940d87b 100644 --- a/kscale/store/urdf.py +++ b/kscale/store/urdf.py @@ -14,7 +14,7 @@ from kscale.conf import Settings from kscale.store.client import KScaleStoreClient -from kscale.store.gen.api import SingleArtifactResponse +from kscale.store.gen.api import SingleArtifactResponse, UploadArtifactResponse from kscale.store.utils import get_api_key # Set up logging @@ -128,18 +128,24 @@ async def remove_local_urdf(artifact_id: str) -> None: raise -async def upload_urdf(listing_id: str, args: Sequence[str]) -> None: - parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False) - parser.add_argument("root_dir", type=Path, help="The path to the root directory to upload") - parsed_args = parser.parse_args(args) - - root_dir = parsed_args.root_dir +async def upload_urdf(listing_id: str, root_dir: Path) -> UploadArtifactResponse: tarball_path = create_tarball(root_dir, "robot.tgz", get_artifact_dir(listing_id)) async with KScaleStoreClient() as client: response = await client.upload_artifact(listing_id, str(tarball_path)) logger.info("Uploaded artifacts: %s", [artifact.artifact_id for artifact in response.artifacts]) + return response + + +async def upload_urdf_cli(listing_id: str, args: Sequence[str]) -> UploadArtifactResponse: + parser = argparse.ArgumentParser(description="K-Scale URDF Store", add_help=False) + parser.add_argument("root_dir", type=Path, help="The path to the root directory to upload") + parsed_args = parser.parse_args(args) + + root_dir = parsed_args.root_dir + response = await upload_urdf(listing_id, root_dir) + return response Command = Literal["download", "info", "upload", "remove-local"] @@ -165,7 +171,7 @@ async def main(args: Sequence[str] | None = None) -> None: await remove_local_urdf(id) case "upload": - await upload_urdf(id, remaining_args) + await upload_urdf_cli(id, remaining_args) case _: logger.error("Invalid command") diff --git a/kscale/store/utils.py b/kscale/store/utils.py index 390f272..a9a09a6 100644 --- a/kscale/store/utils.py +++ b/kscale/store/utils.py @@ -4,10 +4,24 @@ from kscale.conf import Settings -API_ROOT = "https://api.kscale.store" + +def get_api_root() -> str: + """Returns the base URL for the K-Scale Store API. + + This can be overridden when targetting a different server. + + Returns: + The base URL for the K-Scale Store API. + """ + return os.getenv("KSCALE_API_ROOT", "https://api.kscale.store") def get_api_key() -> str: + """Returns the API key for the K-Scale Store API. + + Returns: + The API key for the K-Scale Store API. + """ api_key = Settings.load().store.api_key if api_key is None: api_key = os.getenv("KSCALE_API_KEY")