Skip to content

Commit

Permalink
Merge branch 'release-1.0.0' into LCRS-1937-DefaultCI
Browse files Browse the repository at this point in the history
  • Loading branch information
areyeslo authored Feb 25, 2025
2 parents f7ea189 + 9bf2fa0 commit dcdf13e
Show file tree
Hide file tree
Showing 13 changed files with 536 additions and 149 deletions.
6 changes: 6 additions & 0 deletions backend/lcfs/db/models/compliance/ComplianceReportStatus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class ComplianceReportStatusEnum(enum.Enum):
Reassessed = "Reassessed"
Rejected = "Rejected"

def underscore_value(self) -> str:
"""
Return the status as an underscored string.
"""
return self.value.replace(" ", "_")


class ComplianceReportStatus(BaseModel, EffectiveDates):
__tablename__ = "compliance_report_status"
Expand Down
61 changes: 30 additions & 31 deletions backend/lcfs/web/api/compliance_report/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import defaultdict
from datetime import datetime
from typing import List, Optional, Dict, Union, Tuple
from typing import List, Optional, Dict, Union

import structlog
from fastapi import Depends
from sqlalchemy import func, select, and_, asc, desc, update, or_, String, cast
from sqlalchemy import func, select, and_, asc, desc, update, String, cast
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, contains_eager, aliased
from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.inspection import inspect

from lcfs.db.dependencies import get_async_db_session
Expand Down Expand Up @@ -73,22 +73,33 @@ def apply_filters(self, pagination, conditions):
filter_type = filter.filter_type
if filter.field == "status":
field = cast(
get_field_for_filter(
ComplianceReportListView, "report_status"),
get_field_for_filter(ComplianceReportListView, "report_status"),
String,
)
# Check if filter_value is a comma-separated string
if isinstance(filter_value, str) and "," in filter_value:
filter_value = filter_value.split(",") # Convert to list

if isinstance(filter_value, list):
filter_value = [value.replace(" ", "_")
for value in filter_value]

def underscore_string(val):
"""
If the item is an enum member, get its `.value`
Then do .replace(" ", "_") so we get underscores
"""
if isinstance(val, ComplianceReportStatusEnum):
val = val.value # convert enum to string
return val.replace(" ", "_")

filter_value = [underscore_string(val) for val in filter_value]
filter_type = "set"
else:
if isinstance(filter_value, ComplianceReportStatusEnum):
filter_value = filter_value.value
filter_value = filter_value.replace(" ", "_")

elif filter.field == "type":
field = get_field_for_filter(
ComplianceReportListView, "report_type")
field = get_field_for_filter(ComplianceReportListView, "report_type")
elif filter.field == "organization":
field = get_field_for_filter(
ComplianceReportListView, "organization_name"
Expand All @@ -98,12 +109,10 @@ def apply_filters(self, pagination, conditions):
ComplianceReportListView, "compliance_period"
)
else:
field = get_field_for_filter(
ComplianceReportListView, filter.field)
field = get_field_for_filter(ComplianceReportListView, filter.field)

conditions.append(
apply_filter_conditions(
field, filter_value, filter_option, filter_type)
apply_filter_conditions(field, filter_value, filter_option, filter_type)
)

@repo_handler
Expand Down Expand Up @@ -158,8 +167,7 @@ async def get_compliance_period(self, period: str) -> CompliancePeriod:
Retrieve a compliance period from the database
"""
result = await self.db.scalar(
select(CompliancePeriod).where(
CompliancePeriod.description == period)
select(CompliancePeriod).where(CompliancePeriod.description == period)
)
return result

Expand Down Expand Up @@ -198,8 +206,7 @@ async def get_compliance_report_status_by_desc(
Retrieve the compliance report status ID from the database based on the description.
Replaces spaces with underscores in the status description.
"""
status_enum = status.replace(
" ", "_") # frontend sends status with spaces
status_enum = status.replace(" ", "_") # frontend sends status with spaces
result = await self.db.execute(
select(ComplianceReportStatus).where(
ComplianceReportStatus.status
Expand Down Expand Up @@ -386,17 +393,15 @@ async def get_reports_paginated(
self.apply_filters(pagination, conditions)

# Pagination and offset setup
offset = 0 if (pagination.page < 1) else (
pagination.page - 1) * pagination.size
offset = 0 if (pagination.page < 1) else (pagination.page - 1) * pagination.size
limit = pagination.size

# Build the main query
query = query.where(and_(*conditions))

# Apply sorting from pagination
if len(pagination.sort_orders) < 1:
field = get_field_for_filter(
ComplianceReportListView, "update_date")
field = get_field_for_filter(ComplianceReportListView, "update_date")
query = query.order_by(desc(field))
for order in pagination.sort_orders:
sort_method = asc if order.direction == "asc" else desc
Expand Down Expand Up @@ -731,15 +736,13 @@ def aggregate_quantities(
isinstance(record, FuelSupply)
and record.fuel_type.fossil_derived == fossil_derived
):
fuel_category = self._format_category(
record.fuel_category.category)
fuel_category = self._format_category(record.fuel_category.category)
fuel_quantities[fuel_category] += record.quantity
elif (
isinstance(record, OtherUses)
and record.fuel_type.fossil_derived == fossil_derived
):
fuel_category = self._format_category(
record.fuel_category.category)
fuel_category = self._format_category(record.fuel_category.category)
fuel_quantities[fuel_category] += record.quantity_supplied

return dict(fuel_quantities)
Expand Down Expand Up @@ -891,15 +894,11 @@ async def get_compliance_report_group_id(self, report_id):

@repo_handler
async def get_changelog_data(
self,
pagination: PaginationRequestSchema,
compliance_report_id: int,
selection
self, pagination: PaginationRequestSchema, compliance_report_id: int, selection
):

conditions = [selection.compliance_report_id == compliance_report_id]
offset = 0 if pagination.page < 1 else (
pagination.page - 1) * pagination.size
offset = 0 if pagination.page < 1 else (pagination.page - 1) * pagination.size
limit = pagination.size

# Create an alias for the previous version row.
Expand Down
37 changes: 19 additions & 18 deletions backend/lcfs/web/api/compliance_report/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Union, Type

import structlog
from fastapi import Depends, Request
from fastapi import Depends

from lcfs.db.models.compliance.ComplianceReport import (
ComplianceReport,
Expand All @@ -29,7 +29,6 @@
from lcfs.web.api.organization_snapshot.services import OrganizationSnapshotService
from lcfs.web.core.decorators import service_handler
from lcfs.web.exception.exceptions import DataNotFoundException, ServiceException
from lcfs.db.base import ActionTypeEnum

logger = structlog.get_logger(__name__)

Expand All @@ -39,7 +38,6 @@ def __init__(
self,
repo: ComplianceReportRepository = Depends(),
snapshot_services: OrganizationSnapshotService = Depends(),

) -> None:
self.repo = repo
self.snapshot_services = snapshot_services
Expand All @@ -66,8 +64,7 @@ async def create_compliance_report(
report_data.status
)
if not draft_status:
raise DataNotFoundException(
f"Status '{report_data.status}' not found.")
raise DataNotFoundException(f"Status '{report_data.status}' not found.")

# Generate a new group_uuid for the new report series
group_uuid = str(uuid.uuid4())
Expand Down Expand Up @@ -173,7 +170,10 @@ async def create_supplemental_report(

@service_handler
async def get_compliance_reports_paginated(
self, pagination, organization_id: int = None, bceid_user: bool = False
self,
pagination,
organization_id: int = None,
bceid_user: bool = False,
):
"""Fetches all compliance reports"""
if bceid_user:
Expand Down Expand Up @@ -208,8 +208,8 @@ async def get_compliance_reports_paginated(

def _mask_report_status(self, reports: List) -> List:
recommended_statuses = {
ComplianceReportStatusEnum.Recommended_by_analyst.value,
ComplianceReportStatusEnum.Recommended_by_manager.value,
ComplianceReportStatusEnum.Recommended_by_analyst.underscore_value(),
ComplianceReportStatusEnum.Recommended_by_manager.underscore_value(),
}

masked_reports = []
Expand Down Expand Up @@ -263,8 +263,7 @@ async def get_compliance_report_by_id(

if apply_masking:
# Apply masking to each report in the chain
masked_chain = self._mask_report_status(
compliance_report_chain)
masked_chain = self._mask_report_status(compliance_report_chain)
# Apply history masking to each report in the chain
masked_chain = [
self._mask_report_status_for_history(report, apply_masking)
Expand Down Expand Up @@ -317,7 +316,7 @@ def _model_to_dict(self, record) -> dict:
"""Safely convert a model to a dict, skipping lazy-loaded attributes that raise errors."""
result = {}
for key, value in record.__dict__.items():
if key == '_sa_instance_state':
if key == "_sa_instance_state":
continue
try:
result[key] = value
Expand All @@ -330,10 +329,11 @@ async def get_changelog_data(
self,
pagination: PaginationResponseSchema,
compliance_report_id: int,
selection: Type[Union[FuelSupply, OtherUses,
NotionalTransfer, FuelExport]]
selection: Type[Union[FuelSupply, OtherUses, NotionalTransfer, FuelExport]],
):
changelog, total_count = await self.repo.get_changelog_data(pagination, compliance_report_id, selection)
changelog, total_count = await self.repo.get_changelog_data(
pagination, compliance_report_id, selection
)

groups = {}
for record in changelog:
Expand All @@ -359,12 +359,13 @@ async def get_changelog_data(
changelog = [record for group in groups.values() for record in group]

return {
'pagination': PaginationResponseSchema(
"pagination": PaginationResponseSchema(
total=total_count,
page=pagination.page,
size=pagination.size,
total_pages=math.ceil(
total_count / pagination.size) if pagination.size else 0,
total_pages=(
math.ceil(total_count / pagination.size) if pagination.size else 0
),
),
'changelog': changelog,
"changelog": changelog,
}
43 changes: 28 additions & 15 deletions backend/lcfs/web/api/final_supply_equipment/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,14 @@ async def import_async(
socket_connect_timeout=5,
)

await _update_progress(redis_client, job_id, 5, "Initializing services...")
await _update_progress(
redis_client, job_id, 5, "Initializing services..."
)

if overwrite:
await _update_progress(redis_client, job_id, 10, "Deleting old data...")
await _update_progress(
redis_client, job_id, 10, "Deleting old data..."
)
await fse_service.delete_all(compliance_report_id)
org_code = user.organization.organization_code
await fse_repo.reset_seq_by_org(org_code)
Expand All @@ -165,7 +169,9 @@ async def import_async(
)
clamav_service.scan_file(file)

await _update_progress(redis_client, job_id, 20, "Loading Excel sheet...")
await _update_progress(
redis_client, job_id, 20, "Loading Excel sheet..."
)

try:
sheet = _load_sheet(file)
Expand All @@ -188,7 +194,9 @@ async def import_async(
valid_intended_users = await fse_repo.get_intended_user_types()
valid_use_types = await fse_repo.get_intended_use_types()
valid_use_type_names = {obj.type for obj in valid_use_types}
valid_user_type_names = {obj.type_name for obj in valid_intended_users}
valid_user_type_names = {
obj.type_name for obj in valid_intended_users
}

# Iterate through all data rows, skipping the header
for row_idx, row in enumerate(
Expand All @@ -210,6 +218,10 @@ async def import_async(
errors=errors,
)

# Check if the entire row is empty
if all(cell is None for cell in row):
continue

# Validate row
error = _validate_row(
row, row_idx, valid_use_type_names, valid_user_type_names
Expand All @@ -222,7 +234,9 @@ async def import_async(
# Parse row data and insert into DB
try:
fse_data = _parse_row(row, compliance_report_id)
await fse_service.create_final_supply_equipment(fse_data, user)
await fse_service.create_final_supply_equipment(
fse_data, user
)
created += 1
except Exception as ex:
logger.error(str(ex))
Expand All @@ -239,7 +253,9 @@ async def import_async(
rejected=rejected,
errors=errors,
)
logger.debug(f"Completed importing FSE data, {created} rows created")
logger.debug(
f"Completed importing FSE data, {created} rows created"
)

return {
"success": True,
Expand All @@ -266,6 +282,7 @@ async def import_async(
finally:
await engine.dispose()


def _load_sheet(file: UploadFile) -> Worksheet:
"""
Loads and returns the 'FSE' worksheet from the provided Excel file.
Expand Down Expand Up @@ -307,10 +324,6 @@ def _validate_row(
notes,
) = row

# Check if the entire row is empty
if all(cell is None for cell in row):
return f"Row {row_idx}: Row is empty"

missing_fields = []
if supply_from_date is None:
missing_fields.append("Supply from date")
Expand Down Expand Up @@ -404,15 +417,15 @@ def _parse_row(
supply_from_date=supply_from_date,
supply_to_date=supply_to_date,
kwh_usage=kwh_usage,
serial_nbr=serial_number or "",
manufacturer=manufacturer or "",
model=model or "",
serial_nbr=str(serial_number) or "",
manufacturer=str(manufacturer) or "",
model=str(model) or "",
level_of_equipment=level_of_equipment or "",
ports=PortsEnum(ports) if ports else None,
intended_uses=intended_uses,
intended_users=intended_users,
street_address=street_address or "",
city=city or "",
street_address=str(street_address) or "",
city=str(city) or "",
postal_code=postal_code or "",
latitude=latitude,
longitude=longitude,
Expand Down
Loading

0 comments on commit dcdf13e

Please sign in to comment.