Skip to content

Commit

Permalink
Merge pull request #12 from uktrade/feature/api
Browse files Browse the repository at this point in the history
Implement API stub
  • Loading branch information
lmazz1-dbt authored Dec 2, 2024
2 parents 4a587f1 + b0b51de commit 22ea503
Show file tree
Hide file tree
Showing 6 changed files with 657 additions and 1 deletion.
4 changes: 4 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ format:
test:
docker compose up -d --wait
uv run pytest

# Run development version of API
api:
uv run fastapi dev src/matchbox/server/api.py
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"click>=8.1.7",
"connectorx>=0.3.3",
"duckdb>=1.1.1",
"fastapi[standard]>=0.115.0,<0.116.0",
"matplotlib>=3.9.2",
"pandas>=2.2.3",
"pg-bulk-ingest>=0.0.54",
Expand Down
3 changes: 2 additions & 1 deletion src/matchbox/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from matchbox.server.api import app
from matchbox.server.base import (
MatchboxDBAdapter,
MatchboxSettings,
initialise_matchbox,
inject_backend,
)

__all__ = ["MatchboxDBAdapter", "MatchboxSettings", "inject_backend"]
__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "inject_backend"]

initialise_matchbox()
159 changes: 159 additions & 0 deletions src/matchbox/server/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from enum import StrEnum
from typing import Annotated

from dotenv import find_dotenv, load_dotenv
from fastapi import Depends, FastAPI, HTTPException
from pydantic import BaseModel

from matchbox.server.base import BackendManager, MatchboxDBAdapter

dotenv_path = find_dotenv(usecwd=True)
load_dotenv(dotenv_path)


app = FastAPI(
title="matchbox API",
version="0.2.0",
)


class BackendEntityType(StrEnum):
DATASETS = "datasets"
MODELS = "models"
DATA = "data"
CLUSTERS = "clusters"
CREATES = "creates"
MERGES = "merges"
PROPOSES = "proposes"


class ModelResultsType(StrEnum):
PROBABILITIES = "probabilities"
CLUSTERS = "clusters"


class HealthCheck(BaseModel):
"""Response model to validate and return when performing a health check."""

status: str = "OK"


class CountResult(BaseModel):
"""Response model for count results"""

entities: dict[BackendEntityType, int]


def get_backend() -> MatchboxDBAdapter:
return BackendManager.get_backend()


@app.get("/health")
async def healthcheck() -> HealthCheck:
""" """
return HealthCheck(status="OK")


@app.get("/testing/count")
async def count_backend_items(
backend: Annotated[MatchboxDBAdapter, Depends(get_backend)],
entity: BackendEntityType | None = None,
) -> CountResult:
def get_count(e: BackendEntityType) -> int:
return getattr(backend, str(e)).count()

if entity is not None:
return CountResult(entities={str(entity): get_count(entity)})
else:
res = {str(e): get_count(e) for e in BackendEntityType}
return CountResult(entities=res)


@app.post("/testing/clear")
async def clear_backend():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/sources")
async def list_sources():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/sources/{hash}")
async def get_source(hash: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/sources/{hash}")
async def add_source(hash: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models")
async def list_models():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}")
async def get_model(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}")
async def add_model(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.delete("/models/{name}")
async def delete_model(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/results")
async def get_results(name: str, result_type: ModelResultsType | None):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}/results")
async def set_results(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/truth")
async def get_truth(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}/truth")
async def set_truth(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/ancestors")
async def get_ancestors(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/models/{name}/ancestors_cache")
async def get_ancestors_cache(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.post("/models/{name}/ancestors_cache")
async def set_ancestors_cache(name: str):
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/query")
async def query():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/validate/hash")
async def validate_hashes():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/report/models")
async def get_model_subgraph():
raise HTTPException(status_code=501, detail="Not implemented")
117 changes: 117 additions & 0 deletions test/server/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from unittest.mock import Mock, patch

from fastapi.testclient import TestClient
from matchbox.server import app

client = TestClient(app)


class TestMatchboxAPI:
def test_healthcheck(self):
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "OK"}

@patch("matchbox.server.base.BackendManager.get_backend")
def test_count_all_backend_items(self, get_backend):
entity_counts = {
"datasets": 1,
"models": 2,
"data": 3,
"clusters": 4,
"creates": 5,
"merges": 6,
"proposes": 7,
}
mock_backend = Mock()
for e, c in entity_counts.items():
mock_e = Mock()
mock_e.count = Mock(return_value=c)
setattr(mock_backend, e, mock_e)
get_backend.return_value = mock_backend

response = client.get("/testing/count")
assert response.status_code == 200
assert response.json() == {"entities": entity_counts}

@patch("matchbox.server.base.BackendManager.get_backend")
def test_count_backend_item(self, get_backend):
mock_backend = Mock()
mock_backend.models.count = Mock(return_value=20)
get_backend.return_value = mock_backend

response = client.get("/testing/count", params={"entity": "models"})
assert response.status_code == 200
assert response.json() == {"entities": {"models": 20}}

# def test_clear_backend():
# response = client.post("/testing/clear")
# assert response.status_code == 200

# def test_list_sources():
# response = client.get("/sources")
# assert response.status_code == 200

# def test_get_source():
# response = client.get("/sources/test_source")
# assert response.status_code == 200

# def test_add_source():
# response = client.post("/sources")
# assert response.status_code == 200

# def test_list_models():
# response = client.get("/models")
# assert response.status_code == 200

# def test_get_model():
# response = client.get("/models/test_model")
# assert response.status_code == 200

# def test_add_model():
# response = client.post("/models")
# assert response.status_code == 200

# def test_delete_model():
# response = client.delete("/models/test_model")
# assert response.status_code == 200

# def test_get_results():
# response = client.get("/models/test_model/results")
# assert response.status_code == 200

# def test_set_results():
# response = client.post("/models/test_model/results")
# assert response.status_code == 200

# def test_get_truth():
# response = client.get("/models/test_model/truth")
# assert response.status_code == 200

# def test_set_truth():
# response = client.post("/models/test_model/truth")
# assert response.status_code == 200

# def test_get_ancestors():
# response = client.get("/models/test_model/ancestors")
# assert response.status_code == 200

# def test_get_ancestors_cache():
# response = client.get("/models/test_model/ancestors_cache")
# assert response.status_code == 200

# def test_set_ancestors_cache():
# response = client.post("/models/test_model/ancestors_cache")
# assert response.status_code == 200

# def test_query():
# response = client.get("/query")
# assert response.status_code == 200

# def test_validate_hashes():
# response = client.get("/validate/hash")
# assert response.status_code == 200

# def test_get_model_subgraph():
# response = client.get("/report/models")
# assert response.status_code == 200
Loading

0 comments on commit 22ea503

Please sign in to comment.