-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from uktrade/feature/api
Implement API stub
- Loading branch information
Showing
6 changed files
with
657 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from matchbox.server.api import app | ||
from matchbox.server.base import ( | ||
MatchboxDBAdapter, | ||
MatchboxSettings, | ||
initialise_matchbox, | ||
inject_backend, | ||
) | ||
|
||
__all__ = ["MatchboxDBAdapter", "MatchboxSettings", "inject_backend"] | ||
__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "inject_backend"] | ||
|
||
initialise_matchbox() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from enum import StrEnum | ||
from typing import Annotated | ||
|
||
from dotenv import find_dotenv, load_dotenv | ||
from fastapi import Depends, FastAPI, HTTPException | ||
from pydantic import BaseModel | ||
|
||
from matchbox.server.base import BackendManager, MatchboxDBAdapter | ||
|
||
dotenv_path = find_dotenv(usecwd=True) | ||
load_dotenv(dotenv_path) | ||
|
||
|
||
app = FastAPI( | ||
title="matchbox API", | ||
version="0.2.0", | ||
) | ||
|
||
|
||
class BackendEntityType(StrEnum): | ||
DATASETS = "datasets" | ||
MODELS = "models" | ||
DATA = "data" | ||
CLUSTERS = "clusters" | ||
CREATES = "creates" | ||
MERGES = "merges" | ||
PROPOSES = "proposes" | ||
|
||
|
||
class ModelResultsType(StrEnum): | ||
PROBABILITIES = "probabilities" | ||
CLUSTERS = "clusters" | ||
|
||
|
||
class HealthCheck(BaseModel): | ||
"""Response model to validate and return when performing a health check.""" | ||
|
||
status: str = "OK" | ||
|
||
|
||
class CountResult(BaseModel): | ||
"""Response model for count results""" | ||
|
||
entities: dict[BackendEntityType, int] | ||
|
||
|
||
def get_backend() -> MatchboxDBAdapter: | ||
return BackendManager.get_backend() | ||
|
||
|
||
@app.get("/health") | ||
async def healthcheck() -> HealthCheck: | ||
""" """ | ||
return HealthCheck(status="OK") | ||
|
||
|
||
@app.get("/testing/count") | ||
async def count_backend_items( | ||
backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], | ||
entity: BackendEntityType | None = None, | ||
) -> CountResult: | ||
def get_count(e: BackendEntityType) -> int: | ||
return getattr(backend, str(e)).count() | ||
|
||
if entity is not None: | ||
return CountResult(entities={str(entity): get_count(entity)}) | ||
else: | ||
res = {str(e): get_count(e) for e in BackendEntityType} | ||
return CountResult(entities=res) | ||
|
||
|
||
@app.post("/testing/clear") | ||
async def clear_backend(): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/sources") | ||
async def list_sources(): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/sources/{hash}") | ||
async def get_source(hash: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.post("/sources/{hash}") | ||
async def add_source(hash: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/models") | ||
async def list_models(): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/models/{name}") | ||
async def get_model(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.post("/models/{name}") | ||
async def add_model(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.delete("/models/{name}") | ||
async def delete_model(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/models/{name}/results") | ||
async def get_results(name: str, result_type: ModelResultsType | None): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.post("/models/{name}/results") | ||
async def set_results(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/models/{name}/truth") | ||
async def get_truth(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.post("/models/{name}/truth") | ||
async def set_truth(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/models/{name}/ancestors") | ||
async def get_ancestors(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/models/{name}/ancestors_cache") | ||
async def get_ancestors_cache(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.post("/models/{name}/ancestors_cache") | ||
async def set_ancestors_cache(name: str): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/query") | ||
async def query(): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/validate/hash") | ||
async def validate_hashes(): | ||
raise HTTPException(status_code=501, detail="Not implemented") | ||
|
||
|
||
@app.get("/report/models") | ||
async def get_model_subgraph(): | ||
raise HTTPException(status_code=501, detail="Not implemented") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
from fastapi.testclient import TestClient | ||
from matchbox.server import app | ||
|
||
client = TestClient(app) | ||
|
||
|
||
class TestMatchboxAPI: | ||
def test_healthcheck(self): | ||
response = client.get("/health") | ||
assert response.status_code == 200 | ||
assert response.json() == {"status": "OK"} | ||
|
||
@patch("matchbox.server.base.BackendManager.get_backend") | ||
def test_count_all_backend_items(self, get_backend): | ||
entity_counts = { | ||
"datasets": 1, | ||
"models": 2, | ||
"data": 3, | ||
"clusters": 4, | ||
"creates": 5, | ||
"merges": 6, | ||
"proposes": 7, | ||
} | ||
mock_backend = Mock() | ||
for e, c in entity_counts.items(): | ||
mock_e = Mock() | ||
mock_e.count = Mock(return_value=c) | ||
setattr(mock_backend, e, mock_e) | ||
get_backend.return_value = mock_backend | ||
|
||
response = client.get("/testing/count") | ||
assert response.status_code == 200 | ||
assert response.json() == {"entities": entity_counts} | ||
|
||
@patch("matchbox.server.base.BackendManager.get_backend") | ||
def test_count_backend_item(self, get_backend): | ||
mock_backend = Mock() | ||
mock_backend.models.count = Mock(return_value=20) | ||
get_backend.return_value = mock_backend | ||
|
||
response = client.get("/testing/count", params={"entity": "models"}) | ||
assert response.status_code == 200 | ||
assert response.json() == {"entities": {"models": 20}} | ||
|
||
# def test_clear_backend(): | ||
# response = client.post("/testing/clear") | ||
# assert response.status_code == 200 | ||
|
||
# def test_list_sources(): | ||
# response = client.get("/sources") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_source(): | ||
# response = client.get("/sources/test_source") | ||
# assert response.status_code == 200 | ||
|
||
# def test_add_source(): | ||
# response = client.post("/sources") | ||
# assert response.status_code == 200 | ||
|
||
# def test_list_models(): | ||
# response = client.get("/models") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_model(): | ||
# response = client.get("/models/test_model") | ||
# assert response.status_code == 200 | ||
|
||
# def test_add_model(): | ||
# response = client.post("/models") | ||
# assert response.status_code == 200 | ||
|
||
# def test_delete_model(): | ||
# response = client.delete("/models/test_model") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_results(): | ||
# response = client.get("/models/test_model/results") | ||
# assert response.status_code == 200 | ||
|
||
# def test_set_results(): | ||
# response = client.post("/models/test_model/results") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_truth(): | ||
# response = client.get("/models/test_model/truth") | ||
# assert response.status_code == 200 | ||
|
||
# def test_set_truth(): | ||
# response = client.post("/models/test_model/truth") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_ancestors(): | ||
# response = client.get("/models/test_model/ancestors") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_ancestors_cache(): | ||
# response = client.get("/models/test_model/ancestors_cache") | ||
# assert response.status_code == 200 | ||
|
||
# def test_set_ancestors_cache(): | ||
# response = client.post("/models/test_model/ancestors_cache") | ||
# assert response.status_code == 200 | ||
|
||
# def test_query(): | ||
# response = client.get("/query") | ||
# assert response.status_code == 200 | ||
|
||
# def test_validate_hashes(): | ||
# response = client.get("/validate/hash") | ||
# assert response.status_code == 200 | ||
|
||
# def test_get_model_subgraph(): | ||
# response = client.get("/report/models") | ||
# assert response.status_code == 200 |
Oops, something went wrong.