Skip to content

Commit

Permalink
feat(coverage): add class to call coverage api
Browse files Browse the repository at this point in the history
Changelog
* Rename BaseHttpClient in ApiRegionClient
* Fix session definition in ApiBaseClient
* Update tests
* Add class CoverageApiClient to call coverage api from navitia
* Add tests for the latest class
  • Loading branch information
jonperron committed Apr 6, 2024
1 parent 77386e2 commit 83df58b
Show file tree
Hide file tree
Showing 16 changed files with 239 additions and 43 deletions.
Empty file.
11 changes: 11 additions & 0 deletions navitia_client/client/apis/api_base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from requests import Response, Session # type: ignore


class ApiBaseClient:
def __init__(self, auth_token: str, base_navitia_url: str) -> None:
self.base_navitia_url: str = base_navitia_url
self.session = Session()
self.session.headers.update({"Authorization": auth_token})

def get_navitia_api(self, endpoint: str) -> Response:
return self.session.get(endpoint)
64 changes: 64 additions & 0 deletions navitia_client/client/apis/coverage_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from datetime import datetime
from typing import Any, Sequence

from navitia_client.client.apis.api_base_client import ApiBaseClient
from navitia_client.entities.administrative_region import Region


class CoverageApiClient(ApiBaseClient):
@staticmethod
def _get_regions_from_response(raw_regions_response: Any) -> Sequence[Region]:
regions = []
for region in raw_regions_response:
regions.append(
Region(
id=region.get("id"),
name=region.get("name"),
dataset_created_at=datetime.fromisoformat(
region.get("dataset_created_at")
)
if region.get("dataset_created_at")
else None,
end_production_date=datetime.strptime(
region.get("end_production_date"), "%Y%m%d"
)
if region.get("end_production_date")
else None,
last_load_at=datetime.fromisoformat(region.get("last_load_at"))
if region.get("last_load_at")
else None,
shape=region.get("shape"),
start_production_date=datetime.strptime(
region.get("start_production_date"), "%Y%m%d"
)
if region.get("end_production_date")
else None,
status=region.get("status"),
)
)
return regions

def list_covered_areas(self) -> Sequence[Region]:
results = self.get_navitia_api(f"{self.base_navitia_url}/coverage")
result_regions = results.json()["regions"]
regions = CoverageApiClient._get_regions_from_response(result_regions)
return regions

def get_region_by_id(self, region_id: str) -> Sequence[Region]:
results = self.get_navitia_api(f"{self.base_navitia_url}/coverage/{region_id}")
result_regions = results.json()["regions"]
regions = CoverageApiClient._get_regions_from_response(result_regions)

return regions

def get_region_by_coordinates(self, lon: float, lat: float) -> Sequence[Region]:
results = self.get_navitia_api(
f"{self.base_navitia_url}/coverage/"
+ "{"
+ "{0};{1}".format(lon, lat)
+ "}"
)
result_regions = results.json()["regions"]
regions = CoverageApiClient._get_regions_from_response(result_regions)

return regions
11 changes: 0 additions & 11 deletions navitia_client/client/base_client.py

This file was deleted.

16 changes: 16 additions & 0 deletions navitia_client/client/navitia_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
from navitia_client.client.apis.coverage_apis import CoverageApiClient

BASE_NAVITIA_URL: str = "https://api.navitia.io/v1/"


@dataclass
class NavitiaClient:
auth_token: str
base_navitia_url: str = BASE_NAVITIA_URL

@property
def coverage(self) -> CoverageApiClient:
return CoverageApiClient(
auth_token=self.auth_token, base_navitia_url=self.base_navitia_url
)
3 changes: 1 addition & 2 deletions navitia_client/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from .physical_mode import PhysicalModeId, CommercialMode, PhysicalMode
from .place import Place, PlaceEmbeddedType
from .pt_object import PtObject, PtObjectEmbeddedType
from .line import Line
from .route import Route
from .line_and_route import Line, Route
from .disruption import DisruptionStatus, Disruption
from .context import Context
from .pt_datetime import PTDatetime
Expand Down
12 changes: 12 additions & 0 deletions navitia_client/entities/administrative_region.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from .base_entity import BaseEntity
from .coord import Coord


@dataclass
class Region(BaseEntity):
dataset_created_at: Optional[datetime]
end_production_date: Optional[datetime]
last_load_at: Optional[datetime]
shape: Optional[str]
start_production_date: Optional[datetime]
status: Optional[str]


@dataclass
class AdministrativeRegion(BaseEntity):
label: str
Expand Down
2 changes: 1 addition & 1 deletion navitia_client/entities/disruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Sequence

from .pt_object import PtObject
from .route import Route
from .line_and_route import Route
from .stop_area import StopPoint


Expand Down
2 changes: 1 addition & 1 deletion navitia_client/entities/equipment_reports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from .line import Line
from .line_and_route import Line
from .equipment import StopAreaEquipments


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass

from typing import Sequence

from .base_entity import BaseEntity
from .physical_mode import CommercialMode, PhysicalMode
from .route import Route
from .place import Place


@dataclass
Expand All @@ -12,6 +13,13 @@ class Line(BaseEntity):
color: str
opening_time: str
closing_time: str
routes: Sequence[Route]
routes: Sequence["Route"]
commercial_mode: CommercialMode
physical_modes: Sequence[PhysicalMode]


@dataclass
class Route(BaseEntity):
is_frequence: bool
line: Line
direction: Place
3 changes: 1 addition & 2 deletions navitia_client/entities/pt_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from .base_entity import BaseEntity
from .physical_mode import CommercialMode
from .network import Network
from .line import Line
from .route import Route
from .line_and_route import Line, Route
from .stop_area import StopArea, StopPoint
from .trip import Trip

Expand Down
12 changes: 0 additions & 12 deletions navitia_client/entities/route.py

This file was deleted.

Empty file added tests/client/apis/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/client/apis/test_api_base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from navitia_client.client.apis.api_base_client import ApiBaseClient


def test_http_base_client() -> None:
# Given
auth_token = "foobar"

# When
client = ApiBaseClient(auth_token=auth_token)

# Then
assert isinstance(client, ApiBaseClient)
110 changes: 110 additions & 0 deletions tests/client/apis/test_coverage_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from datetime import datetime
from unittest.mock import MagicMock, patch

import pytest

from navitia_client.client.apis.coverage_apis import CoverageApiClient
from navitia_client.entities.administrative_region import Region


@pytest.fixture
def coverage_apis():
return CoverageApiClient(auth_token="foobar")


@patch.object(CoverageApiClient, "get_navitia_api")
def test_list_covered_areas(
mock_get_navitia_api: MagicMock, coverage_apis: CoverageApiClient
) -> None:
# Given
mock_response = MagicMock()
mock_response.json.return_value = {
"regions": [
{
"id": "region1",
"name": "Region 1",
"dataset_created_at": "2022-01-01T00:00:00",
"end_production_date": "20221231",
"last_load_at": "2022-01-01T00:00:00",
"shape": "shape_data",
"start_production_date": "20220101",
"status": "active",
}
]
}
mock_get_navitia_api.return_value = mock_response

# When
regions = coverage_apis.list_covered_areas()

# Then
assert len(regions) == 1
assert isinstance(regions[0], Region)
assert regions[0].id == "region1"
assert regions[0].name == "Region 1"
assert regions[0].dataset_created_at == datetime(2022, 1, 1, 0, 0)
assert regions[0].end_production_date == datetime(2022, 12, 31, 0, 0)
assert regions[0].last_load_at == datetime(2022, 1, 1, 0, 0)
assert regions[0].shape == "shape_data"
assert regions[0].start_production_date == datetime(2022, 1, 1, 0, 0)
assert regions[0].status == "active"


@patch.object(CoverageApiClient, "get_navitia_api")
def test_get_region_by_id(
mock_get_navitia_api: MagicMock, coverage_apis: CoverageApiClient
) -> None:
# Given
mock_response = MagicMock()
mock_response.json.return_value = {
"regions": [
{
"id": "region1",
"name": "Region 1",
"dataset_created_at": "2022-01-01T00:00:00",
"end_production_date": "20221231",
"last_load_at": "2022-01-01T00:00:00",
"shape": "shape_data",
"start_production_date": "20220101",
"status": "active",
}
]
}
mock_get_navitia_api.return_value = mock_response

# When
regions = coverage_apis.get_region_by_id("12")

# Then
assert len(regions) == 1
assert isinstance(regions[0], Region)


@patch.object(CoverageApiClient, "get_navitia_api")
def test_get_region_by_coordinates(
mock_get_navitia_api: MagicMock, coverage_apis: CoverageApiClient
) -> None:
# Given
mock_response = MagicMock()
mock_response.json.return_value = {
"regions": [
{
"id": "region1",
"name": "Region 1",
"dataset_created_at": "2022-01-01T00:00:00",
"end_production_date": "20221231",
"last_load_at": "2022-01-01T00:00:00",
"shape": "shape_data",
"start_production_date": "20220101",
"status": "active",
}
]
}
mock_get_navitia_api.return_value = mock_response

# When
regions = coverage_apis.get_region_by_coordinates(12.5, 13.2)

# Then
assert len(regions) == 1
assert isinstance(regions[0], Region)
12 changes: 0 additions & 12 deletions tests/client/test_base_client.py

This file was deleted.

0 comments on commit 83df58b

Please sign in to comment.