Skip to content

Commit

Permalink
Vulnerabilities test cases work, refactor for improved clarity on fil…
Browse files Browse the repository at this point in the history
…ter_helpers.filter_vulnerabilities()
  • Loading branch information
JCantu248 committed Nov 15, 2024
1 parent d7cb6cf commit c09850a
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 131 deletions.
136 changes: 46 additions & 90 deletions backend/src/xfd_django/xfd_api/helpers/filter_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,104 +112,60 @@ def filter_vulnerabilities(
vulnerabilities: QuerySet, vulnerability_filters: VulnerabilityFilters
):
"""
Filter vulnerabilitie
Filter vulnerabilities based on given filters.
Arguments:
vulnerabilities: A list of all vulnerabilities, sorted
vulnerability_filters: Value to filter the vulnberabilities table by
vulnerabilities: A list of all vulnerabilities, sorted.
vulnerability_filters: Value to filter the vulnerabilities table by.
Returns:
object: a list of Vulnerability objects
QuerySet: A filtered list of Vulnerability objects.
"""
try:
if vulnerability_filters.id:
vulnerability_by_id = Vulnerability.objects.values("id").get(
id=vulnerability_filters.id
)
if not vulnerability_by_id:
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided id"
)
vulnerabilities = vulnerabilities.filter(id=vulnerability_by_id)
# Initialize a query that includes all vulnerabilities
query = vulnerabilities

if vulnerability_filters.title:
vulnerabilities_by_title = Vulnerability.objects.values("id").filter(
title=vulnerability_filters.title
)
if not vulnerabilities_by_title.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided title"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_title)
# Apply filters based on the provided criteria
if vulnerability_filters.id:
query = query.filter(id=vulnerability_filters.id)

if vulnerability_filters.domain:
vulnerabilities_by_domain = Vulnerability.objects.values("id").filter(
domain=vulnerability_filters.domain
)
if not vulnerabilities_by_domain.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided domain"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain)
if vulnerability_filters.title:
query = query.filter(title=vulnerability_filters.title)

if vulnerability_filters.severity:
vulnerabilities_by_severity = Vulnerability.objects.values("id").filter(
severity=vulnerability_filters.severity
)
if not vulnerabilities_by_severity.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided severity level"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_severity)
if vulnerability_filters.domain:
query = query.filter(domain=vulnerability_filters.domain)

if vulnerability_filters.cpe:
vulnerabilities_by_cpe = Vulnerability.objects.values("id").filter(
cpe=vulnerability_filters.cpe
)
if not vulnerabilities_by_cpe.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided Cpe"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_cpe)
if vulnerability_filters.severity:
query = query.filter(severity=vulnerability_filters.severity)

if vulnerability_filters.state:
vulnerabilities_by_state = Vulnerability.objects.values("id").filter(
state=vulnerability_filters.state
)
if not vulnerabilities_by_state.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided state"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_state)
if vulnerability_filters.cpe:
query = query.filter(cpe=vulnerability_filters.cpe)

if vulnerability_filters.organization:
domains = Domain.objects.all()
domains_by_organization = Domain.objects.values("id").filter(
organization_id=vulnerability_filters.organization
)
if not domains_by_organization.exists():
raise Vulnerability.DoesNotExist(
"No Organization-Domain found with the provided organization ID"
)
domains = domains.filter(id__in=domains_by_organization)
vulnerabilities_by_domain = Vulnerability.objects.values("id").filter(
id__in=domains
)
if not vulnerabilities_by_domain.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided organization ID"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_domain)
if vulnerability_filters.state:
query = query.filter(state=vulnerability_filters.state)

if vulnerability_filters.isKev:
vulnerabilities_by_is_kev = Vulnerability.objects.values("id").filter(
isKev=vulnerability_filters.isKev
if vulnerability_filters.organization:
# Fetch domains based on the organization ID
domains_by_organization = Domain.objects.filter(
organization_id=vulnerability_filters.organization
)

if not domains_by_organization.exists():
raise Vulnerability.DoesNotExist(
"No Organization-Domain found with the provided organization ID"
)
if not vulnerabilities_by_is_kev.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided isKev value"
)
vulnerabilities = vulnerabilities.filter(id__in=vulnerabilities_by_is_kev)
return vulnerabilities
except Domain.DoesNotExist as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Filter vulnerabilities based on the found domains
query = query.filter(domain__in=domains_by_organization)

if (
vulnerability_filters.isKev is not None
): # Check for None to distinguish between True/False
query = query.filter(isKev=vulnerability_filters.isKev)

# If the queryset is empty, raise a not found exception (404)
if not query.exists():
raise Vulnerability.DoesNotExist(
"No Vulnerabilities found with the provided filters."
)

return query
40 changes: 18 additions & 22 deletions backend/src/xfd_django/xfd_api/tests/test_domain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
# Standard Python Libraries
from datetime import datetime
import logging
import secrets

# Configure logging
logging.basicConfig(level=logging.DEBUG) # Set the logging level to DEBUG
logger = logging.getLogger(__name__)


# Third-Party Libraries
from fastapi.testclient import TestClient
import pytest
Expand All @@ -33,6 +27,7 @@

@pytest.mark.django_db(transaction=True)
def test_get_domain_by_id():
# Get domain by Id.
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -49,11 +44,12 @@ def test_get_domain_by_id():
data = response.json()

assert response.status_code == 200
assert data["id"] == test_id


@pytest.mark.django_db(transaction=True)
def test_filter_domain_by_ip(capfd):
# Filter domains by ip
def test_search_domain_by_ip():
# Search domains by ip
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -75,8 +71,8 @@ def test_filter_domain_by_ip(capfd):


@pytest.mark.django_db(transaction=True)
def test_filter_domain_by_port():
# Test filter domains by port
def test_search_domain_by_port():
# Test search domains by port
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -98,8 +94,8 @@ def test_filter_domain_by_port():


@pytest.mark.django_db(transaction=True)
def test_filter_domain_by_service():
# Test filter domains by service_id
def test_search_domain_by_service():
# Test search domains by service_id
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -122,8 +118,8 @@ def test_filter_domain_by_service():


@pytest.mark.django_db(transaction=True)
def test_filter_domain_by_organization():
# Test filter domains by organization
def test_search_domain_by_organization():
# Test search domains by organization
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -148,8 +144,8 @@ def test_filter_domain_by_organization():


@pytest.mark.django_db(transaction=True)
def test_filter_domain_by_organization_name():
# Test filter domains by organization
def test_search_domain_by_organization_name():
# Test search domains by organization
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -176,8 +172,8 @@ def test_filter_domain_by_organization_name():


@pytest.mark.django_db(transaction=True)
def test_filter_domain_by_vulnerabilities():
# Test filter domains by vulnerabilities
def test_search_domain_by_vulnerabilities():
# Test search domains by vulnerabilities
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -204,8 +200,8 @@ def test_filter_domain_by_vulnerabilities():


@pytest.mark.django_db(transaction=True)
def test_filter_domains_multiple_criteria():
# Test filter domains by multiple criteria
def test_search_domains_multiple_criteria():
# Test search domains by multiple criteria
user = User.objects.create(
firstName="",
lastName="",
Expand All @@ -231,8 +227,8 @@ def test_filter_domains_multiple_criteria():


@pytest.mark.django_db(transaction=True)
def test_filter_domains_does_not_exist():
# Test filter domains if record does not exist
def test_search_domains_does_not_exist():
# Test search domains if record does not exist
user = User.objects.create(
firstName="",
lastName="",
Expand Down
Loading

0 comments on commit c09850a

Please sign in to comment.