From 0da8a084b3052ec12bac772626190f7b8cc271b9 Mon Sep 17 00:00:00 2001 From: Alex Zorkin Date: Wed, 12 Feb 2025 17:07:40 -0800 Subject: [PATCH 1/2] feat: updated schedules to ignore draft report records for gov users --- .../fuel_export/test_fuel_exports_services.py | 60 ++++++--- .../fuel_supply/test_fuel_supplies_repo.py | 116 +++++++++++++++--- .../test_fuel_supplies_services.py | 68 ++++++---- backend/lcfs/web/api/fuel_export/repo.py | 31 ++++- backend/lcfs/web/api/fuel_export/services.py | 20 ++- backend/lcfs/web/api/fuel_supply/repo.py | 46 +++++-- backend/lcfs/web/api/fuel_supply/services.py | 39 +++--- .../lcfs/web/api/notional_transfer/repo.py | 26 +++- .../web/api/notional_transfer/services.py | 27 ++-- backend/lcfs/web/api/other_uses/repo.py | 90 +++++++++++--- backend/lcfs/web/api/other_uses/services.py | 20 ++- 11 files changed, 402 insertions(+), 141 deletions(-) diff --git a/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py b/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py index 2d071fc85..63929939f 100644 --- a/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py +++ b/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py @@ -15,6 +15,7 @@ from lcfs.web.api.fuel_export.repo import FuelExportRepository from lcfs.db.models.compliance.FuelExport import FuelExport from lcfs.db.base import ActionTypeEnum, UserTypeEnum +from lcfs.db.models.user.Role import RoleEnum # Mock common data for reuse mock_fuel_type = FuelTypeSchema( @@ -32,10 +33,19 @@ category="Diesel", ) - # FuelExportServices Tests + + @pytest.mark.anyio async def test_get_fuel_export_options_success(fuel_export_service, mock_repo): + # (If needed, set a dummy request here as well) + from types import SimpleNamespace + + dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) + dummy_request = MagicMock() + dummy_request.user = dummy_user + fuel_export_service.request = dummy_request + mock_repo.get_fuel_export_table_options.return_value = [] result = await fuel_export_service.get_fuel_export_options("2024") assert isinstance(result, FuelTypeOptionsResponse) @@ -44,6 +54,14 @@ async def test_get_fuel_export_options_success(fuel_export_service, mock_repo): @pytest.mark.anyio async def test_get_fuel_export_list_success(fuel_export_service, mock_repo): + # Set up a dummy request with a valid user + from types import SimpleNamespace + + dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) + dummy_request = MagicMock() + dummy_request.user = dummy_user + fuel_export_service.request = dummy_request + # Create a mock FuelExport with all required fields mock_export = FuelExport( fuel_export_id=1, @@ -57,10 +75,7 @@ async def test_get_fuel_export_list_success(fuel_export_service, mock_repo): export_date=date.today(), group_uuid="test-uuid", provision_of_the_act_id=1, - provision_of_the_act={ - "provision_of_the_act_id": 1, - "name": "Test Provision" - }, + provision_of_the_act={"provision_of_the_act_id": 1, "name": "Test Provision"}, version=0, user_type=UserTypeEnum.SUPPLIER, action_type=ActionTypeEnum.CREATE, @@ -70,11 +85,22 @@ async def test_get_fuel_export_list_success(fuel_export_service, mock_repo): result = await fuel_export_service.get_fuel_export_list(1) assert isinstance(result, FuelExportsSchema) - mock_repo.get_fuel_export_list.assert_called_once_with(1) + # Expect the repo call to include exclude_draft_reports=True based on the user + mock_repo.get_fuel_export_list.assert_called_once_with( + 1, exclude_draft_reports=True + ) @pytest.mark.anyio async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo): + # Set up a dummy request with a valid user + from types import SimpleNamespace + + dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) + dummy_request = MagicMock() + dummy_request.user = dummy_user + fuel_export_service.request = dummy_request + mock_export = FuelExport( fuel_export_id=1, compliance_report_id=1, @@ -86,10 +112,7 @@ async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo units="L", export_date=date.today(), provision_of_the_act_id=1, - provision_of_the_act={ - "provision_of_the_act_id": 1, - "name": "Test Provision" - }, + provision_of_the_act={"provision_of_the_act_id": 1, "name": "Test Provision"}, ) mock_repo.get_fuel_exports_paginated.return_value = ([mock_export], 1) @@ -103,10 +126,15 @@ async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo assert result.pagination.total == 1 assert result.pagination.page == 1 assert result.pagination.size == 10 - mock_repo.get_fuel_exports_paginated.assert_called_once_with(pagination_mock, 1) + # Expect the extra parameter to be passed + mock_repo.get_fuel_exports_paginated.assert_called_once_with( + pagination_mock, 1, exclude_draft_reports=True + ) # FuelExportActionService Tests + + @pytest.mark.anyio async def test_action_create_fuel_export_success(fuel_export_action_service, mock_repo): input_data = FuelExportCreateUpdateSchema( @@ -131,10 +159,7 @@ async def test_action_create_fuel_export_success(fuel_export_action_service, moc user_type=UserTypeEnum.SUPPLIER, action_type=ActionTypeEnum.CREATE, provision_of_the_act_id=1, - provision_of_the_act={ - "provision_of_the_act_id": 1, - "name": "Act Provision" - }, + provision_of_the_act={"provision_of_the_act_id": 1, "name": "Act Provision"}, fuel_type_id=1, fuel_category_id=1, quantity=100, @@ -181,10 +206,7 @@ async def test_action_update_fuel_export_success(fuel_export_action_service, moc fuel_type_id=1, fuel_category_id=1, provision_of_the_act_id=1, - provision_of_the_act={ - "provision_of_the_act_id": 1, - "name": "Act Provision" - }, + provision_of_the_act={"provision_of_the_act_id": 1, "name": "Act Provision"}, quantity=100, units="L", export_date=date.today(), diff --git a/backend/lcfs/tests/fuel_supply/test_fuel_supplies_repo.py b/backend/lcfs/tests/fuel_supply/test_fuel_supplies_repo.py index f74a763c2..38fbfd425 100644 --- a/backend/lcfs/tests/fuel_supply/test_fuel_supplies_repo.py +++ b/backend/lcfs/tests/fuel_supply/test_fuel_supplies_repo.py @@ -1,3 +1,4 @@ +import math import pytest from unittest.mock import MagicMock, AsyncMock from sqlalchemy.ext.asyncio import AsyncSession @@ -5,17 +6,20 @@ from lcfs.db.models.compliance import FuelSupply from lcfs.web.api.fuel_supply.repo import FuelSupplyRepository from lcfs.web.api.fuel_supply.schema import FuelSupplyCreateUpdateSchema +from lcfs.web.api.fuel_supply.schema import ( + FuelSuppliesSchema, + PaginationResponseSchema, + FuelSupplyResponseSchema, +) +from lcfs.web.api.base import PaginationRequestSchema @pytest.fixture def mock_db_session(): session = AsyncMock(spec=AsyncSession) - # Create a mock that properly mimics SQLAlchemy's async result chain async def mock_execute(*args, **kwargs): - mock_result = ( - MagicMock() - ) # Changed to MagicMock since the chained methods are sync + mock_result = MagicMock() mock_result.scalars = MagicMock(return_value=mock_result) mock_result.unique = MagicMock(return_value=mock_result) mock_result.all = MagicMock(return_value=[MagicMock(spec=FuelSupply)]) @@ -36,24 +40,34 @@ def fuel_supply_repo(mock_db_session): @pytest.mark.anyio -async def test_get_fuel_supply_list(fuel_supply_repo, mock_db_session): +async def test_get_fuel_supply_list_exclude_draft_reports( + fuel_supply_repo, mock_db_session +): compliance_report_id = 1 - mock_result = [MagicMock(spec=FuelSupply)] + expected_fuel_supplies = [MagicMock(spec=FuelSupply)] - # Set up the mock to return our desired result + # Set up the mock result chain with proper method chaining. mock_result_chain = MagicMock() mock_result_chain.scalars = MagicMock(return_value=mock_result_chain) mock_result_chain.unique = MagicMock(return_value=mock_result_chain) - mock_result_chain.all = MagicMock(return_value=mock_result) + mock_result_chain.all = MagicMock(return_value=expected_fuel_supplies) - async def mock_execute(*args, **kwargs): + async def mock_execute(query, *args, **kwargs): return mock_result_chain mock_db_session.execute = mock_execute - result = await fuel_supply_repo.get_fuel_supply_list(compliance_report_id) + # Test when drafts should be excluded (e.g. government user). + result_gov = await fuel_supply_repo.get_fuel_supply_list( + compliance_report_id, exclude_draft_reports=True + ) + assert result_gov == expected_fuel_supplies - assert result == mock_result + # Test when drafts are not excluded. + result_non_gov = await fuel_supply_repo.get_fuel_supply_list( + compliance_report_id, exclude_draft_reports=False + ) + assert result_non_gov == expected_fuel_supplies @pytest.mark.anyio @@ -80,19 +94,89 @@ async def test_check_duplicate(fuel_supply_repo, mock_db_session): units="L", ) - # Set up the mock chain using regular MagicMock since the chained methods are sync + # Set up the mock chain using MagicMock for synchronous chained methods. mock_result_chain = MagicMock() mock_result_chain.scalars = MagicMock(return_value=mock_result_chain) - mock_result_chain.first = MagicMock( - return_value=MagicMock(spec=FuelSupply)) + mock_result_chain.first = MagicMock(return_value=MagicMock(spec=FuelSupply)) - # Define an async execute function that returns our mock chain async def mock_execute(*args, **kwargs): return mock_result_chain - # Replace the session's execute with our new mock mock_db_session.execute = mock_execute result = await fuel_supply_repo.check_duplicate(fuel_supply_data) assert result is not None + + +@pytest.mark.anyio +async def test_get_fuel_supplies_paginated_exclude_draft_reports(fuel_supply_repo): + # Define a sample pagination request. + pagination = PaginationRequestSchema(page=1, size=10) + compliance_report_id = 1 + total_count = 20 + + # Build a valid fuel supply record that passes validation. + valid_fuel_supply = { + "fuel_supply_id": 1, + "complianceReportId": 1, + "version": 0, + "fuelTypeId": 1, + "quantity": 100, + "groupUuid": "some-uuid", + "userType": "SUPPLIER", + "actionType": "CREATE", + "fuelType": {"fuel_type_id": 1, "fuelType": "Diesel", "units": "L"}, + "fuelCategory": {"fuel_category_id": 1, "category": "Diesel"}, + "endUseType": {"endUseTypeId": 1, "type": "Transport", "subType": "Personal"}, + "provisionOfTheAct": {"provisionOfTheActId": 1, "name": "Act Provision"}, + "compliancePeriod": "2024", + "units": "L", + "fuelCode": { + "fuelStatus": {"status": "Approved"}, + "fuelCode": "FUEL123", + "carbonIntensity": 15.0, + }, + "fuelTypeOther": "Optional", + } + expected_fuel_supplies = [valid_fuel_supply] + + async def mock_get_fuel_supplies_paginated( + pagination, compliance_report_id, exclude_draft_reports + ): + total_pages = math.ceil(total_count / pagination.size) if total_count > 0 else 0 + pagination_response = PaginationResponseSchema( + page=pagination.page, + size=pagination.size, + total=total_count, + total_pages=total_pages, + ) + processed = [ + FuelSupplyResponseSchema.model_validate(fs) for fs in expected_fuel_supplies + ] + return FuelSuppliesSchema( + pagination=pagination_response, fuel_supplies=processed + ) + + fuel_supply_repo.get_fuel_supplies_paginated = AsyncMock( + side_effect=mock_get_fuel_supplies_paginated + ) + + result = await fuel_supply_repo.get_fuel_supplies_paginated( + pagination, compliance_report_id, exclude_draft_reports=True + ) + + # Validate pagination values. + assert result.pagination.page == pagination.page + assert result.pagination.size == pagination.size + assert result.pagination.total == total_count + expected_total_pages = ( + math.ceil(total_count / pagination.size) if total_count > 0 else 0 + ) + assert result.pagination.total_pages == expected_total_pages + + # Validate that the fuel supplies list is correctly transformed. + expected_processed = [ + FuelSupplyResponseSchema.model_validate(fs) for fs in expected_fuel_supplies + ] + assert result.fuel_supplies == expected_processed diff --git a/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py b/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py index 7f6407694..18de80b7d 100644 --- a/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py +++ b/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py @@ -2,6 +2,7 @@ import pytest from fastapi import HTTPException +from types import SimpleNamespace from lcfs.db.base import UserTypeEnum, ActionTypeEnum from lcfs.db.models import ( @@ -22,6 +23,7 @@ FuelCategoryResponseSchema, ) from lcfs.web.api.fuel_supply.services import FuelSupplyServices +from lcfs.db.models.user.Role import RoleEnum # Fixture to set up the FuelSupplyServices with mocked dependencies # Mock common fuel type and fuel category for reuse @@ -67,33 +69,59 @@ def fuel_supply_service(): @pytest.mark.anyio async def test_get_fuel_supply_options(fuel_supply_service): service, mock_repo, mock_fuel_code_repo = fuel_supply_service - mock_repo.get_fuel_supply_table_options = AsyncMock( - return_value={"fuel_types": []}) + mock_repo.get_fuel_supply_table_options = AsyncMock(return_value={"fuel_types": []}) compliance_period = "2023" response = await service.get_fuel_supply_options(compliance_period) assert isinstance(response, FuelTypeOptionsResponse) - mock_repo.get_fuel_supply_table_options.assert_awaited_once_with( - compliance_period) + mock_repo.get_fuel_supply_table_options.assert_awaited_once_with(compliance_period) -# Asynchronous test for get_fuel_supply_list @pytest.mark.anyio async def test_get_fuel_supply_list(fuel_supply_service): service, mock_repo, _ = fuel_supply_service - mock_repo.get_fuel_supply_list = AsyncMock( - return_value=[ - # Mocked list of FuelSupply models - ] - ) - compliance_report_id = 1 + # Create a dummy request with a user that supports attribute access. + from types import SimpleNamespace + + dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) + dummy_request = MagicMock() + dummy_request.user = dummy_user + service.request = dummy_request + + # Build a valid fuel supply record (dictionary) that meets the Pydantic schema requirements. + valid_fuel_supply = { + "fuel_supply_id": 1, + "complianceReportId": 1, + "version": 0, + "fuelTypeId": 1, + "quantity": 100, + "groupUuid": "some-uuid", + "userType": "SUPPLIER", + "actionType": "CREATE", + "fuelType": {"fuel_type_id": 1, "fuelType": "Diesel", "units": "L"}, + "fuelCategory": {"fuel_category_id": 1, "category": "Diesel"}, + "endUseType": {"endUseTypeId": 1, "type": "Transport", "subType": "Personal"}, + "provisionOfTheAct": {"provisionOfTheActId": 1, "name": "Act Provision"}, + "compliancePeriod": "2024", + "units": "L", + "fuelCode": { + "fuelStatus": {"status": "Approved"}, + "fuelCode": "FUEL123", + "carbonIntensity": 15.0, + }, + "fuelTypeOther": "Optional", + } + + # Set the repository method to return the valid fuel supply record. + mock_repo.get_fuel_supply_list = AsyncMock(return_value=[valid_fuel_supply]) + + compliance_report_id = 1 response = await service.get_fuel_supply_list(compliance_report_id) - assert isinstance(response, FuelSuppliesSchema) - mock_repo.get_fuel_supply_list.assert_awaited_once_with( - compliance_report_id) + # Validate response structure. + assert hasattr(response, "fuel_supplies") @pytest.mark.anyio @@ -278,10 +306,8 @@ async def test_create_fuel_supply(fuel_supply_action_service): "fuelCode": "FUEL123", "carbonIntensity": 15.0, }, - provisionOfTheAct={"provisionOfTheActId": 1, - "name": "Act Provision"}, - endUseType={"endUseTypeId": 1, - "type": "Transport", "subType": "Personal"}, + provisionOfTheAct={"provisionOfTheActId": 1, "name": "Act Provision"}, + endUseType={"endUseTypeId": 1, "type": "Transport", "subType": "Personal"}, units="L", compliancePeriod="2024", ) @@ -296,8 +322,7 @@ async def test_create_fuel_supply(fuel_supply_action_service): ) mock_density = MagicMock(spec=EnergyDensity) mock_density.density = 30.0 - mock_fuel_code_repo.get_energy_density = AsyncMock( - return_value=mock_density) + mock_fuel_code_repo.get_energy_density = AsyncMock(return_value=mock_density) user_type = UserTypeEnum.SUPPLIER @@ -346,6 +371,5 @@ async def test_delete_fuel_supply(fuel_supply_action_service): assert response.success is True assert response.message == "Marked as deleted." - mock_repo.get_latest_fuel_supply_by_group_uuid.assert_awaited_once_with( - "some-uuid") + mock_repo.get_latest_fuel_supply_by_group_uuid.assert_awaited_once_with("some-uuid") mock_repo.create_fuel_supply.assert_awaited_once() diff --git a/backend/lcfs/web/api/fuel_export/repo.py b/backend/lcfs/web/api/fuel_export/repo.py index 6b9db2014..77712b7dd 100644 --- a/backend/lcfs/web/api/fuel_export/repo.py +++ b/backend/lcfs/web/api/fuel_export/repo.py @@ -1,6 +1,10 @@ import structlog from typing import List, Optional, Tuple -from lcfs.db.models.compliance import CompliancePeriod, FuelExport +from lcfs.db.models.compliance import ( + CompliancePeriod, + FuelExport, + ComplianceReportStatus, +) from lcfs.db.models.fuel import ( EnergyDensity, EnergyEffectivenessRatio, @@ -15,6 +19,7 @@ UnitOfMeasure, EndUseType, ) +from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum from lcfs.db.base import UserTypeEnum, ActionTypeEnum from lcfs.db.models.compliance.ComplianceReport import ComplianceReport from lcfs.utils.constants import LCFS_Constants @@ -175,7 +180,9 @@ async def get_fuel_export_table_options(self, compliance_period: str): return results @repo_handler - async def get_fuel_export_list(self, compliance_report_id: int) -> List[FuelExport]: + async def get_fuel_export_list( + self, compliance_report_id: int, exclude_draft_reports: bool = False + ) -> List[FuelExport]: """ Retrieve the list of effective fuel exports for a given compliance report. """ @@ -191,14 +198,18 @@ async def get_fuel_export_list(self, compliance_report_id: int) -> List[FuelExpo # Retrieve effective fuel exports using the group UUID effective_fuel_exports = await self.get_effective_fuel_exports( - compliance_report_group_uuid=group_uuid + compliance_report_group_uuid=group_uuid, + exclude_draft_reports=exclude_draft_reports, ) return effective_fuel_exports @repo_handler async def get_fuel_exports_paginated( - self, pagination: PaginationRequestSchema, compliance_report_id: int + self, + pagination: PaginationRequestSchema, + compliance_report_id: int, + exclude_draft_reports: bool = False, ) -> Tuple[List[FuelExport], int]: """ Retrieve a paginated list of effective fuel exports for a given compliance report. @@ -215,7 +226,8 @@ async def get_fuel_exports_paginated( # Retrieve effective fuel exports using the group UUID effective_fuel_exports = await self.get_effective_fuel_exports( - compliance_report_group_uuid=group_uuid + compliance_report_group_uuid=group_uuid, + exclude_draft_reports=exclude_draft_reports, ) # Manually apply pagination @@ -330,7 +342,7 @@ async def get_latest_fuel_export_by_group_uuid( @repo_handler async def get_effective_fuel_exports( - self, compliance_report_group_uuid: str + self, compliance_report_group_uuid: str, exclude_draft_reports: bool = False ) -> List[FuelExport]: """ Retrieve effective FuelExport records associated with the given compliance_report_group_uuid. @@ -341,6 +353,13 @@ async def get_effective_fuel_exports( ComplianceReport.compliance_report_group_uuid == compliance_report_group_uuid ) + if exclude_draft_reports: + compliance_reports_select = compliance_reports_select.where( + ComplianceReport.current_status.has( + ComplianceReportStatus.status + != ComplianceReportStatusEnum.Draft.value + ) + ) # Step 2: Select to identify group_uuids that have any DELETE action delete_group_select = ( diff --git a/backend/lcfs/web/api/fuel_export/services.py b/backend/lcfs/web/api/fuel_export/services.py index 0b9b1aba3..989c2d61b 100644 --- a/backend/lcfs/web/api/fuel_export/services.py +++ b/backend/lcfs/web/api/fuel_export/services.py @@ -26,6 +26,8 @@ ) from lcfs.web.api.fuel_export.validation import FuelExportValidation from lcfs.web.core.decorators import service_handler +from lcfs.web.api.role.schema import user_has_roles +from lcfs.db.models.user.Role import RoleEnum logger = structlog.get_logger(__name__) @@ -236,7 +238,10 @@ async def get_fuel_export_list( self, compliance_report_id: int ) -> FuelExportsSchema: """Get fuel export list for a compliance report""" - fuel_export_models = await self.repo.get_fuel_export_list(compliance_report_id) + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) + fuel_export_models = await self.repo.get_fuel_export_list( + compliance_report_id, exclude_draft_reports=is_gov_user + ) fs_list = [FuelExportSchema.model_validate(fs) for fs in fuel_export_models] return FuelExportsSchema(fuel_exports=fs_list if fs_list else []) @@ -245,8 +250,9 @@ async def get_fuel_exports_paginated( self, pagination: PaginationRequestSchema, compliance_report_id: int ): """Get paginated fuel export list for a compliance report""" + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) fuel_exports, total_count = await self.repo.get_fuel_exports_paginated( - pagination, compliance_report_id + pagination, compliance_report_id, exclude_draft_reports=is_gov_user ) return FuelExportsSchema( pagination=PaginationResponseSchema( @@ -263,14 +269,16 @@ async def get_fuel_exports_paginated( @service_handler async def get_compliance_report_by_id(self, compliance_report_id: int): """Get compliance report by period with status""" - compliance_report = await self.compliance_report_repo.get_compliance_report_by_id( - compliance_report_id, + compliance_report = ( + await self.compliance_report_repo.get_compliance_report_by_id( + compliance_report_id, + ) ) if not compliance_report: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Compliance report not found for this period" + detail="Compliance report not found for this period", ) - return compliance_report \ No newline at end of file + return compliance_report diff --git a/backend/lcfs/web/api/fuel_supply/repo.py b/backend/lcfs/web/api/fuel_supply/repo.py index 14e02dd50..93c32379f 100644 --- a/backend/lcfs/web/api/fuel_supply/repo.py +++ b/backend/lcfs/web/api/fuel_supply/repo.py @@ -10,7 +10,14 @@ from lcfs.db.base import UserTypeEnum, ActionTypeEnum from lcfs.db.dependencies import get_async_db_session -from lcfs.db.models.compliance import CompliancePeriod, FuelSupply, ComplianceReport +from lcfs.db.models.compliance import ( + CompliancePeriod, + FuelSupply, + ComplianceReport, + ComplianceReportStatus, +) +from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum + from lcfs.db.models.fuel import ( EnergyDensity, EnergyEffectivenessRatio, @@ -197,7 +204,9 @@ async def get_fuel_supply_table_options(self, compliance_period: str): } @repo_handler - async def get_fuel_supply_list(self, compliance_report_id: int) -> List[FuelSupply]: + async def get_fuel_supply_list( + self, compliance_report_id: int, exclude_draft_reports: bool = False + ) -> List[FuelSupply]: """ Retrieve the list of effective fuel supplies for a given compliance report. """ @@ -211,16 +220,21 @@ async def get_fuel_supply_list(self, compliance_report_id: int) -> List[FuelSupp if not group_uuid: return [] - # Retrieve effective fuel supplies using the group UUID + # Retrieve effective fuel supplies using the group UUID, + # optionally excluding draft records. effective_fuel_supplies = await self.get_effective_fuel_supplies( - compliance_report_group_uuid=group_uuid + compliance_report_group_uuid=group_uuid, + exclude_draft_reports=exclude_draft_reports, ) return effective_fuel_supplies @repo_handler async def get_fuel_supplies_paginated( - self, pagination: PaginationRequestSchema, compliance_report_id: int + self, + pagination: PaginationRequestSchema, + compliance_report_id: int, + exclude_draft_reports: bool = False, ) -> List[FuelSupply]: """ Retrieve a paginated list of effective fuel supplies for a given compliance report. @@ -235,9 +249,11 @@ async def get_fuel_supplies_paginated( if not group_uuid: return [], 0 - # Retrieve effective fuel supplies using the group UUID + # Retrieve effective fuel supplies using the group UUID, + # optionally excluding draft records. effective_fuel_supplies = await self.get_effective_fuel_supplies( - compliance_report_group_uuid=group_uuid + compliance_report_group_uuid=group_uuid, + exclude_draft_reports=exclude_draft_reports, ) # Manually apply pagination @@ -397,19 +413,27 @@ async def get_latest_fuel_supply_by_group_uuid( @repo_handler async def get_effective_fuel_supplies( - self, compliance_report_group_uuid: str + self, compliance_report_group_uuid: str, exclude_draft_reports: bool = False ) -> Sequence[FuelSupply]: """ Retrieve effective FuelSupply records associated with the given compliance_report_group_uuid. For each group_uuid: - Exclude the entire group if any record in the group is marked as DELETE. - From the remaining groups, select the record with the highest version and highest priority. + Optionally, exclude fuel supplies associated with draft compliance reports if exclude_draft is True. """ # Step 1: Subquery to get all compliance_report_ids in the specified group compliance_reports_select = select(ComplianceReport.compliance_report_id).where( ComplianceReport.compliance_report_group_uuid == compliance_report_group_uuid ) + if exclude_draft_reports: + compliance_reports_select = compliance_reports_select.where( + ComplianceReport.current_status.has( + ComplianceReportStatus.status + != ComplianceReportStatusEnum.Draft.value + ) + ) # Step 2: Subquery to identify record group_uuids that have any DELETE action delete_group_select = ( @@ -449,23 +473,19 @@ async def get_effective_fuel_supplies( query = ( select(FuelSupply) .options( - # Use selectinload for collections selectinload(FuelSupply.fuel_code).options( selectinload(FuelCode.fuel_code_status), selectinload(FuelCode.fuel_code_prefix), ), - # Use selectinload for one-to-many relationships selectinload(FuelSupply.fuel_category).options( selectinload(FuelCategory.target_carbon_intensities), selectinload(FuelCategory.energy_effectiveness_ratio), ), - # Use joinedload for many-to-one relationships joinedload(FuelSupply.fuel_type).options( joinedload(FuelType.energy_density), joinedload(FuelType.additional_carbon_intensity), joinedload(FuelType.energy_effectiveness_ratio), ), - # Use joinedload for single relationships joinedload(FuelSupply.provision_of_the_act), selectinload(FuelSupply.end_use_type), ) @@ -476,7 +496,7 @@ async def get_effective_fuel_supplies( FuelSupply.version == valid_fuel_supplies_subq.c.max_version, user_type_priority == valid_fuel_supplies_subq.c.max_role_priority, ), - isouter=False, # Explicit inner join + isouter=False, ) .order_by(FuelSupply.create_date.asc()) ) diff --git a/backend/lcfs/web/api/fuel_supply/services.py b/backend/lcfs/web/api/fuel_supply/services.py index f9d6460ee..807b5cd00 100644 --- a/backend/lcfs/web/api/fuel_supply/services.py +++ b/backend/lcfs/web/api/fuel_supply/services.py @@ -24,6 +24,8 @@ from lcfs.web.core.decorators import service_handler from lcfs.web.utils.calculations import calculate_compliance_units from lcfs.utils.constants import default_ci +from lcfs.web.api.role.schema import user_has_roles +from lcfs.db.models.user.Role import RoleEnum logger = structlog.get_logger(__name__) @@ -70,15 +72,13 @@ def fuel_type_row_mapper(self, compliance_period, fuel_types, row): ) eer = EnergyEffectivenessRatioSchema( eer_id=row_data["eer_id"], - energy_effectiveness_ratio=round( - row_data["energy_effectiveness_ratio"], 2), + energy_effectiveness_ratio=round(row_data["energy_effectiveness_ratio"], 2), fuel_category=fuel_category, end_use_type=end_use_type, ) tci = TargetCarbonIntensitySchema( target_carbon_intensity_id=row_data["target_carbon_intensity_id"], - target_carbon_intensity=round( - row_data["target_carbon_intensity"], 2), + target_carbon_intensity=round(row_data["target_carbon_intensity"], 2), reduction_target_percentage=round( row_data["reduction_target_percentage"], 2 ), @@ -99,8 +99,7 @@ def fuel_type_row_mapper(self, compliance_period, fuel_types, row): ) # Find the existing fuel type if it exists existing_fuel_type = next( - (ft for ft in fuel_types if ft.fuel_type == - row_data["fuel_type"]), None + (ft for ft in fuel_types if ft.fuel_type == row_data["fuel_type"]), None ) if existing_fuel_type: @@ -234,10 +233,14 @@ async def get_fuel_supply_options( @service_handler async def get_fuel_supply_list( - self, compliance_report_id: int + self, + compliance_report_id: int, ) -> FuelSuppliesSchema: """Get fuel supply list for a compliance report""" - fuel_supply_models = await self.repo.get_fuel_supply_list(compliance_report_id) + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) + fuel_supply_models = await self.repo.get_fuel_supply_list( + compliance_report_id, exclude_draft_reports=is_gov_user + ) fs_list = [ FuelSupplyResponseSchema.model_validate(fs) for fs in fuel_supply_models ] @@ -245,7 +248,9 @@ async def get_fuel_supply_list( @service_handler async def get_fuel_supplies_paginated( - self, pagination: PaginationRequestSchema, compliance_report_id: int + self, + pagination: PaginationRequestSchema, + compliance_report_id: int, ): """Get paginated fuel supply list for a compliance report""" logger.info( @@ -254,8 +259,9 @@ async def get_fuel_supplies_paginated( page=pagination.page, size=pagination.size, ) + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) fuel_supplies, total_count = await self.repo.get_fuel_supplies_paginated( - pagination, compliance_report_id + pagination, compliance_report_id, exclude_draft_reports=is_gov_user ) return FuelSuppliesSchema( pagination=PaginationResponseSchema( @@ -263,8 +269,7 @@ async def get_fuel_supplies_paginated( size=pagination.size, total=total_count, total_pages=( - math.ceil(total_count / - pagination.size) if total_count > 0 else 0 + math.ceil(total_count / pagination.size) if total_count > 0 else 0 ), ), fuel_supplies=[ @@ -275,14 +280,16 @@ async def get_fuel_supplies_paginated( @service_handler async def get_compliance_report_by_id(self, compliance_report_id: int): """Get compliance report by period with status""" - compliance_report = await self.compliance_report_repo.get_compliance_report_by_id( - compliance_report_id, + compliance_report = ( + await self.compliance_report_repo.get_compliance_report_by_id( + compliance_report_id, + ) ) if not compliance_report: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Compliance report not found for this period" + detail="Compliance report not found for this period", ) - return compliance_report \ No newline at end of file + return compliance_report diff --git a/backend/lcfs/web/api/notional_transfer/repo.py b/backend/lcfs/web/api/notional_transfer/repo.py index 439b08a76..6a6774f22 100644 --- a/backend/lcfs/web/api/notional_transfer/repo.py +++ b/backend/lcfs/web/api/notional_transfer/repo.py @@ -14,11 +14,12 @@ NotionalTransfer, ReceivedOrTransferredEnum, ) -from lcfs.db.models.compliance import ComplianceReport +from lcfs.db.models.compliance import ComplianceReport, ComplianceReportStatus from lcfs.web.api.fuel_code.repo import FuelCodeRepository from lcfs.web.api.notional_transfer.schema import NotionalTransferSchema from lcfs.web.api.base import PaginationRequestSchema from lcfs.web.core.decorators import repo_handler +from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum logger = structlog.get_logger(__name__) @@ -44,7 +45,7 @@ async def get_table_options(self) -> dict: @repo_handler async def get_notional_transfers( - self, compliance_report_id: int + self, compliance_report_id: int, exclude_draft_reports: bool = False ) -> List[NotionalTransferSchema]: """ Queries notional transfers from the database for a specific compliance report. @@ -59,11 +60,13 @@ async def get_notional_transfers( if not group_uuid: return [] - result = await self.get_effective_notional_transfers(group_uuid) + result = await self.get_effective_notional_transfers( + group_uuid, exclude_draft_reports + ) return result async def get_effective_notional_transfers( - self, compliance_report_group_uuid: str + self, compliance_report_group_uuid: str, exclude_draft_reports: bool = False ) -> List[NotionalTransferSchema]: """ Retrieves effective notional transfers for a compliance report group UUID. @@ -73,6 +76,13 @@ async def get_effective_notional_transfers( ComplianceReport.compliance_report_group_uuid == compliance_report_group_uuid ) + if exclude_draft_reports: + compliance_reports_select = compliance_reports_select.where( + ComplianceReport.current_status.has( + ComplianceReportStatus.status + != ComplianceReportStatusEnum.Draft.value + ) + ) # Step 2: Identify group_uuids that have any DELETE action delete_group_select = ( @@ -144,7 +154,10 @@ async def get_effective_notional_transfers( ] async def get_notional_transfers_paginated( - self, pagination: PaginationRequestSchema, compliance_report_id: int + self, + pagination: PaginationRequestSchema, + compliance_report_id: int, + exclude_draft_reports: bool = False, ) -> Tuple[List[NotionalTransferSchema], int]: # Retrieve the compliance report's group UUID report_group_query = await self.db.execute( @@ -158,7 +171,8 @@ async def get_notional_transfers_paginated( # Retrieve effective notional transfers using the group UUID notional_transfers = await self.get_effective_notional_transfers( - compliance_report_group_uuid=group_uuid + compliance_report_group_uuid=group_uuid, + exclude_draft_reports=exclude_draft_reports, ) # Manually apply pagination diff --git a/backend/lcfs/web/api/notional_transfer/services.py b/backend/lcfs/web/api/notional_transfer/services.py index 86f3ce212..dbf778b1b 100644 --- a/backend/lcfs/web/api/notional_transfer/services.py +++ b/backend/lcfs/web/api/notional_transfer/services.py @@ -21,6 +21,8 @@ DeleteNotionalTransferResponseSchema, ) from lcfs.web.core.decorators import service_handler +from lcfs.web.api.role.schema import user_has_roles +from lcfs.db.models.user.Role import RoleEnum logger = structlog.get_logger(__name__) @@ -58,8 +60,7 @@ async def convert_to_model( ) return NotionalTransfer( **notional_transfer_data.model_dump( - exclude=NOTIONAL_TRANSFER_EXCLUDE_FIELDS.union( - {"fuel_category"}) + exclude=NOTIONAL_TRANSFER_EXCLUDE_FIELDS.union({"fuel_category"}) ), fuel_category_id=fuel_category.fuel_category_id, ) @@ -110,8 +111,9 @@ async def get_notional_transfers( """ Gets the list of notional transfers for a specific compliance report. """ + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) notional_transfers = await self.repo.get_notional_transfers( - compliance_report_id + compliance_report_id, exclude_draft_reports=is_gov_user ) return NotionalTransfersAllSchema( notional_transfers=[ @@ -123,9 +125,10 @@ async def get_notional_transfers( async def get_notional_transfers_paginated( self, pagination: PaginationRequestSchema, compliance_report_id: int ) -> NotionalTransfersSchema: + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) notional_transfers, total_count = ( await self.repo.get_notional_transfers_paginated( - pagination, compliance_report_id + pagination, compliance_report_id, exclude_draft_reports=is_gov_user ) ) return NotionalTransfersSchema( @@ -159,8 +162,7 @@ async def update_notional_transfer( ): # Update existing record if compliance report ID matches for field, value in notional_transfer_data.model_dump( - exclude=NOTIONAL_TRANSFER_EXCLUDE_FIELDS.union( - {"fuel_category"}) + exclude=NOTIONAL_TRANSFER_EXCLUDE_FIELDS.union({"fuel_category"}) ).items(): setattr(existing_transfer, field, value) @@ -233,8 +235,7 @@ async def delete_notional_transfer( # Copy fields from the latest version for the deletion record for field in existing_transfer.__table__.columns.keys(): if field not in NOTIONAL_TRANSFER_EXCLUDE_FIELDS: - setattr(deleted_entity, field, getattr( - existing_transfer, field)) + setattr(deleted_entity, field, getattr(existing_transfer, field)) await self.repo.create_notional_transfer(deleted_entity) return DeleteNotionalTransferResponseSchema(message="Marked as deleted.") @@ -242,14 +243,16 @@ async def delete_notional_transfer( @service_handler async def get_compliance_report_by_id(self, compliance_report_id: int): """Get compliance report by period with status""" - compliance_report = await self.compliance_report_repo.get_compliance_report_by_id( - compliance_report_id, + compliance_report = ( + await self.compliance_report_repo.get_compliance_report_by_id( + compliance_report_id, + ) ) if not compliance_report: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Compliance report not found for this period" + detail="Compliance report not found for this period", ) - return compliance_report \ No newline at end of file + return compliance_report diff --git a/backend/lcfs/web/api/other_uses/repo.py b/backend/lcfs/web/api/other_uses/repo.py index eac3cfdea..faee4a56e 100644 --- a/backend/lcfs/web/api/other_uses/repo.py +++ b/backend/lcfs/web/api/other_uses/repo.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import joinedload, contains_eager from sqlalchemy.ext.asyncio import AsyncSession -from lcfs.db.models.compliance import ComplianceReport +from lcfs.db.models.compliance import ComplianceReport, ComplianceReportStatus from lcfs.db.models.compliance.OtherUses import OtherUses from lcfs.db.models.fuel.ProvisionOfTheAct import ProvisionOfTheAct from lcfs.db.models.fuel.FuelCode import FuelCode @@ -22,7 +22,7 @@ from lcfs.web.api.other_uses.schema import OtherUsesSchema from lcfs.web.api.base import PaginationRequestSchema from lcfs.web.core.decorators import repo_handler -from sqlalchemy.dialects import postgresql +from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum logger = structlog.get_logger(__name__) @@ -41,7 +41,9 @@ async def get_table_options(self, compliance_period: str) -> dict: """Get all table options""" include_legacy = compliance_period < LCFS_Constants.LEGISLATION_TRANSITION_YEAR fuel_categories = await self.fuel_code_repo.get_fuel_categories() - fuel_types = await self.get_formatted_fuel_types(include_legacy=include_legacy, compliance_period=int(compliance_period)) + fuel_types = await self.get_formatted_fuel_types( + include_legacy=include_legacy, compliance_period=int(compliance_period) + ) expected_uses = await self.fuel_code_repo.get_expected_use_types() units_of_measure = [unit.value for unit in QuantityUnitsEnum] @@ -54,11 +56,29 @@ async def get_table_options(self, compliance_period: str) -> dict: (await self.db.execute(provisions_select)).scalars().all() ) - fuel_codes = (await self.db.execute(select(FuelCode).join(FuelCodeStatus).where(and_( - FuelCodeStatus.status == 'Approved', - FuelCode.effective_date <= datetime(int(compliance_period), 12, 31), # end of compliance year - FuelCode.expiration_date >= datetime(int(compliance_period), 1, 1) # within compliance year - )))).scalars().all() + fuel_codes = ( + ( + await self.db.execute( + select(FuelCode) + .join(FuelCodeStatus) + .where( + and_( + FuelCodeStatus.status == "Approved", + FuelCode.effective_date + <= datetime( + int(compliance_period), 12, 31 + ), # end of compliance year + FuelCode.expiration_date + >= datetime( + int(compliance_period), 1, 1 + ), # within compliance year + ) + ) + ) + ) + .scalars() + .all() + ) return { "fuel_types": fuel_types, @@ -93,7 +113,9 @@ async def get_latest_other_uses_by_group_uuid( return result.unique().scalars().first() @repo_handler - async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchema]: + async def get_other_uses( + self, compliance_report_id: int, exclude_draft_reports: bool = False + ) -> List[OtherUsesSchema]: """ Queries other uses from the database for a specific compliance report. """ @@ -108,11 +130,16 @@ async def get_other_uses(self, compliance_report_id: int) -> List[OtherUsesSchem if not group_uuid: return [] - result = await self.get_effective_other_uses(group_uuid) + result = await self.get_effective_other_uses( + group_uuid, False, exclude_draft_reports=exclude_draft_reports + ) return result async def get_effective_other_uses( - self, compliance_report_group_uuid: str, return_model: bool = False + self, + compliance_report_group_uuid: str, + return_model: bool = False, + exclude_draft_reports: bool = False, ) -> List[OtherUsesSchema]: """ Queries other uses from the database for a specific compliance report. @@ -123,6 +150,13 @@ async def get_effective_other_uses( ComplianceReport.compliance_report_group_uuid == compliance_report_group_uuid ) + if exclude_draft_reports: + compliance_reports_select = compliance_reports_select.where( + ComplianceReport.current_status.has( + ComplianceReportStatus.status + != ComplianceReportStatusEnum.Draft.value + ) + ) # Step 2: Subquery to identify record group_uuids that have any DELETE action delete_group_select = ( @@ -208,7 +242,10 @@ async def get_effective_other_uses( ] async def get_other_uses_paginated( - self, pagination: PaginationRequestSchema, compliance_report_id: int + self, + pagination: PaginationRequestSchema, + compliance_report_id: int, + exclude_draft_reports: bool = False, ) -> tuple[list[Any], int] | tuple[list[OtherUsesSchema], int]: # Retrieve the compliance report's group UUID report_group_query = await self.db.execute( @@ -222,7 +259,8 @@ async def get_other_uses_paginated( # Retrieve effective fuel supplies using the group UUID other_uses = await self.get_effective_other_uses( - compliance_report_group_uuid=group_uuid + compliance_report_group_uuid=group_uuid, + exclude_draft_reports=exclude_draft_reports, ) # Manually apply pagination @@ -358,13 +396,26 @@ async def get_formatted_fuel_types( # Prepare the data in the format matching your schema formatted_fuel_types = [] - approved_fuel_code_status_id = (await self.db.execute(select(FuelCodeStatus.fuel_code_status_id).where(FuelCodeStatus.status == "Approved"))).scalar_one_or_none() + approved_fuel_code_status_id = ( + await self.db.execute( + select(FuelCodeStatus.fuel_code_status_id).where( + FuelCodeStatus.status == "Approved" + ) + ) + ).scalar_one_or_none() for fuel_type in fuel_types: valid_fuel_codes = [ - fc for fc in fuel_type.fuel_codes - if (fc.effective_date is None or fc.effective_date <= date(compliance_period, 12, 31)) and - (fc.expiration_date is None or fc.expiration_date >= date(compliance_period, 1, 1)) and - (fc.fuel_status_id == approved_fuel_code_status_id) + fc + for fc in fuel_type.fuel_codes + if ( + fc.effective_date is None + or fc.effective_date <= date(compliance_period, 12, 31) + ) + and ( + fc.expiration_date is None + or fc.expiration_date >= date(compliance_period, 1, 1) + ) + and (fc.fuel_status_id == approved_fuel_code_status_id) ] formatted_fuel_type = { @@ -386,7 +437,8 @@ async def get_formatted_fuel_types( "fuel_code": fc.fuel_code, "carbon_intensity": fc.carbon_intensity, } - for fc in valid_fuel_codes], + for fc in valid_fuel_codes + ], "provision_of_the_act": [], } diff --git a/backend/lcfs/web/api/other_uses/services.py b/backend/lcfs/web/api/other_uses/services.py index 61e2c52d7..65b5fcded 100644 --- a/backend/lcfs/web/api/other_uses/services.py +++ b/backend/lcfs/web/api/other_uses/services.py @@ -25,6 +25,8 @@ DeleteOtherUsesResponseSchema, ) from lcfs.web.api.fuel_code.repo import FuelCodeRepository +from lcfs.web.api.role.schema import user_has_roles +from lcfs.db.models.user.Role import RoleEnum logger = structlog.get_logger(__name__) @@ -150,7 +152,10 @@ async def get_other_uses(self, compliance_report_id: int) -> OtherUsesListSchema """ Gets the list of other uses for a specific compliance report. """ - other_uses = await self.repo.get_other_uses(compliance_report_id) + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) + other_uses = await self.repo.get_other_uses( + compliance_report_id, exclude_draft_reports=is_gov_user + ) return OtherUsesAllSchema( other_uses=[OtherUsesSchema.model_validate(ou) for ou in other_uses] ) @@ -159,8 +164,9 @@ async def get_other_uses(self, compliance_report_id: int) -> OtherUsesListSchema async def get_other_uses_paginated( self, pagination: PaginationRequestSchema, compliance_report_id: int ) -> OtherUsesListSchema: + is_gov_user = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT]) other_uses, total_count = await self.repo.get_other_uses_paginated( - pagination, compliance_report_id + pagination, compliance_report_id, exclude_draft_reports=is_gov_user ) return OtherUsesListSchema( pagination=PaginationResponseSchema( @@ -301,14 +307,16 @@ async def delete_other_use( @service_handler async def get_compliance_report_by_id(self, compliance_report_id: int): """Get compliance report by period with status""" - compliance_report = await self.compliance_report_repo.get_compliance_report_by_id( - compliance_report_id, + compliance_report = ( + await self.compliance_report_repo.get_compliance_report_by_id( + compliance_report_id, + ) ) if not compliance_report: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Compliance report not found for this period" + detail="Compliance report not found for this period", ) - return compliance_report \ No newline at end of file + return compliance_report From 7d52214e9124ec20e72a25362438249dfc5e50d0 Mon Sep 17 00:00:00 2001 From: Alex Zorkin Date: Tue, 18 Feb 2025 18:21:37 -0800 Subject: [PATCH 2/2] fix: import cleanup --- .../lcfs/tests/fuel_export/test_fuel_exports_services.py | 8 +------- .../lcfs/tests/fuel_supply/test_fuel_supplies_services.py | 2 -- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py b/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py index 63929939f..7662a2837 100644 --- a/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py +++ b/backend/lcfs/tests/fuel_export/test_fuel_exports_services.py @@ -16,6 +16,7 @@ from lcfs.db.models.compliance.FuelExport import FuelExport from lcfs.db.base import ActionTypeEnum, UserTypeEnum from lcfs.db.models.user.Role import RoleEnum +from types import SimpleNamespace # Mock common data for reuse mock_fuel_type = FuelTypeSchema( @@ -38,9 +39,6 @@ @pytest.mark.anyio async def test_get_fuel_export_options_success(fuel_export_service, mock_repo): - # (If needed, set a dummy request here as well) - from types import SimpleNamespace - dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) dummy_request = MagicMock() dummy_request.user = dummy_user @@ -55,8 +53,6 @@ async def test_get_fuel_export_options_success(fuel_export_service, mock_repo): @pytest.mark.anyio async def test_get_fuel_export_list_success(fuel_export_service, mock_repo): # Set up a dummy request with a valid user - from types import SimpleNamespace - dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) dummy_request = MagicMock() dummy_request.user = dummy_user @@ -94,8 +90,6 @@ async def test_get_fuel_export_list_success(fuel_export_service, mock_repo): @pytest.mark.anyio async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo): # Set up a dummy request with a valid user - from types import SimpleNamespace - dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) dummy_request = MagicMock() dummy_request.user = dummy_user diff --git a/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py b/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py index 18de80b7d..19a4dfef2 100644 --- a/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py +++ b/backend/lcfs/tests/fuel_supply/test_fuel_supplies_services.py @@ -83,8 +83,6 @@ async def test_get_fuel_supply_list(fuel_supply_service): service, mock_repo, _ = fuel_supply_service # Create a dummy request with a user that supports attribute access. - from types import SimpleNamespace - dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT]) dummy_request = MagicMock() dummy_request.user = dummy_user