diff --git a/.env.template b/.env.template index 78735791..a801e445 100644 --- a/.env.template +++ b/.env.template @@ -3,6 +3,7 @@ ############################################################################################### # these values are used in the local docker env. You can use "localhost" hostname # if you run the application without docker +POSTGRES_DRIVER=postgresql POSTGRES_HOSTNAME=postgres_bloom POSTGRES_USER=bloom_user POSTGRES_PASSWORD=bloom diff --git a/backend/bloom/config.py b/backend/bloom/config.py index 1b875151..9ae8d4b8 100644 --- a/backend/bloom/config.py +++ b/backend/bloom/config.py @@ -40,6 +40,7 @@ class Settings(BaseSettings): default=5432) postgres_db:str = Field(min_length=1,max_length=32,pattern=r'^(?:[a-zA-Z]|_)[\w\d_]*$') + postgres_schema:str = Field(default='public') srid: int = Field(default=4326) spire_token:str = Field(default='') data_folder:str=Field(default=str(Path(__file__).parent.parent.parent.joinpath('./data'))) diff --git a/backend/bloom/container.py b/backend/bloom/container.py index 87b5c359..d5e77c90 100644 --- a/backend/bloom/container.py +++ b/backend/bloom/container.py @@ -13,6 +13,7 @@ from bloom.services.metrics import MetricsService from bloom.usecase.GenerateAlerts import GenerateAlerts from dependency_injector import containers, providers +import redis class UseCases(containers.DeclarativeContainer): @@ -25,7 +26,6 @@ class UseCases(containers.DeclarativeContainer): vessel_repository = providers.Factory( VesselRepository, - session_factory=db.provided.session, ) alert_repository = providers.Factory( diff --git a/backend/bloom/infra/database/sql_model.py b/backend/bloom/infra/database/sql_model.py index 57a3f855..2f8dab2b 100644 --- a/backend/bloom/infra/database/sql_model.py +++ b/backend/bloom/infra/database/sql_model.py @@ -21,6 +21,7 @@ class Vessel(Base): __tablename__ = "dim_vessel" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) mmsi = Column("mmsi", Integer) ship_name = Column("ship_name", String, nullable=False) @@ -46,6 +47,7 @@ class Vessel(Base): class Alert(Base): __tablename__ = "alert" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True, index=True) timestamp = Column("timestamp", DateTime) mpa_id = Column("mpa_id", Integer) @@ -54,6 +56,7 @@ class Alert(Base): class Port(Base): __tablename__ = "dim_port" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True, index=True) name = Column("name", String, nullable=False) locode = Column("locode", String, nullable=False) @@ -70,6 +73,7 @@ class Port(Base): class SpireAisData(Base): __tablename__ = "spire_ais_data" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) spire_update_statement = Column("spire_update_statement", DateTime(timezone=True)) @@ -107,6 +111,7 @@ class SpireAisData(Base): class Zone(Base): __tablename__ = "dim_zone" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) category = Column("category", String, nullable=False) sub_category = Column("sub_category", String) @@ -119,6 +124,7 @@ class Zone(Base): class WhiteZone(Base): __tablename__ = "dim_white_zone" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) geometry = Column("geometry", Geometry(geometry_type="GEOMETRY", srid=settings.srid)) created_at = Column("created_at", DateTime(timezone=True), server_default=func.now()) @@ -127,6 +133,7 @@ class WhiteZone(Base): class VesselPosition(Base): __tablename__ = "vessel_positions" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) accuracy = Column("accuracy", String) @@ -146,6 +153,7 @@ class VesselPosition(Base): class VesselData(Base): __tablename__ = "vessel_data" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) ais_class = Column("ais_class", String) @@ -164,6 +172,7 @@ class VesselData(Base): class VesselVoyage(Base): __tablename__ = "vessel_voyage" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) timestamp = Column("timestamp", DateTime(timezone=True), nullable=False) destination = Column("destination", String) @@ -175,6 +184,7 @@ class VesselVoyage(Base): class Excursion(Base): __tablename__ = "fct_excursion" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) vessel_id = Column("vessel_id", Integer, ForeignKey("dim_vessel.id"), nullable=False) departure_port_id = Column("departure_port_id", Integer, ForeignKey("dim_port.id")) @@ -199,6 +209,7 @@ class Excursion(Base): class Segment(Base): __tablename__ = "fct_segment" + __table_args__ = {'schema': settings.postgres_schema} id = Column("id", Integer, primary_key=True) excursion_id = Column("excursion_id", Integer, ForeignKey("fct_excursion.id"), nullable=False) timestamp_start = Column("timestamp_start", DateTime(timezone=True)) @@ -224,6 +235,7 @@ class Segment(Base): class TaskExecution(Base): __tablename__ = "tasks_executions" + __table_args__ = {'schema': settings.postgres_schema} task_name = Column("task_name", String, primary_key=True) point_in_time = Column("point_in_time", DateTime(timezone=True)) created_at = Column("created_at", DateTime(timezone=True), server_default=func.now()) @@ -234,6 +246,7 @@ class RelSegmentZone(Base): __tablename__ = "rel_segment_zone" __table_args__ = ( PrimaryKeyConstraint('segment_id', 'zone_id'), + {'schema': settings.postgres_schema} ) segment_id = Column("segment_id", Integer, ForeignKey("fct_segment.id"), nullable=False) zone_id = Column("zone_id", Integer, ForeignKey("dim_zone.id"), nullable=False) @@ -253,6 +266,7 @@ class RelSegmentZone(Base): class MetricsVesselInActivity(Base): __table__ = vessel_in_activity_request + __table_args__ = {'schema': settings.postgres_schema} #vessel_id: Mapped[Optional[int]] #total_time_at_sea: Mapped[Optional[timedelta]] diff --git a/backend/bloom/infra/repositories/repository_port.py b/backend/bloom/infra/repositories/repository_port.py index 264dc6e1..cc6a2458 100644 --- a/backend/bloom/infra/repositories/repository_port.py +++ b/backend/bloom/infra/repositories/repository_port.py @@ -12,23 +12,16 @@ from sqlalchemy import func, or_, and_, select, update, asc, text from sqlalchemy.orm import Session +from bloom.infra.repository import GenericRepository, GenericSqlRepository +from abc import ABC, abstractmethod -class PortRepository: - def __init__(self, session_factory: Callable) -> None: - self.session_factory = session_factory - - def get_port_by_id(self, session: Session, port_id: int) -> Union[Port, None]: - entity = session.get(sql_model.Port, port_id) - if entity is not None: - return PortRepository.map_to_domain(entity) - else: - return None +class PortRepositoryBase(GenericRepository[Port], ABC): + def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: + raise NotImplementedError() - def get_all_ports(self, session: Session) -> List[Port]: - q = session.query(sql_model.Port) - if not q: - return [] - return [PortRepository.map_to_domain(entity) for entity in q] +class PortRepository(GenericSqlRepository[Port,sql_model.Port],PortRepositoryBase): + def __init__(self,session:Session) -> None: + super().__init__(session=session,model_cls=sql_model.Vessel, schema_cls=Port) def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: stmt = select(sql_model.Port).where(sql_model.Port.geometry_buffer.is_(None)) @@ -36,101 +29,126 @@ def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: if not q: return [] return [PortRepository.map_to_domain(entity) for entity in q] - - def get_ports_updated_created_after(self, session: Session, created_updated_after: datetime) -> list[Port]: - stmt = select(sql_model.Port).where(or_(sql_model.Port.created_at >= created_updated_after, - sql_model.Port.updated_at >= created_updated_after)) - q = session.execute(stmt).scalars() - if not q: - return [] - return [PortRepository.map_to_domain(entity) for entity in q] - - def update_geometry_buffer(self, session: Session, port_id: int, buffer: Polygon) -> None: - session.execute(update(sql_model.Port), [{"id": port_id, "geometry_buffer": from_shape(buffer)}]) - - def batch_update_geometry_buffer(self, session: Session, id_buffers: list[dict[str, Any]]) -> None: - items = [{"id": item["id"], "geometry_buffer": from_shape(item["geometry_buffer"])} for item in id_buffers] - session.execute(update(sql_model.Port), items) - - def create_port(self, session: Session, port: Port) -> Port: - orm_port = PortRepository.map_to_sql(port) - session.add(orm_port) - return PortRepository.map_to_domain(orm_port) - - def batch_create_port(self, session: Session, ports: list[Port]) -> list[Port]: - orm_list = [PortRepository.map_to_sql(port) for port in ports] - session.add_all(orm_list) - return [PortRepository.map_to_domain(orm) for orm in orm_list] - - def find_port_by_position_in_port_buffer(self, session: Session, position: Point) -> Union[Port, None]: - stmt = select(sql_model.Port).where( - func.ST_contains(sql_model.Port.geometry_buffer, from_shape(position, srid=settings.srid)) == True) - port = session.execute(stmt).scalar() - if not port: - return None - else: - return PortRepository.map_to_domain(port) - - def find_port_by_distance(self, - session: Session, - longitude: float, - latitude: float, - threshold_distance_to_port: float) -> Union[Port, None]: - position = Point(longitude, latitude) - stmt = select(sql_model.Port).where( - and_( - func.ST_within(from_shape(position, srid=settings.srid), - sql_model.Port.geometry_buffer) == True, - func.ST_distance(from_shape(position, srid=settings.srid), - sql_model.Port.geometry_point) < threshold_distance_to_port - ) - ).order_by(asc(func.ST_distance(from_shape(position, srid=settings.srid), - sql_model.Port.geometry_point))) - result = session.execute(stmt).scalars() - return [PortRepository.map_to_domain(e) for e in result] - - def get_closest_port_in_range(self, session: Session, longitude: float, latitude: float, range: float) -> Union[ - tuple[int, float], None]: - res = session.execute(text("""SELECT id,ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) - FROM dim_port WHERE ST_Within(ST_POINT(:longitude,:latitude, 4326),geometry_buffer) = true - AND ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) < :range - ORDER by ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) ASC LIMIT 1"""), - {"longitude": longitude, "latitude": latitude, "range": range}).first() - return res - - @staticmethod - def map_to_domain(orm_port: sql_model.Port) -> Port: - return Port( - id=orm_port.id, - name=orm_port.name, - locode=orm_port.locode, - url=orm_port.url, - country_iso3=orm_port.country_iso3, - latitude=orm_port.latitude, - longitude=orm_port.longitude, - geometry_point=to_shape(orm_port.geometry_point), - geometry_buffer=to_shape(orm_port.geometry_buffer) - if orm_port.geometry_buffer is not None - else None, - has_excursion=orm_port.has_excursion, - created_at=orm_port.created_at, - updated_at=orm_port.updated_at, - ) - - @staticmethod - def map_to_sql(port: Port) -> sql_model.Port: - return sql_model.Port( - name=port.name, - locode=port.locode, - url=port.url, - country_iso3=port.country_iso3, - latitude=port.latitude, - longitude=port.longitude, - geometry_point=from_shape(port.geometry_point), - geometry_buffer=from_shape(port.geometry_buffer) - if port.geometry_buffer is not None - else None, - has_excursion=port.has_excursion, - created_at=port.created_at, - updated_at=port.updated_at, - ) + pass + +# class PortRepository: +# def __init__(self, session_factory: Callable) -> None: +# self.session_factory = session_factory + +# def get_port_by_id(self, session: Session, port_id: int) -> Union[Port, None]: +# entity = session.get(sql_model.Port, port_id) +# if entity is not None: +# return PortRepository.map_to_domain(entity) +# else: +# return None + +# def get_all_ports(self, session: Session) -> List[Port]: +# q = session.query(sql_model.Port) +# if not q: +# return [] +# return [PortRepository.map_to_domain(entity) for entity in q] + +# def get_empty_geometry_buffer_ports(self, session: Session) -> list[Port]: +# stmt = select(sql_model.Port).where(sql_model.Port.geometry_buffer.is_(None)) +# q = session.execute(stmt).scalars() +# if not q: +# return [] +# return [PortRepository.map_to_domain(entity) for entity in q] + +# def get_ports_updated_created_after(self, session: Session, created_updated_after: datetime) -> list[Port]: +# stmt = select(sql_model.Port).where(or_(sql_model.Port.created_at >= created_updated_after, +# sql_model.Port.updated_at >= created_updated_after)) +# q = session.execute(stmt).scalars() +# if not q: +# return [] +# return [PortRepository.map_to_domain(entity) for entity in q] + +# def update_geometry_buffer(self, session: Session, port_id: int, buffer: Polygon) -> None: +# session.execute(update(sql_model.Port), [{"id": port_id, "geometry_buffer": from_shape(buffer)}]) + +# def batch_update_geometry_buffer(self, session: Session, id_buffers: list[dict[str, Any]]) -> None: +# items = [{"id": item["id"], "geometry_buffer": from_shape(item["geometry_buffer"])} for item in id_buffers] +# session.execute(update(sql_model.Port), items) + +# def create_port(self, session: Session, port: Port) -> Port: +# orm_port = PortRepository.map_to_sql(port) +# session.add(orm_port) +# return PortRepository.map_to_domain(orm_port) + +# def batch_create_port(self, session: Session, ports: list[Port]) -> list[Port]: +# orm_list = [PortRepository.map_to_sql(port) for port in ports] +# session.add_all(orm_list) +# return [PortRepository.map_to_domain(orm) for orm in orm_list] + +# def find_port_by_position_in_port_buffer(self, session: Session, position: Point) -> Union[Port, None]: +# stmt = select(sql_model.Port).where( +# func.ST_contains(sql_model.Port.geometry_buffer, from_shape(position, srid=settings.srid)) == True) +# port = session.execute(stmt).scalar() +# if not port: +# return None +# else: +# return PortRepository.map_to_domain(port) + +# def find_port_by_distance(self, +# session: Session, +# longitude: float, +# latitude: float, +# threshold_distance_to_port: float) -> Union[Port, None]: +# position = Point(longitude, latitude) +# stmt = select(sql_model.Port).where( +# and_( +# func.ST_within(from_shape(position, srid=settings.srid), +# sql_model.Port.geometry_buffer) == True, +# func.ST_distance(from_shape(position, srid=settings.srid), +# sql_model.Port.geometry_point) < threshold_distance_to_port +# ) +# ).order_by(asc(func.ST_distance(from_shape(position, srid=settings.srid), +# sql_model.Port.geometry_point))) +# result = session.execute(stmt).scalars() +# return [PortRepository.map_to_domain(e) for e in result] + +# def get_closest_port_in_range(self, session: Session, longitude: float, latitude: float, range: float) -> Union[ +# tuple[int, float], None]: +# res = session.execute(text("""SELECT id,ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) +# FROM dim_port WHERE ST_Within(ST_POINT(:longitude,:latitude, 4326),geometry_buffer) = true +# AND ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) < :range +# ORDER by ST_Distance(ST_POINT(:longitude,:latitude, 4326)::geography, geometry_point::geography) ASC LIMIT 1"""), +# {"longitude": longitude, "latitude": latitude, "range": range}).first() +# return res + +# @staticmethod +# def map_to_domain(orm_port: sql_model.Port) -> Port: +# return Port( +# id=orm_port.id, +# name=orm_port.name, +# locode=orm_port.locode, +# url=orm_port.url, +# country_iso3=orm_port.country_iso3, +# latitude=orm_port.latitude, +# longitude=orm_port.longitude, +# geometry_point=to_shape(orm_port.geometry_point), +# geometry_buffer=to_shape(orm_port.geometry_buffer) +# if orm_port.geometry_buffer is not None +# else None, +# has_excursion=orm_port.has_excursion, +# created_at=orm_port.created_at, +# updated_at=orm_port.updated_at, +# ) + +# @staticmethod +# def map_to_sql(port: Port) -> sql_model.Port: +# return sql_model.Port( +# name=port.name, +# locode=port.locode, +# url=port.url, +# country_iso3=port.country_iso3, +# latitude=port.latitude, +# longitude=port.longitude, +# geometry_point=from_shape(port.geometry_point), +# geometry_buffer=from_shape(port.geometry_buffer) +# if port.geometry_buffer is not None +# else None, +# has_excursion=port.has_excursion, +# created_at=port.created_at, +# updated_at=port.updated_at, +# ) diff --git a/backend/bloom/infra/repositories/repository_vessel.py b/backend/bloom/infra/repositories/repository_vessel.py index c82b5c37..256631c1 100644 --- a/backend/bloom/infra/repositories/repository_vessel.py +++ b/backend/bloom/infra/repositories/repository_vessel.py @@ -1,5 +1,5 @@ from contextlib import AbstractContextManager -from typing import Any, Generator, Union +from typing import Union from bloom.domain.vessel import Vessel from bloom.infra.database import sql_model @@ -7,7 +7,68 @@ from sqlalchemy import func, select, update, and_ from sqlalchemy.orm import Session +from bloom.infra.repository import GenericRepository, GenericSqlRepository +from abc import ABC, abstractmethod +from bloom.domain.vessel import Vessel +from bloom.infra.database import sql_model +from dependency_injector.providers import Callable + + +class VesselRepositoryBase(GenericRepository[Vessel], ABC): + @abstractmethod + def set_tracking(self, vessel_ids: list[int], tracking_activated: bool, + tracking_status: str) -> None: + raise NotImplementedError() + def check_mmsi_integrity(self) -> list[(int, int)]: + raise NotImplementedError() + +class VesselRepository(GenericSqlRepository[Vessel,sql_model.Vessel],VesselRepositoryBase): + def __init__(self,session:Session) -> None: + super().__init__(session=session,model_cls=sql_model.Vessel, schema_cls=Vessel) + + def set_tracking(self, vessel_ids: list[int], tracking_activated: bool, + tracking_status: str) -> None: + updates = [{"id": id, "tracking_activated": tracking_activated, "tracking_status": tracking_status} for id in + vessel_ids] + self._session.execute(update(sql_model.Vessel), updates) + + def check_mmsi_integrity(self) -> list[(int, int)]: + # Recherche des valeurs distinctes de MMSI ayant un nombre de résultats actif > 1 + stmt = select(sql_model.Vessel.mmsi, func.count(sql_model.Vessel.id).label("count")).group_by( + sql_model.Vessel.mmsi).having( + func.count(sql_model.Vessel.id) > 1).where( + sql_model.Vessel.tracking_activated == True) + return self._session.execute(stmt).all() + + def map_to_domain(self, model: sql_model.Vessel) -> Vessel: + return Vessel( + id=model.id, + mmsi=model.mmsi, + ship_name=model.ship_name, + width=model.width, + length=model.length, + country_iso3=model.country_iso3, + type=model.type, + imo=model.imo, + cfr=model.cfr, + external_marking=model.external_marking, + ircs=model.ircs, + tracking_activated=model.tracking_activated, + tracking_status=model.tracking_status, + home_port_id=model.home_port_id, + created_at=model.created_at, + updated_at=model.updated_at, + details=model.details, + check=model.check, + length_class=model.length_class, + ) + + def map_to_model(self, schema: Vessel) -> sql_model.Vessel: + return sql_model.Vessel(**schema.__dict__) + + +""" class VesselRepository: def __init__( self, @@ -32,9 +93,9 @@ def get_activated_vessel_by_mmsi(self, session: Session, mmsi: int) -> Union[Ves return VesselRepository.map_to_domain(vessel) def get_vessels_list(self, session: Session) -> list[Vessel]: - """ + """""" Liste l'ensemble des vessels actifs - """ + """""" stmt = select(sql_model.Vessel).where(sql_model.Vessel.tracking_activated == True) e = session.execute(stmt).scalars() if not e: @@ -42,9 +103,9 @@ def get_vessels_list(self, session: Session) -> list[Vessel]: return [VesselRepository.map_to_domain(vessel) for vessel in e] def get_all_vessels_list(self, session: Session) -> list[Vessel]: - """ + """""" Liste l'ensemble des vessels actifs ou inactifs - """ + """""" stmt = select(sql_model.Vessel) e = session.execute(stmt).scalars() @@ -81,27 +142,27 @@ def check_mmsi_integrity(self, session: Session) -> list[(int, int)]: return session.execute(stmt).all() @staticmethod - def map_to_domain(sql_vessel: sql_model.Vessel) -> Vessel: + def map_to_domain(model: sql_model.Vessel) -> Vessel: return Vessel( - id=sql_vessel.id, - mmsi=sql_vessel.mmsi, - ship_name=sql_vessel.ship_name, - width=sql_vessel.width, - length=sql_vessel.length, - country_iso3=sql_vessel.country_iso3, - type=sql_vessel.type, - imo=sql_vessel.imo, - cfr=sql_vessel.cfr, - external_marking=sql_vessel.external_marking, - ircs=sql_vessel.ircs, - tracking_activated=sql_vessel.tracking_activated, - tracking_status=sql_vessel.tracking_status, - home_port_id=sql_vessel.home_port_id, - created_at=sql_vessel.created_at, - updated_at=sql_vessel.updated_at, - details=sql_vessel.details, - check=sql_vessel.check, - length_class=sql_vessel.length_class, + id=model.id, + mmsi=model.mmsi, + ship_name=model.ship_name, + width=model.width, + length=model.length, + country_iso3=model.country_iso3, + type=model.type, + imo=model.imo, + cfr=model.cfr, + external_marking=model.external_marking, + ircs=model.ircs, + tracking_activated=model.tracking_activated, + tracking_status=model.tracking_status, + home_port_id=model.home_port_id, + created_at=model.created_at, + updated_at=model.updated_at, + details=model.details, + check=model.check, + length_class=model.length_class, ) @staticmethod @@ -127,3 +188,4 @@ def map_to_sql(vessel: Vessel) -> sql_model.Vessel: check=vessel.check, length_class=vessel.length_class, ) + """ \ No newline at end of file diff --git a/backend/bloom/infra/repository.py b/backend/bloom/infra/repository.py new file mode 100644 index 00000000..b47f7778 --- /dev/null +++ b/backend/bloom/infra/repository.py @@ -0,0 +1,124 @@ +from typing import TypeVar,Type,Generic, Optional, List, Any +from abc import ABC,abstractmethod +from sqlalchemy import select +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import ScalarSelect, and_, or_ +from dependency_injector.providers import Callable + +SCHEMA = TypeVar("SCHEMA", bound=BaseModel) +MODEL = TypeVar("MODEL", bound=Any) + +class GenericRepository(Generic[SCHEMA], ABC): + + @abstractmethod + def get_by_id(self, id: int) -> Optional[SCHEMA]: + raise NotImplementedError() + + @abstractmethod + def list(self, **filters) -> List[SCHEMA]: + raise NotImplementedError() + + @abstractmethod + def add(self, record: SCHEMA) -> SCHEMA: + raise NotImplementedError() + + @abstractmethod + def add(self, records: List[SCHEMA]) -> List[SCHEMA]: + raise NotImplementedError() + + @abstractmethod + def update(self, record: SCHEMA) -> SCHEMA: + raise NotImplementedError() + + @abstractmethod + def update(self, records: List[SCHEMA]) -> List[SCHEMA]: + raise NotImplementedError() + + @abstractmethod + def delete(self, id: int) -> None: + raise NotImplementedError() + + def delete(self, ids: List[int]) -> None: + raise NotImplementedError() + +class GenericSqlRepository(GenericRepository[SCHEMA],Generic[SCHEMA,MODEL], ABC): + def __init__(self, + session: Session, + model_cls: Type[MODEL], + schema_cls: Type[SCHEMA]) -> None: + self._session = session + self._model_cls = model_cls + self._schema_cls = model_cls + + def _construct_get_stmt(self, id: int) -> ScalarSelect: + stmt = select(self._model_cls).where(self._model_cls.id == id) + return stmt + + def get_by_id(self, id: int) -> Optional[SCHEMA]: + stmt = self._construct_get_stmt(id) + return self.map_to_domain(self._session.execute(stmt).scalar_one_or_none()) + + def _construct_list_stmt(self, **filters) -> ScalarSelect: + stmt = select(self._model_cls) + where_clauses = [] + for c, v in filters.items(): + if not hasattr(self._model_cls, c): + raise ValueError(f"Invalid column name {c}") + where_clauses.append(getattr(self._model_cls, c) == v) + + if len(where_clauses) == 1: + stmt = stmt.where(where_clauses[0]) + elif len(where_clauses) > 1: + stmt = stmt.where(and_(*where_clauses)) + return stmt + + def list(self, **filters) -> List[SCHEMA]: + stmt = self._construct_list_stmt(**filters) + return [ self.map_to_domain(item) for item in self._session.execute(stmt).scalars()] + + def add(self, record: SCHEMA) -> SCHEMA: + self._session.add(record) + self._session.flush() + self._session.refresh(record) + return record + + def add(self, records: List[SCHEMA]) -> List[SCHEMA]: + [self._session.add(record) for record in records] + self._session.flush() + self._session.refresh(records) + return records + + def update(self, record: SCHEMA) -> SCHEMA: + self._session.add(record) + self._session.flush() + self._session.refresh(record) + return record + + def update(self, records: List[SCHEMA]) -> List[SCHEMA]: + [self._session.add(record) for record in records] + self._session.flush() + self._session.refresh(records) + return records + + def delete(self, id: int) -> None: + record = self.get_by_id(id) + if record is not None: + self._session.delete(record) + self._session.flush() + + def delete(self, ids: List[int]) -> None: + for id in ids: + record = self.get_by_id(id) + if record is not None: + self._session.delete(record) + self._session.flush() + + + @abstractmethod + def map_to_domain(self,model: MODEL) -> SCHEMA: + raise NotImplementedError() + + @abstractmethod + def map_to_model(self,schema: SCHEMA) -> MODEL: + raise NotImplementedError() \ No newline at end of file diff --git a/backend/bloom/routers/v1/ports.py b/backend/bloom/routers/v1/ports.py index be0fe278..acb310ff 100644 --- a/backend/bloom/routers/v1/ports.py +++ b/backend/bloom/routers/v1/ports.py @@ -39,11 +39,11 @@ async def list_ports( request:Request, return payload else: use_cases = UseCases() - port_repository = use_cases.port_repository() db = use_cases.db() with db.session() as session: + port_repository = use_cases.port_repository(session) json_data = [json.loads(p.model_dump_json() if p else "{}") - for p in port_repository.get_all_ports(session)] + for p in port_repository.list()] rd.set(endpoint, json.dumps(json_data)) rd.expire(endpoint,settings.redis_cache_expiration) logger.debug(f"{endpoint} elapsed Time: {time.time()-start}") @@ -55,7 +55,7 @@ async def get_port(port_id:int, key: str = Depends(X_API_KEY_HEADER)): check_apikey(key) use_cases = UseCases() - port_repository = use_cases.port_repository() db = use_cases.db() with db.session() as session: - return port_repository.get_port_by_id(session,port_id) \ No newline at end of file + port_repository = use_cases.port_repository(session) + return port_repository.get_by_id(port_id) \ No newline at end of file diff --git a/backend/bloom/routers/v1/vessels.py b/backend/bloom/routers/v1/vessels.py index bf396c65..22cd07c5 100644 --- a/backend/bloom/routers/v1/vessels.py +++ b/backend/bloom/routers/v1/vessels.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request from redis import Redis from bloom.config import settings from bloom.container import UseCases @@ -36,24 +36,23 @@ async def list_vessels(nocache:bool=False,key: str = Depends(X_API_KEY_HEADER)): return payload else: use_cases = UseCases() - vessel_repository = use_cases.vessel_repository() db = use_cases.db() with db.session() as session: - + vessel_repository = use_cases.vessel_repository(session) json_data = [json.loads(v.model_dump_json() if v else "{}") - for v in vessel_repository.get_vessels_list(session)] + for v in vessel_repository.list()] rd.set(endpoint, json.dumps(json_data)) rd.expire(endpoint,settings.redis_cache_expiration) return json_data @router.get("/vessels/{vessel_id}") -async def get_vessel(vessel_id: int,key: str = Depends(X_API_KEY_HEADER)): +async def get_vessel(request: Request, vessel_id: int,key: str = Depends(X_API_KEY_HEADER)): check_apikey(key) use_cases = UseCases() - vessel_repository = use_cases.vessel_repository() db = use_cases.db() with db.session() as session: - return vessel_repository.get_vessel_by_id(session,vessel_id) + vessel_repository = use_cases.vessel_repository(session) + return vessel_repository.get_by_id(id=vessel_id) @router.get("/vessels/all/positions/last") async def list_all_vessel_last_position(nocache:bool=False,key: str = Depends(X_API_KEY_HEADER)): diff --git a/backend/bloom/services/geo.py b/backend/bloom/services/geo.py index 81f3d674..976e19a2 100644 --- a/backend/bloom/services/geo.py +++ b/backend/bloom/services/geo.py @@ -29,10 +29,10 @@ def find_positions_in_port_buffer(vessel_positions: List[tuple]) -> List[tuple]: # Get all ports from DataBase use_cases = UseCases() - port_repository = use_cases.port_repository() db = use_cases.db() with db.session() as session: - ports = port_repository.get_all_ports(session) + port_repository = use_cases.port_repository(session) + ports = port_repository.list() df_ports = pd.DataFrame( [[p.id, p.name, p.geometry_buffer] for p in ports], diff --git a/backend/bloom/tasks/load_dim_vessel_from_csv.py b/backend/bloom/tasks/load_dim_vessel_from_csv.py index 1d0a1e96..abbce380 100644 --- a/backend/bloom/tasks/load_dim_vessel_from_csv.py +++ b/backend/bloom/tasks/load_dim_vessel_from_csv.py @@ -34,7 +34,6 @@ def map_to_domain(row: pd.Series) -> Vessel: def run(csv_file_name: str) -> None: use_cases = UseCases() - vessel_repository = use_cases.vessel_repository() db = use_cases.db() inserted_ports = [] @@ -43,11 +42,12 @@ def run(csv_file_name: str) -> None: df = pd.read_csv(csv_file_name, sep=",") vessels = df.apply(map_to_domain, axis=1) with db.session() as session: + vessel_repository = use_cases.vessel_repository(session) ports_inserts = [] ports_updates = [] # Pour chaque enregistrement du fichier CSV for vessel in vessels: - if vessel.id and vessel_repository.get_vessel_by_id(session, vessel.id): + if vessel.id and vessel_repository.get_by_id(vessel.id): # si la valeur du champ id n'est pas vide: # rechercher l'enregistrement correspondant dans la table dim_vessel # mettre à jour l'enregistrement à partir des données CSV. @@ -57,20 +57,20 @@ def run(csv_file_name: str) -> None: # insérer les données CSV dans la table dim_vessel; ports_inserts.append(vessel) # Insertions / MAJ en batch - inserted_ports = vessel_repository.batch_create_vessel(session, ports_inserts) - vessel_repository.batch_update_vessel(session, ports_updates) + inserted_ports = vessel_repository.add(ports_inserts) + vessel_repository.List(ports_updates) # En fin de traitement: # les enregistrements de la table dim_vessel pourtant un MMSI absent du fichier CSV sont mis à jour # avec la valeur tracking_activated=FALSE csv_mmsi = list(df['mmsi']) deleted_ports = list( - filter(lambda v: v.mmsi not in csv_mmsi, vessel_repository.get_all_vessels_list(session))) - vessel_repository.set_tracking(session, [v.id for v in deleted_ports], False, + filter(lambda v: v.mmsi not in csv_mmsi, vessel_repository.list())) + vessel_repository.set_tracking([v.id for v in deleted_ports], False, "Suppression logique suite import nouveau fichier CSV") # le traitement vérifie qu'il n'existe qu'un seul enregistrement à l'état tracking_activated==True # pour chaque valeur distincte de MMSI. - integrity_errors = vessel_repository.check_mmsi_integrity(session) + integrity_errors = vessel_repository.check_mmsi_integrity() if not integrity_errors: session.commit() else: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 3c5c6ecc..e3a5d535 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -46,6 +46,8 @@ dependencies = [ "fastapi[standard]>=0.115.0,<1.0.0", "uvicorn~=0.32", "redis~=5.0", + "pytest>=8.3.3", + "pytest-env>=1.1.5", ] name = "bloom" version = "0.1.0" @@ -157,3 +159,14 @@ target-version = "py310" [tool.ruff.mccabe] max-complexity = 10 + +[tool.pytest.ini_options] +env = [ + "POSTGRES_DRIVER=sqlite", + "POSTGRES_USER=", + "POSTGRES_PASSWORD=", + "POSTGRES_HOSTNAME=", + "POSTGRES_PORT=", + "POSTGRES_DB=:memory:", +] + diff --git a/docker/Dockerfile b/docker/Dockerfile index e0160a76..c718a293 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,9 +23,9 @@ COPY ./backend/ ${PROJECT_DIR}/backend COPY docker/rsyslog.conf /etc/rsyslog.conf # Install requirements package for python with poetry -ARG POETRY_VERSION=1.8.2 -ENV POETRY_VERSION=${POETRY_VERSION} -RUN pip install --upgrade pip && pip install --user "poetry==$POETRY_VERSION" +#ARG POETRY_VERSION=1.8.2 +#ENV POETRY_VERSION=${POETRY_VERSION} +#RUN pip install --upgrade pip && pip install --user "poetry==$POETRY_VERSION" ENV PATH="${PATH}:/root/.local/bin" COPY ./backend/pyproject.toml ./backend/alembic.ini ./backend/ @@ -37,7 +37,8 @@ ENV UV_PROJECT_ENVIRONMENT=${VIRTUAL_ENV} RUN \ cd backend &&\ uv venv ${VIRTUAL_ENV} &&\ - echo ". ${VIRTUAL_ENV}/bin/activate" >> /root/.bashrc &&\ + echo ". ${VIRTUAL_ENV}/bin/activate" >> ~/.bashrc &&\ + . ${VIRTUAL_ENV}/bin/activate &&\ uv sync # Launch cron services