Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API stub #12

Merged
merged 11 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ format:
test:
docker compose up -d --wait
uv run pytest

# Run development version of API
api:
uv run fastapi dev src/matchbox/server/api.py
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"click>=8.1.7",
"connectorx>=0.3.3",
"duckdb>=1.1.1",
"fastapi[standard]>=0.115.0,<0.116.0",
lmazz1-dbt marked this conversation as resolved.
Show resolved Hide resolved
"matplotlib>=3.9.2",
"pandas>=2.2.3",
"pg-bulk-ingest>=0.0.54",
Expand Down
3 changes: 2 additions & 1 deletion src/matchbox/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from matchbox.server.api import app
from matchbox.server.base import (
MatchboxDBAdapter,
MatchboxSettings,
initialise_matchbox,
inject_backend,
)

__all__ = ["MatchboxDBAdapter", "MatchboxSettings", "inject_backend"]
__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "inject_backend"]

initialise_matchbox()
159 changes: 159 additions & 0 deletions src/matchbox/server/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from enum import StrEnum
from typing import Annotated

from dotenv import find_dotenv, load_dotenv
from fastapi import Depends, FastAPI, HTTPException
from pydantic import BaseModel

from matchbox.server.base import BackendManager, MatchboxDBAdapter

dotenv_path = find_dotenv(usecwd=True)
load_dotenv(dotenv_path)


app = FastAPI(
title="matchbox API",
version="0.2.0",
)


class BackendEntityType(StrEnum):
DATASETS = "datasets"
MODELS = "models"
DATA = "data"
CLUSTERS = "clusters"
CREATES = "creates"
MERGES = "merges"
PROPOSES = "proposes"


class ModelResultsType(StrEnum):
PROBABILITIES = "probabilities"
CLUSTERS = "clusters"


class HealthCheck(BaseModel):
lmazz1-dbt marked this conversation as resolved.
Show resolved Hide resolved
"""Response model to validate and return when performing a health check."""

status: str = "OK"


class CountResult(BaseModel):
"""Response model for count results"""

entities: dict[BackendEntityType, int]


def get_backend() -> MatchboxDBAdapter:
return BackendManager.get_backend()


@app.get("/health")
async def healthcheck() -> HealthCheck:
""" """
return HealthCheck(status="OK")


@app.get("/testing/count")
async def count_backend_items(
backend: Annotated[MatchboxDBAdapter, Depends(get_backend)],
entity: BackendEntityType | None = None,
) -> CountResult:
def get_count(e: BackendEntityType) -> int:
return getattr(backend, str(e)).count()

if entity is not None:
return CountResult(entities={str(entity): get_count(entity)})
else:
res = {str(e): get_count(e) for e in BackendEntityType}
return CountResult(entities=res)


@app.post("/testing/clear")
async def clear_backend():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/sources")
async def list_sources():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/sources/{hash}")
async def get_source(hash: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/sources/{hash}")
lmazz1-dbt marked this conversation as resolved.
Show resolved Hide resolved
async def add_source(hash: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models")
async def list_models():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}")
async def get_model(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}")
async def add_model(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.delete("/models/{name}")
async def delete_model(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/results")
async def get_results(name: str, result_type: ModelResultsType | None):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}/results")
async def set_results(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/truth")
async def get_truth(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}/truth")
async def set_truth(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/ancestors")
async def get_ancestors(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/ancestors_cache")
async def get_ancestors_cache(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}/ancestors_cache")
async def set_ancestors_cache(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/query")
async def query():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/validate/hash")
async def validate_hashes():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/report/models")
async def get_model_subgraph():
raise HTTPException(status_code=501, detail="Not implemented")
117 changes: 117 additions & 0 deletions test/server/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from unittest.mock import Mock, patch

from fastapi.testclient import TestClient
from matchbox.server import app

client = TestClient(app)


class TestMatchboxAPI:
def test_healthcheck(self):
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "OK"}

@patch("matchbox.server.base.BackendManager.get_backend")
def test_count_all_backend_items(self, get_backend):
entity_counts = {
"datasets": 1,
"models": 2,
"data": 3,
"clusters": 4,
"creates": 5,
"merges": 6,
"proposes": 7,
}
mock_backend = Mock()
for e, c in entity_counts.items():
mock_e = Mock()
mock_e.count = Mock(return_value=c)
setattr(mock_backend, e, mock_e)
get_backend.return_value = mock_backend

response = client.get("/testing/count")
assert response.status_code == 200
assert response.json() == {"entities": entity_counts}

@patch("matchbox.server.base.BackendManager.get_backend")
def test_count_backend_item(self, get_backend):
mock_backend = Mock()
mock_backend.models.count = Mock(return_value=20)
get_backend.return_value = mock_backend

response = client.get("/testing/count", params={"entity": "models"})
assert response.status_code == 200
assert response.json() == {"entities": {"models": 20}}

# def test_clear_backend():
# response = client.post("/testing/clear")
# assert response.status_code == 200

# def test_list_sources():
# response = client.get("/sources")
# assert response.status_code == 200

# def test_get_source():
# response = client.get("/sources/test_source")
# assert response.status_code == 200

# def test_add_source():
# response = client.post("/sources")
# assert response.status_code == 200

# def test_list_models():
# response = client.get("/models")
# assert response.status_code == 200

# def test_get_model():
# response = client.get("/models/test_model")
# assert response.status_code == 200

# def test_add_model():
# response = client.post("/models")
# assert response.status_code == 200

# def test_delete_model():
# response = client.delete("/models/test_model")
# assert response.status_code == 200

# def test_get_results():
# response = client.get("/models/test_model/results")
# assert response.status_code == 200

# def test_set_results():
# response = client.post("/models/test_model/results")
# assert response.status_code == 200

# def test_get_truth():
# response = client.get("/models/test_model/truth")
# assert response.status_code == 200

# def test_set_truth():
# response = client.post("/models/test_model/truth")
# assert response.status_code == 200

# def test_get_ancestors():
# response = client.get("/models/test_model/ancestors")
# assert response.status_code == 200

# def test_get_ancestors_cache():
# response = client.get("/models/test_model/ancestors_cache")
# assert response.status_code == 200

# def test_set_ancestors_cache():
# response = client.post("/models/test_model/ancestors_cache")
# assert response.status_code == 200

# def test_query():
# response = client.get("/query")
# assert response.status_code == 200

# def test_validate_hashes():
# response = client.get("/validate/hash")
# assert response.status_code == 200

# def test_get_model_subgraph():
# response = client.get("/report/models")
# assert response.status_code == 200
Loading