Skip to content

Commit

Permalink
Merge pull request #1998 from bcgov/fix/alex-draft-schedule-filter-1931
Browse files Browse the repository at this point in the history
Feat: Draft Schedule Records Exclusions - 1931
  • Loading branch information
AlexZorkin authored Feb 19, 2025
2 parents a554365 + 7d52214 commit 7c89e0a
Show file tree
Hide file tree
Showing 11 changed files with 394 additions and 141 deletions.
54 changes: 35 additions & 19 deletions backend/lcfs/tests/fuel_export/test_fuel_exports_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from lcfs.web.api.fuel_export.repo import FuelExportRepository
from lcfs.db.models.compliance.FuelExport import FuelExport
from lcfs.db.base import ActionTypeEnum, UserTypeEnum
from lcfs.db.models.user.Role import RoleEnum
from types import SimpleNamespace

# Mock common data for reuse
mock_fuel_type = FuelTypeSchema(
Expand All @@ -32,10 +34,16 @@
category="Diesel",
)


# FuelExportServices Tests


@pytest.mark.anyio
async def test_get_fuel_export_options_success(fuel_export_service, mock_repo):
dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT])
dummy_request = MagicMock()
dummy_request.user = dummy_user
fuel_export_service.request = dummy_request

mock_repo.get_fuel_export_table_options.return_value = []
result = await fuel_export_service.get_fuel_export_options("2024")
assert isinstance(result, FuelTypeOptionsResponse)
Expand All @@ -44,6 +52,12 @@ async def test_get_fuel_export_options_success(fuel_export_service, mock_repo):

@pytest.mark.anyio
async def test_get_fuel_export_list_success(fuel_export_service, mock_repo):
# Set up a dummy request with a valid user
dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT])
dummy_request = MagicMock()
dummy_request.user = dummy_user
fuel_export_service.request = dummy_request

# Create a mock FuelExport with all required fields
mock_export = FuelExport(
fuel_export_id=1,
Expand All @@ -57,10 +71,7 @@ async def test_get_fuel_export_list_success(fuel_export_service, mock_repo):
export_date=date.today(),
group_uuid="test-uuid",
provision_of_the_act_id=1,
provision_of_the_act={
"provision_of_the_act_id": 1,
"name": "Test Provision"
},
provision_of_the_act={"provision_of_the_act_id": 1, "name": "Test Provision"},
version=0,
user_type=UserTypeEnum.SUPPLIER,
action_type=ActionTypeEnum.CREATE,
Expand All @@ -70,11 +81,20 @@ async def test_get_fuel_export_list_success(fuel_export_service, mock_repo):
result = await fuel_export_service.get_fuel_export_list(1)

assert isinstance(result, FuelExportsSchema)
mock_repo.get_fuel_export_list.assert_called_once_with(1)
# Expect the repo call to include exclude_draft_reports=True based on the user
mock_repo.get_fuel_export_list.assert_called_once_with(
1, exclude_draft_reports=True
)


@pytest.mark.anyio
async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo):
# Set up a dummy request with a valid user
dummy_user = SimpleNamespace(id=1, role_names=[RoleEnum.GOVERNMENT])
dummy_request = MagicMock()
dummy_request.user = dummy_user
fuel_export_service.request = dummy_request

mock_export = FuelExport(
fuel_export_id=1,
compliance_report_id=1,
Expand All @@ -86,10 +106,7 @@ async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo
units="L",
export_date=date.today(),
provision_of_the_act_id=1,
provision_of_the_act={
"provision_of_the_act_id": 1,
"name": "Test Provision"
},
provision_of_the_act={"provision_of_the_act_id": 1, "name": "Test Provision"},
)
mock_repo.get_fuel_exports_paginated.return_value = ([mock_export], 1)

Expand All @@ -103,10 +120,15 @@ async def test_get_fuel_exports_paginated_success(fuel_export_service, mock_repo
assert result.pagination.total == 1
assert result.pagination.page == 1
assert result.pagination.size == 10
mock_repo.get_fuel_exports_paginated.assert_called_once_with(pagination_mock, 1)
# Expect the extra parameter to be passed
mock_repo.get_fuel_exports_paginated.assert_called_once_with(
pagination_mock, 1, exclude_draft_reports=True
)


# FuelExportActionService Tests


@pytest.mark.anyio
async def test_action_create_fuel_export_success(fuel_export_action_service, mock_repo):
input_data = FuelExportCreateUpdateSchema(
Expand All @@ -131,10 +153,7 @@ async def test_action_create_fuel_export_success(fuel_export_action_service, moc
user_type=UserTypeEnum.SUPPLIER,
action_type=ActionTypeEnum.CREATE,
provision_of_the_act_id=1,
provision_of_the_act={
"provision_of_the_act_id": 1,
"name": "Act Provision"
},
provision_of_the_act={"provision_of_the_act_id": 1, "name": "Act Provision"},
fuel_type_id=1,
fuel_category_id=1,
quantity=100,
Expand Down Expand Up @@ -181,10 +200,7 @@ async def test_action_update_fuel_export_success(fuel_export_action_service, moc
fuel_type_id=1,
fuel_category_id=1,
provision_of_the_act_id=1,
provision_of_the_act={
"provision_of_the_act_id": 1,
"name": "Act Provision"
},
provision_of_the_act={"provision_of_the_act_id": 1, "name": "Act Provision"},
quantity=100,
units="L",
export_date=date.today(),
Expand Down
116 changes: 100 additions & 16 deletions backend/lcfs/tests/fuel_supply/test_fuel_supplies_repo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import math
import pytest
from unittest.mock import MagicMock, AsyncMock
from sqlalchemy.ext.asyncio import AsyncSession

from lcfs.db.models.compliance import FuelSupply
from lcfs.web.api.fuel_supply.repo import FuelSupplyRepository
from lcfs.web.api.fuel_supply.schema import FuelSupplyCreateUpdateSchema
from lcfs.web.api.fuel_supply.schema import (
FuelSuppliesSchema,
PaginationResponseSchema,
FuelSupplyResponseSchema,
)
from lcfs.web.api.base import PaginationRequestSchema


@pytest.fixture
def mock_db_session():
session = AsyncMock(spec=AsyncSession)

# Create a mock that properly mimics SQLAlchemy's async result chain
async def mock_execute(*args, **kwargs):
mock_result = (
MagicMock()
) # Changed to MagicMock since the chained methods are sync
mock_result = MagicMock()
mock_result.scalars = MagicMock(return_value=mock_result)
mock_result.unique = MagicMock(return_value=mock_result)
mock_result.all = MagicMock(return_value=[MagicMock(spec=FuelSupply)])
Expand All @@ -36,24 +40,34 @@ def fuel_supply_repo(mock_db_session):


@pytest.mark.anyio
async def test_get_fuel_supply_list(fuel_supply_repo, mock_db_session):
async def test_get_fuel_supply_list_exclude_draft_reports(
fuel_supply_repo, mock_db_session
):
compliance_report_id = 1
mock_result = [MagicMock(spec=FuelSupply)]
expected_fuel_supplies = [MagicMock(spec=FuelSupply)]

# Set up the mock to return our desired result
# Set up the mock result chain with proper method chaining.
mock_result_chain = MagicMock()
mock_result_chain.scalars = MagicMock(return_value=mock_result_chain)
mock_result_chain.unique = MagicMock(return_value=mock_result_chain)
mock_result_chain.all = MagicMock(return_value=mock_result)
mock_result_chain.all = MagicMock(return_value=expected_fuel_supplies)

async def mock_execute(*args, **kwargs):
async def mock_execute(query, *args, **kwargs):
return mock_result_chain

mock_db_session.execute = mock_execute

result = await fuel_supply_repo.get_fuel_supply_list(compliance_report_id)
# Test when drafts should be excluded (e.g. government user).
result_gov = await fuel_supply_repo.get_fuel_supply_list(
compliance_report_id, exclude_draft_reports=True
)
assert result_gov == expected_fuel_supplies

assert result == mock_result
# Test when drafts are not excluded.
result_non_gov = await fuel_supply_repo.get_fuel_supply_list(
compliance_report_id, exclude_draft_reports=False
)
assert result_non_gov == expected_fuel_supplies


@pytest.mark.anyio
Expand All @@ -80,19 +94,89 @@ async def test_check_duplicate(fuel_supply_repo, mock_db_session):
units="L",
)

# Set up the mock chain using regular MagicMock since the chained methods are sync
# Set up the mock chain using MagicMock for synchronous chained methods.
mock_result_chain = MagicMock()
mock_result_chain.scalars = MagicMock(return_value=mock_result_chain)
mock_result_chain.first = MagicMock(
return_value=MagicMock(spec=FuelSupply))
mock_result_chain.first = MagicMock(return_value=MagicMock(spec=FuelSupply))

# Define an async execute function that returns our mock chain
async def mock_execute(*args, **kwargs):
return mock_result_chain

# Replace the session's execute with our new mock
mock_db_session.execute = mock_execute

result = await fuel_supply_repo.check_duplicate(fuel_supply_data)

assert result is not None


@pytest.mark.anyio
async def test_get_fuel_supplies_paginated_exclude_draft_reports(fuel_supply_repo):
# Define a sample pagination request.
pagination = PaginationRequestSchema(page=1, size=10)
compliance_report_id = 1
total_count = 20

# Build a valid fuel supply record that passes validation.
valid_fuel_supply = {
"fuel_supply_id": 1,
"complianceReportId": 1,
"version": 0,
"fuelTypeId": 1,
"quantity": 100,
"groupUuid": "some-uuid",
"userType": "SUPPLIER",
"actionType": "CREATE",
"fuelType": {"fuel_type_id": 1, "fuelType": "Diesel", "units": "L"},
"fuelCategory": {"fuel_category_id": 1, "category": "Diesel"},
"endUseType": {"endUseTypeId": 1, "type": "Transport", "subType": "Personal"},
"provisionOfTheAct": {"provisionOfTheActId": 1, "name": "Act Provision"},
"compliancePeriod": "2024",
"units": "L",
"fuelCode": {
"fuelStatus": {"status": "Approved"},
"fuelCode": "FUEL123",
"carbonIntensity": 15.0,
},
"fuelTypeOther": "Optional",
}
expected_fuel_supplies = [valid_fuel_supply]

async def mock_get_fuel_supplies_paginated(
pagination, compliance_report_id, exclude_draft_reports
):
total_pages = math.ceil(total_count / pagination.size) if total_count > 0 else 0
pagination_response = PaginationResponseSchema(
page=pagination.page,
size=pagination.size,
total=total_count,
total_pages=total_pages,
)
processed = [
FuelSupplyResponseSchema.model_validate(fs) for fs in expected_fuel_supplies
]
return FuelSuppliesSchema(
pagination=pagination_response, fuel_supplies=processed
)

fuel_supply_repo.get_fuel_supplies_paginated = AsyncMock(
side_effect=mock_get_fuel_supplies_paginated
)

result = await fuel_supply_repo.get_fuel_supplies_paginated(
pagination, compliance_report_id, exclude_draft_reports=True
)

# Validate pagination values.
assert result.pagination.page == pagination.page
assert result.pagination.size == pagination.size
assert result.pagination.total == total_count
expected_total_pages = (
math.ceil(total_count / pagination.size) if total_count > 0 else 0
)
assert result.pagination.total_pages == expected_total_pages

# Validate that the fuel supplies list is correctly transformed.
expected_processed = [
FuelSupplyResponseSchema.model_validate(fs) for fs in expected_fuel_supplies
]
assert result.fuel_supplies == expected_processed
Loading

0 comments on commit 7c89e0a

Please sign in to comment.