diff --git a/backend_py/primary/primary/routers/explore.py b/backend_py/primary/primary/routers/explore/router.py similarity index 75% rename from backend_py/primary/primary/routers/explore.py rename to backend_py/primary/primary/routers/explore/router.py index effda0837..decc337d8 100644 --- a/backend_py/primary/primary/routers/explore.py +++ b/backend_py/primary/primary/routers/explore/router.py @@ -8,43 +8,24 @@ from primary.services.sumo_access.sumo_inspector import SumoInspector from primary.services.utils.authenticated_user import AuthenticatedUser -router = APIRouter() - - -class FieldInfo(BaseModel): - field_identifier: str - - -class CaseInfo(BaseModel): - uuid: str - name: str - status: str - user: str +from . import schemas +router = APIRouter() -class EnsembleInfo(BaseModel): - name: str - realization_count: int -class EnsembleDetails(BaseModel): - name: str - field_identifier: str - case_name: str - case_uuid: str - realizations: Sequence[int] @router.get("/fields") async def get_fields( authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user), -) -> List[FieldInfo]: +) -> List[schemas.FieldInfo]: """ Get list of fields """ sumo_inspector = SumoInspector(authenticated_user.get_sumo_access_token()) field_ident_arr = await sumo_inspector.get_fields_async() - ret_arr = [FieldInfo(field_identifier=field_ident.identifier) for field_ident in field_ident_arr] + ret_arr = [schemas.FieldInfo(field_identifier=field_ident.identifier) for field_ident in field_ident_arr] return ret_arr @@ -53,14 +34,14 @@ async def get_fields( async def get_cases( authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user), field_identifier: str = Query(description="Field identifier"), -) -> List[CaseInfo]: +) -> List[schemas.CaseInfo]: """Get list of cases for specified field""" sumo_inspector = SumoInspector(authenticated_user.get_sumo_access_token()) case_info_arr = await sumo_inspector.get_cases_async(field_identifier=field_identifier) - ret_arr: List[CaseInfo] = [] + ret_arr: List[schemas.CaseInfo] = [] - ret_arr = [CaseInfo(uuid=ci.uuid, name=ci.name, status=ci.status, user=ci.user) for ci in case_info_arr] + ret_arr = [schemas.CaseInfo(uuid=ci.uuid, name=ci.name, status=ci.status, user=ci.user) for ci in case_info_arr] return ret_arr @@ -69,14 +50,14 @@ async def get_cases( async def get_ensembles( authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user), case_uuid: str = Path(description="Sumo case uuid"), -) -> List[EnsembleInfo]: +) -> List[schemas.EnsembleInfo]: """Get list of ensembles for a case""" case_inspector = CaseInspector.from_case_uuid(authenticated_user.get_sumo_access_token(), case_uuid) iteration_info_arr = await case_inspector.get_iterations_async() print(iteration_info_arr) - return [EnsembleInfo(name=it.name, realization_count=it.realization_count) for it in iteration_info_arr] + return [schemas.EnsembleInfo(name=it.name, realization_count=it.realization_count) for it in iteration_info_arr] @router.get("/cases/{case_uuid}/ensembles/{ensemble_name}") @@ -84,7 +65,7 @@ async def get_ensemble_details( authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user), case_uuid: str = Path(description="Sumo case uuid"), ensemble_name: str = Path(description="Ensemble name"), -) -> EnsembleDetails: +) -> schemas.EnsembleDetails: """Get more detailed information for an ensemble""" case_inspector = CaseInspector.from_case_uuid(authenticated_user.get_sumo_access_token(), case_uuid) @@ -95,7 +76,7 @@ async def get_ensemble_details( if len(field_identifiers) != 1: raise NotImplementedError("Multiple field identifiers not supported") - return EnsembleDetails( + return schemas.EnsembleDetails( name=ensemble_name, case_name=case_name, case_uuid=case_uuid, diff --git a/backend_py/primary/primary/routers/explore/schemas.py b/backend_py/primary/primary/routers/explore/schemas.py new file mode 100644 index 000000000..4d61c069e --- /dev/null +++ b/backend_py/primary/primary/routers/explore/schemas.py @@ -0,0 +1,27 @@ +from typing import List, Sequence + +from pydantic import BaseModel + + +class FieldInfo(BaseModel): + field_identifier: str + + +class CaseInfo(BaseModel): + uuid: str + name: str + status: str + user: str + + +class EnsembleInfo(BaseModel): + name: str + realization_count: int + + +class EnsembleDetails(BaseModel): + name: str + field_identifier: str + case_name: str + case_uuid: str + realizations: Sequence[int] \ No newline at end of file diff --git a/backend_py/primary/tests/integration/conftest.py b/backend_py/primary/tests/integration/conftest.py index ca58bd6ba..73041ca45 100644 --- a/backend_py/primary/tests/integration/conftest.py +++ b/backend_py/primary/tests/integration/conftest.py @@ -15,13 +15,20 @@ @dataclass class SumoTestEnsemble: + field_identifier: str case_uuid: str + case_name: str ensemble_name: str @pytest.fixture(name="sumo_test_ensemble_prod", scope="session") def fixture_sumo_test_ensemble_prod() -> SumoTestEnsemble: - return SumoTestEnsemble(case_uuid="485041ce-ad72-48a3-ac8c-484c0ed95cf8", ensemble_name="iter-0") + return SumoTestEnsemble( + field_identifier="DROGON", + case_name="webviz_ahm_case", + case_uuid="485041ce-ad72-48a3-ac8c-484c0ed95cf8", + ensemble_name="iter-0", + ) @pytest.fixture(name="test_user", scope="session") diff --git a/backend_py/primary/tests/integration/routers/explore/test_explore.py b/backend_py/primary/tests/integration/routers/explore/test_explore.py index 2a3a27e00..d42caae83 100644 --- a/backend_py/primary/tests/integration/routers/explore/test_explore.py +++ b/backend_py/primary/tests/integration/routers/explore/test_explore.py @@ -1,9 +1,31 @@ -from primary.routers.explore import router, FieldInfo, CaseInfo,EnsembleInfo, EnsembleDetails +from primary.routers.explore import router +from primary.routers.explore import schemas +async def test_get_fields(test_user, sumo_test_ensemble_prod) -> None: + fields = await router.get_fields(test_user) + assert all(isinstance(f, schemas.FieldInfo) for f in fields) + assert any(f.field_identifier == sumo_test_ensemble_prod.field_identifier for f in fields) + +async def test_get_cases(test_user, sumo_test_ensemble_prod) -> None: + cases = await router.get_cases(test_user, sumo_test_ensemble_prod.field_identifier) + assert all(isinstance(c, schemas.CaseInfo) for c in cases) + assert any(c.uuid == sumo_test_ensemble_prod.case_uuid for c in cases) -async def test_get_fields(test_user) -> None: - print(dir(router)) - field_list = await router.get_fields(test_user) - assert isinstance(field_list, list[FieldInfo]) +async def test_get_ensembles(test_user, sumo_test_ensemble_prod) -> None: + ensembles = await router.get_ensembles(test_user, sumo_test_ensemble_prod.case_uuid) + assert all(isinstance(e, schemas.EnsembleInfo) for e in ensembles) + assert any(e.name == sumo_test_ensemble_prod.ensemble_name for e in ensembles) + + +async def test_get_ensemble_details(test_user, sumo_test_ensemble_prod) -> None: + ensemble_details = await router.get_ensemble_details( + test_user, sumo_test_ensemble_prod.case_uuid, sumo_test_ensemble_prod.ensemble_name + ) + assert isinstance(ensemble_details, schemas.EnsembleDetails) + assert ensemble_details.name == sumo_test_ensemble_prod.ensemble_name + assert ensemble_details.field_identifier == sumo_test_ensemble_prod.field_identifier + assert ensemble_details.case_uuid == sumo_test_ensemble_prod.case_uuid + assert ensemble_details.case_name == sumo_test_ensemble_prod.case_name + assert len(ensemble_details.realizations) == 100