From 944e459269b5873e37e85aa0fda790040e4eb3e8 Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Tue, 28 Jan 2025 21:55:47 -0800 Subject: [PATCH] script for running client sdk tests --- docs/source/concepts/index.md | 2 + docs/source/distributions/configuration.md | 4 +- llama_stack/scripts/run_client_sdk_tests.py | 69 +++++++++++++++++++++ tests/client-sdk/report.py | 2 +- 4 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 llama_stack/scripts/run_client_sdk_tests.py diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index 834b7d7cd1..7422799b29 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -13,6 +13,7 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s - **DatasetIO**: interface with datasets and data loaders - **Scoring**: evaluate outputs of the system - **Eval**: generate outputs (via Inference or Agents) and perform scoring +- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents - **Telemetry**: collect telemetry data from the system We are working on adding a few more APIs to complete the application lifecycle. These will include: @@ -41,6 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi - **Safety** is associated with `Shield` resources. - **Tool Runtime** is associated with `ToolGroup` resources. - **DatasetIO** is associated with `Dataset` resources. +- **VectorIO** is associated with `VectorDB` resources. - **Scoring** is associated with `ScoringFunction` resources. - **Eval** is associated with `Model` and `EvalTask` resources. diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index d12f584f7a..0f766dcd55 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -10,7 +10,7 @@ conda_env: ollama apis: - agents - inference -- memory +- vector_io - safety - telemetry providers: @@ -19,7 +19,7 @@ providers: provider_type: remote::ollama config: url: ${env.OLLAMA_URL:http://localhost:11434} - memory: + vector_io: - provider_id: faiss provider_type: inline::faiss config: diff --git a/llama_stack/scripts/run_client_sdk_tests.py b/llama_stack/scripts/run_client_sdk_tests.py new file mode 100644 index 0000000000..90ecb84267 --- /dev/null +++ b/llama_stack/scripts/run_client_sdk_tests.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os +from pathlib import Path + +import pytest + + +""" +Script for running client-sdk on AsyncLlamaStackAsLibraryClient with templates + +Assuming directory structure: +- llama-stack + - llama_stack + - scripts + - tests + - client-sdk + +Example command: + +cd llama-stack +EXPORT TOGETHER_API_KEY=<..> +EXPORT FIREWORKS_API_KEY=<..> +python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report +""" + +REPO_ROOT = Path(__file__).parent.parent.parent +CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/client-sdk/" + + +def main(parser: argparse.ArgumentParser): + args = parser.parse_args() + templates_dir = REPO_ROOT / "llama_stack" / "templates" + user_specified_templates = ( + [templates_dir / t for t in args.templates] if args.templates else [] + ) + for d in templates_dir.iterdir(): + if d.is_dir() and d.name != "__pycache__": + template_configs = list(d.rglob("run.yaml")) + if len(template_configs) == 0: + continue + config = template_configs[0] + if user_specified_templates: + if not any(config.parent == t for t in user_specified_templates): + continue + os.environ["LLAMA_STACK_CONFIG"] = str(config) + pytest_args = "--report" if args.report else "" + pytest.main( + [ + pytest_args, + "-s", + "-v", + REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH, + ] + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="llama_test", + ) + parser.add_argument("--templates", nargs="+") + parser.add_argument("--report", action="store_true") + main(parser) diff --git a/tests/client-sdk/report.py b/tests/client-sdk/report.py index f39ea02fa9..6e6c8a98ab 100644 --- a/tests/client-sdk/report.py +++ b/tests/client-sdk/report.py @@ -198,7 +198,7 @@ def pytest_sessionfinish(self, session): "|:-----|:-----|:-----|:-----|:-----|", ] provider = [p for p in providers if p.api == str(api_group.name)] - provider_str = provider[0].provider_type if provider else "" + provider_str = ",".join(provider) if provider else "" for api, capa_map in API_MAPS[api_group].items(): for capa, tests in capa_map.items(): for test_name in tests: