diff --git a/connect_ext_ppr/filters.py b/connect_ext_ppr/filters.py index e9770f1..373606e 100644 --- a/connect_ext_ppr/filters.py +++ b/connect_ext_ppr/filters.py @@ -1,7 +1,7 @@ from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, validator from fastapi_filter import FilterDepends, with_prefix from fastapi_filter.contrib.sqlalchemy import Filter @@ -9,44 +9,121 @@ Deployment, DeploymentRequest, MarketplaceConfiguration, ) from connect_ext_ppr.models.ppr import PPRVersion +from connect_ext_ppr.models.replicas import Product, Account from connect_ext_ppr.models.task import Task +# TODO: Make mixin +def restrict_sortable_fields_base(value, allowed_field_names): + if value is None: + return None + + for field_name in value: + field_name = field_name.replace("+", "").replace("-", "") # + + if field_name not in allowed_field_names: + raise ValueError(f"You may only sort by: {', '.join(allowed_field_names)}") + + return value + + +class ProductFilter(Filter): + name: Optional[str] + name__ilike: Optional[str] + + class Constants(Filter.Constants): + model = Product + + +class VendorFilter(Filter): + name: Optional[str] + name__ilike: Optional[str] + + class Constants(Filter.Constants): + model = Account + + class DeploymentFilter(Filter): + id: Optional[str] hub_id: Optional[str] + hub__name: Optional[str] # custom filter + hub__name__like: Optional[str] # custom filter product_id: Optional[str] + product: Optional[ProductFilter] = FilterDepends(with_prefix('product', ProductFilter)) status: Optional[str] vendor_id: Optional[str] + vendor: Optional[VendorFilter] = FilterDepends(with_prefix('vendor', VendorFilter)) + + order_by: Optional[list[str]] + custom_order_by: Optional[str] + + @validator("order_by") + @classmethod + def restrict_sortable_fields(cls, value): + return restrict_sortable_fields_base(value, ['id', 'product_id']) + + @validator("hub__name", "hub__name__like", "custom_order_by") + @classmethod + def validate_hub_name(cls, value): + pass class Constants(Filter.Constants): model = Deployment class PPRVersionFilter(Filter): - id: Optional[str] version: Optional[int] + status: Optional[str] + description: Optional[str] + description__ilike: Optional[str] order_by: Optional[List[str]] + @validator("order_by") + @classmethod + def restrict_sortable_fields(cls, value): + return restrict_sortable_fields_base(value, ['version']) + class Constants(Filter.Constants): model = PPRVersion class DeploymentRequestFilter(Filter): - status: Optional[str] - delegate_l2: Optional[bool] - + """For specific deployment""" + id: Optional[str] + manually: Optional[bool] ppr: Optional[PPRVersionFilter] = FilterDepends(with_prefix('ppr', PPRVersionFilter)) + status: Optional[str] order_by: Optional[List[str]] + @validator("order_by") + @classmethod + def restrict_sortable_fields(cls, value): + return restrict_sortable_fields_base(value, ['id']) + class Constants(Filter.Constants): model = DeploymentRequest +class DeploymentRequestExtendedFilter(DeploymentRequestFilter): + """For all deployments""" + delegate_l2: Optional[bool] + + deployment_id: Optional[str] + deployment: Optional[DeploymentFilter] = FilterDepends( + with_prefix('deployment', DeploymentFilter), + ) + + @validator("order_by") + @classmethod + def restrict_sortable_fields(cls, value): + return restrict_sortable_fields_base(value, ['id', 'deployment_id']) + + class MarketplaceConfigurationFilter(Filter): + """For specific DeploymentRequest""" marketplace: Optional[str] - ppr: Optional[PPRVersionFilter] = FilterDepends(with_prefix('ppr', PPRVersionFilter)) order_by: Optional[List[str]] @@ -54,15 +131,9 @@ class Constants(Filter.Constants): model = MarketplaceConfiguration -class DeploymentRequestExtendedFilter(DeploymentRequestFilter): - id: Optional[str] - - deployment: Optional[DeploymentFilter] = FilterDepends( - with_prefix('deployment', DeploymentFilter), - ) - - class Constants(Filter.Constants): - model = DeploymentRequest +class MarketplaceConfigurationExtendedFilter(MarketplaceConfigurationFilter): + """For Deployment""" + ppr: Optional[PPRVersionFilter] = FilterDepends(with_prefix('ppr', PPRVersionFilter)) class PricingBatchFilter(BaseModel): @@ -70,7 +141,15 @@ class PricingBatchFilter(BaseModel): class TaskFilter(Filter): + id: Optional[str] status: Optional[str] + order_by: Optional[List[str]] + + @validator("order_by") + @classmethod + def restrict_sortable_fields(cls, value): + return restrict_sortable_fields_base(value, ['id']) + class Constants(Filter.Constants): model = Task diff --git a/connect_ext_ppr/webapp.py b/connect_ext_ppr/webapp.py index f4a6840..d8a0b4a 100644 --- a/connect_ext_ppr/webapp.py +++ b/connect_ext_ppr/webapp.py @@ -37,6 +37,7 @@ from connect_ext_ppr.filters import ( DeploymentFilter, DeploymentRequestExtendedFilter, DeploymentRequestFilter, MarketplaceConfigurationFilter, PPRVersionFilter, PricingBatchFilter, TaskFilter, + MarketplaceConfigurationExtendedFilter, ) from connect_ext_ppr.models.configuration import Configuration from connect_ext_ppr.models.deployment import ( @@ -202,7 +203,7 @@ def list_deployment_requests( DeploymentRequest.deployment, DeploymentRequest.ppr, ).filter( DeploymentRequest.deployment_id.in_(deployments), - ).order_by(desc(DeploymentRequest.created_at)) + ) deployment_requests = dr_filter.filter(deployment_requests) deployment_requests = dr_filter.sort(deployment_requests) deployment_requests = apply_pagination(deployment_requests, db, pagination_params, response) @@ -265,8 +266,9 @@ def list_deployment_request_tasks( dr = get_deployment_request_by_id(depl_req_id, db, installation) if dr: task_list = [] - qs = db.query(Task).filter_by(deployment_request_id=dr.id).order_by(Task.id) + qs = db.query(Task).filter_by(deployment_request_id=dr.id) qs = task_filter.filter(qs) + qs = task_filter.sort(qs) for task in apply_pagination(qs, db, pagination_params, response): task_list.append(get_task_schema(task)) @@ -404,8 +406,9 @@ def list_requests_for_deployment( db .query(DeploymentRequest) .filter_by(deployment_id=deployment_id) - .order_by(desc(DeploymentRequest.id)) - ) + ).join(DeploymentRequest.ppr) + qs = dr_filter.filter(qs) + qs = dr_filter.sort(qs) for dr in apply_pagination(qs, db, pagination_params, response): response_list.append(get_deployment_request_schema(dr, hub)) @@ -426,11 +429,24 @@ def get_deployments( installation: dict = Depends(get_installation), ): deployments = db.query(Deployment).filter_by(account_id=installation['owner']['id']) - deployments = deployment_filter.filter(deployments) - deployments = apply_pagination(deployments, db, pagination_params, response) + deployments = deployment_filter.filter(deployments.join(Deployment.product)) listings = get_all_listing_info(client) vendors = [li['vendor'] for li in listings] hubs = [hub['hub'] for li in listings for hub in li['contract']['marketplace']['hubs']] + + if deployment_filter.hub__name: # custom filter by hub name + dep_ids = [dep.id for dep in deployments if filter_object_list_by_id(hubs, dep.hub_id)['name'] == deployment_filter.hub__name] + deployments = deployments.filter(Deployment.id.in_(dep_ids)) + + if deployment_filter.hub__name__like: # custom filter by hub name + dep_ids = [dep.id for dep in deployments if deployment_filter.hub__name__like in filter_object_list_by_id(hubs, dep.hub_id)['name']] + deployments = deployments.filter(Deployment.id.in_(dep_ids)) + + deployments = deployment_filter.sort(deployments) + if deployment_filter.custom_order_by == 'product__name': # custom sort by product name + deployments = deployments.order_by(Product.name) + deployments = apply_pagination(deployments, db, pagination_params, response) + response_list = [] for dep in deployments: vendor = filter_object_list_by_id(vendors, dep.vendor_id) @@ -438,6 +454,7 @@ def get_deployments( response_list.append( get_deployment_schema(dep, dep.product, vendor, hub), ) + return response_list @router.get( @@ -610,7 +627,6 @@ def get_pprs( .filter_by(deployment=deployment_id) .join(File, PPRVersion.file == File.id) .outerjoin(Configuration, PPRVersion.configuration == Configuration.id) - .order_by(desc(PPRVersion.version)) ) ppr_file_conf_qs = ppr_filter.filter(ppr_file_conf_qs) ppr_file_conf_qs = ppr_filter.sort(ppr_file_conf_qs) @@ -674,7 +690,9 @@ def add_ppr( def get_marketplaces_by_deployment( self, deployment_id: str, - m_filter: MarketplaceConfigurationFilter = FilterDepends(MarketplaceConfigurationFilter), + m_filter: MarketplaceConfigurationExtendedFilter = FilterDepends( + MarketplaceConfigurationExtendedFilter, + ), pagination_params: PaginationParams = Depends(), response: Response = None, client: ConnectClient = Depends(get_installation_client),