Skip to content

Commit

Permalink
pytest fixes after change
Browse files Browse the repository at this point in the history
  • Loading branch information
areyeslo committed Feb 25, 2025
1 parent 38d630f commit f7ea189
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 195 deletions.
2 changes: 0 additions & 2 deletions backend/lcfs/db/models/fuel/FuelCategory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional
from pydantic import computed_field
from sqlalchemy import Column, Integer, Text, Enum, Numeric
from sqlalchemy.orm import relationship

Expand Down
2 changes: 0 additions & 2 deletions backend/lcfs/db/models/fuel/FuelType.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import enum
from typing import Optional

from pydantic import computed_field
from sqlalchemy import Column, Integer, Text, Boolean, Enum, Numeric, text
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def test_create_fuel_export_success(

# Call the method under test with compliance_period
result = await fuel_export_action_service.create_fuel_export(
fe_data, user_type, "2024"
fe_data, user_type
)
# Assertions
assert result == FuelExportSchema.model_validate(created_export)
Expand Down Expand Up @@ -238,7 +238,7 @@ async def test_update_fuel_export_success_existing_report(

# Call the method under test with compliance_period
result = await fuel_export_action_service.update_fuel_export(
fe_data, user_type, "2024"
fe_data, user_type
)

# Assertions
Expand Down Expand Up @@ -318,7 +318,7 @@ async def test_update_fuel_export_create_new_version(

# Call the method under test with compliance_period
result = await fuel_export_action_service.update_fuel_export(
fe_data, user_type, "2024"
fe_data, user_type
)

# Assertions
Expand Down Expand Up @@ -481,7 +481,6 @@ async def test_compliance_units_calculation(
case, fuel_export_action_service, mock_repo, mock_fuel_code_repo
):
# Mock repository methods
mock_repo.get_compliance_period_id = AsyncMock(return_value=1)
mock_repo.create_fuel_export = AsyncMock()
mock_repo.get_fuel_export_by_id = AsyncMock()

Expand Down Expand Up @@ -555,11 +554,10 @@ async def create_fuel_export_side_effect(fuel_export: FuelExport):
# Set up repository mocks
mock_created_export = await create_fuel_export_side_effect(FuelExport())
mock_repo.create_fuel_export.return_value = mock_created_export
mock_repo.get_fuel_export_by_id.return_value = mock_created_export

# Call service method
result = await fuel_export_action_service.create_fuel_export(
fe_data, UserTypeEnum.SUPPLIER, compliance_period=fe_data.compliance_period
fe_data, UserTypeEnum.SUPPLIER
)

# Assertions
Expand All @@ -572,11 +570,7 @@ async def create_fuel_export_side_effect(fuel_export: FuelExport):
assert result.quantity == fe_data.quantity

# Verify mock calls
mock_repo.get_compliance_period_id.assert_awaited_once_with(
fe_data.compliance_period
)
mock_repo.create_fuel_export.assert_awaited_once()
mock_repo.get_fuel_export_by_id.assert_awaited_once_with(1, 1)
mock_fuel_code_repo.get_standardized_fuel_data.assert_awaited_once_with(
fuel_type_id=fe_data.fuel_type_id,
fuel_category_id=fe_data.fuel_category_id,
Expand Down
3 changes: 1 addition & 2 deletions backend/lcfs/tests/fuel_export/test_fuel_exports_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ async def mock_get_effective_exports(*args, **kwargs):
@pytest.mark.anyio
async def test_get_fuel_export_by_id_success(fuel_export_repo, mock_db):
fuel_export_id = 1
compliance_period_id = 1
expected_fuel_export = FuelExport(fuel_export_id=fuel_export_id)

mock_result = MagicMock()
Expand All @@ -115,7 +114,7 @@ async def test_get_fuel_export_by_id_success(fuel_export_repo, mock_db):
)
mock_db.execute.return_value = mock_result

result = await fuel_export_repo.get_fuel_export_by_id(fuel_export_id, compliance_period_id)
result = await fuel_export_repo.get_fuel_export_by_id(fuel_export_id)

mock_db.execute.assert_called_once()
mock_result.unique.assert_called_once()
Expand Down
12 changes: 1 addition & 11 deletions backend/lcfs/tests/fuel_export/test_fuel_exports_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,14 @@ async def test_action_create_fuel_export_success(fuel_export_action_service, moc
fuel_category=mock_fuel_category.dict(),
)
mock_repo.create_fuel_export = AsyncMock(return_value=mock_created_export)
mock_repo.get_compliance_period_id = AsyncMock(return_value=1)
mock_repo.get_fuel_export_by_id = AsyncMock(return_value=mock_created_export)

result = await fuel_export_action_service.create_fuel_export(
input_data,
UserTypeEnum.SUPPLIER,
compliance_period=input_data.compliance_period,
UserTypeEnum.SUPPLIER
)

assert isinstance(result, FuelExportSchema)
mock_repo.create_fuel_export.assert_awaited_once()
mock_repo.get_compliance_period_id.assert_awaited_once()
mock_repo.get_fuel_export_by_id.assert_awaited_once()


@pytest.mark.anyio
Expand Down Expand Up @@ -219,20 +214,15 @@ async def test_action_update_fuel_export_success(fuel_export_action_service, moc
return_value=mock_existing_export
)
mock_repo.update_fuel_export = AsyncMock(return_value=mock_existing_export)
mock_repo.get_compliance_period_id = AsyncMock(return_value=1)
mock_repo.get_fuel_export_by_id = AsyncMock(return_value=mock_existing_export)

result = await fuel_export_action_service.update_fuel_export(
input_data,
UserTypeEnum.SUPPLIER,
compliance_period=input_data.compliance_period,
)

assert isinstance(result, FuelExportSchema)
mock_repo.get_fuel_export_version_by_user.assert_awaited_once()
mock_repo.update_fuel_export.assert_awaited_once()
mock_repo.get_compliance_period_id.assert_awaited_once()
mock_repo.get_fuel_export_by_id.assert_awaited_once()


@pytest.mark.anyio
Expand Down
2 changes: 1 addition & 1 deletion backend/lcfs/web/api/allocation_agreement/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AllocationAgreementCreateSchema(BaseSchema):
transaction_partner_phone: str
fuel_type: str
fuel_type_other: Optional[str] = None
ci_of_fuel: float
ci_of_fuel: Optional[float] = 0
provision_of_the_act: str
quantity: int = Field(..., gt=0, description="Quantity must be greater than 0")
units: str
Expand Down
24 changes: 9 additions & 15 deletions backend/lcfs/web/api/fuel_code/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,12 @@ async def get_formatted_fuel_types(
conditions.append(FuelType.is_legacy == False)


# Build the query with filtered fuel_codes
# Build the query with filtered fuel_codes and compliance period joins
query = (
select(FuelType)
.outerjoin(FuelType.fuel_instances)
.outerjoin(FuelInstance.fuel_category)
.outerjoin(FuelType.fuel_codes)
.outerjoin(FuelType.default_carbon_intensities)
)

# Add compliance period dependent joins if period is provided
Expand All @@ -131,6 +130,13 @@ async def get_formatted_fuel_types(
EnergyEffectivenessRatio.fuel_category_id == FuelCategory.fuel_category_id
)
)
.outerjoin(
DefaultCarbonIntensity,
and_(
DefaultCarbonIntensity.fuel_type_id == FuelType.fuel_type_id,
DefaultCarbonIntensity.compliance_period_id == compliance_period_id
)
)
)

query = (
Expand All @@ -151,19 +157,10 @@ async def get_formatted_fuel_types(
# Prepare the data in the format matching your schema
formatted_fuel_types = []
for fuel_type in fuel_types:
# Get default CI for specific compliance period if provided
default_ci = None
if compliance_period_id:
default_ci = next(
(dci.default_carbon_intensity
for dci in fuel_type.default_carbon_intensities
if dci.compliance_period_id == compliance_period_id),
None
)
formatted_fuel_type = {
"fuel_type_id": fuel_type.fuel_type_id,
"fuel_type": fuel_type.fuel_type,
"default_carbon_intensity": default_ci,
"default_carbon_intensity": fuel_type.default_carbon_intensity,
"units": fuel_type.units if fuel_type.units else None,
"unrecognized": fuel_type.unrecognized,
"fuel_categories": [
Expand Down Expand Up @@ -1008,9 +1005,6 @@ async def get_standardized_fuel_data(
effective_carbon_intensity = fuel_code.carbon_intensity
# Other Fuel uses the Default CI of the Category
elif fuel_type.unrecognized:
fuel_category = await self.get_fuel_category_by(
fuel_category_id=fuel_category_id
)
effective_carbon_intensity = await self.get_category_carbon_intensity(
fuel_category_id=fuel_category_id,
compliance_period=compliance_period
Expand Down
37 changes: 5 additions & 32 deletions backend/lcfs/web/api/fuel_export/actions_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def _populate_fuel_export_fields(

@service_handler
async def create_fuel_export(
self, fe_data: FuelExportCreateUpdateSchema, user_type: UserTypeEnum, compliance_period: str,
self, fe_data: FuelExportCreateUpdateSchema, user_type: UserTypeEnum
) -> FuelExportSchema:
"""
Create a new fuel export record.
Expand All @@ -105,9 +105,6 @@ async def create_fuel_export(
Returns the newly created fuel export record as a response schema.
"""
# Get compliance period ID
compliance_period_id = await self.repo.get_compliance_period_id(compliance_period)

# Assign a unique group UUID for the new fuel export
new_group_uuid = str(uuid.uuid4())
fuel_export = FuelExport(
Expand All @@ -123,18 +120,11 @@ async def create_fuel_export(

# Save the populated fuel export record
created_export = await self.repo.create_fuel_export(fuel_export)

# Fetch with compliance period
result = await self.repo.get_fuel_export_by_id(
created_export.fuel_export_id,
compliance_period_id
)

return FuelExportSchema.model_validate(result)
return FuelExportSchema.model_validate(created_export)

@service_handler
async def update_fuel_export(
self, fe_data: FuelExportCreateUpdateSchema, user_type: UserTypeEnum, compliance_period: str,
self, fe_data: FuelExportCreateUpdateSchema, user_type: UserTypeEnum
) -> FuelExportSchema:
"""
Update an existing fuel export record or create a new version if necessary.
Expand All @@ -146,9 +136,6 @@ async def update_fuel_export(
Returns the updated or new version of the fuel export record.
"""
# Get compliance period ID
compliance_period_id = await self.repo.get_compliance_period_id(compliance_period)

existing_export = await self.repo.get_fuel_export_version_by_user(
fe_data.group_uuid, fe_data.version, user_type
)
Expand All @@ -169,14 +156,7 @@ async def update_fuel_export(
)

updated_export = await self.repo.update_fuel_export(existing_export)

# Fetch with compliance period
result = await self.repo.get_fuel_export_by_id(
updated_export.fuel_export_id,
compliance_period_id
)

return FuelExportSchema.model_validate(result)
return FuelExportSchema.model_validate(updated_export)

elif existing_export:
# Create a new version if compliance report ID differs
Expand Down Expand Up @@ -204,14 +184,7 @@ async def update_fuel_export(

# Save the new version
new_export = await self.repo.create_fuel_export(fuel_export)

# Fetch with compliance period
result = await self.repo.get_fuel_export_by_id(
new_export.fuel_export_id,
compliance_period_id
)

return FuelExportSchema.model_validate(result)
return FuelExportSchema.model_validate(new_export)

raise HTTPException(
status_code=404, detail="Fuel export record not found.")
Expand Down
60 changes: 13 additions & 47 deletions backend/lcfs/web/api/fuel_export/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,6 @@ def __init__(self, db: AsyncSession = Depends(get_async_db_session)):
joinedload(FuelExport.end_use_type),
)

@repo_handler
async def get_compliance_period_id(self, compliance_period: str) -> int:
"""Get compliance period ID from description"""
query = (
select(CompliancePeriod.compliance_period_id)
.where(CompliancePeriod.description == compliance_period)
)
result = await self.db.execute(query)
return result.scalar_one_or_none()

@repo_handler
async def get_default_carbon_intensity(self, fuel_type_id: int, compliance_period_id: int) -> Optional[float]:
"""
Get default carbon intensity for a specific fuel type and compliance period
"""
query = (
select(DefaultCarbonIntensity.default_carbon_intensity)
.where(
and_(
DefaultCarbonIntensity.fuel_type_id == fuel_type_id,
DefaultCarbonIntensity.compliance_period_id == compliance_period_id
)
)
)
result = await self.db.execute(query)
default_ci = result.scalar_one_or_none()
return float(default_ci) if default_ci is not None else None


@repo_handler
async def get_fuel_export_table_options(self, compliance_period: str):
Expand Down Expand Up @@ -157,14 +129,17 @@ async def get_fuel_export_table_options(self, compliance_period: str):
DefaultCarbonIntensity,
and_(
DefaultCarbonIntensity.fuel_type_id == FuelType.fuel_type_id,
DefaultCarbonIntensity.compliance_period_id == subquery_compliance_period_id
DefaultCarbonIntensity.compliance_period_id
== subquery_compliance_period_id,
),
)
.outerjoin(
CategoryCarbonIntensity,
and_(
CategoryCarbonIntensity.fuel_category_id == FuelCategory.fuel_category_id,
CategoryCarbonIntensity.compliance_period_id == subquery_compliance_period_id
CategoryCarbonIntensity.fuel_category_id
== FuelCategory.fuel_category_id,
CategoryCarbonIntensity.compliance_period_id
== subquery_compliance_period_id,
),
)
.outerjoin(
Expand All @@ -184,8 +159,8 @@ async def get_fuel_export_table_options(self, compliance_period: str):
EnergyDensity,
and_(
EnergyDensity.fuel_type_id == FuelType.fuel_type_id,
EnergyDensity.compliance_period_id == subquery_compliance_period_id
)
EnergyDensity.compliance_period_id == subquery_compliance_period_id,
),
)
.outerjoin(UnitOfMeasure, EnergyDensity.uom_id == UnitOfMeasure.uom_id)
.outerjoin(
Expand All @@ -194,7 +169,8 @@ async def get_fuel_export_table_options(self, compliance_period: str):
EnergyEffectivenessRatio.fuel_category_id
== FuelCategory.fuel_category_id,
EnergyEffectivenessRatio.fuel_type_id == FuelInstance.fuel_type_id,
EnergyEffectivenessRatio.compliance_period_id == subquery_compliance_period_id
EnergyEffectivenessRatio.compliance_period_id
== subquery_compliance_period_id,
),
)
.outerjoin(
Expand Down Expand Up @@ -293,23 +269,13 @@ async def get_fuel_exports_paginated(
return paginated_exports, total_count

@repo_handler
async def get_fuel_export_by_id(self, fuel_export_id: int, compliance_period_id:int) -> FuelExport:
async def get_fuel_export_by_id(self, fuel_export_id: int) -> FuelExport:
"""
Retrieve a fuel export row from the database with compliance period filtering
Retrieve a fuel supply row from the database
"""
query = self.query.where(FuelExport.fuel_export_id == fuel_export_id)
result = await self.db.execute(query)
fuel_export = result.unique().scalar_one_or_none()

if fuel_export and fuel_export.fuel_type and compliance_period_id:
default_ci = await self.get_default_carbon_intensity(
fuel_export.fuel_type.fuel_type_id,
compliance_period_id
)
# Add default_ci to the fuel type object
setattr(fuel_export.fuel_type, 'default_carbon_intensity', default_ci)

return fuel_export
return result.unique().scalar_one_or_none()

@repo_handler
async def update_fuel_export(self, fuel_export: FuelExport) -> FuelExport:
Expand Down
Loading

0 comments on commit f7ea189

Please sign in to comment.