Skip to content

Commit

Permalink
Merge pull request #3013 from jbaptperez/fix/tests
Browse files Browse the repository at this point in the history
Fix - Broken tests
  • Loading branch information
elegantmoose authored Oct 17, 2024
2 parents 0524d12 + a196363 commit 06042d6
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 28 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,16 @@ jobs:
uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c
with:
python-version: ${{ matrix.python-version }}
- name: Setup Node.js
uses: actions/setup-node@v3
with:
node-version: '20'
- name: Install dependencies
run: |
pip install --upgrade virtualenv
pip install tox
npm --prefix plugins/magma install
npm --prefix plugins/magma run build
- name: Run tests
env:
TOXENV: ${{ matrix.toxenv }}
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ We use the basic feature branch GIT flow. Fork this repository and create a feat
# Run the tests
Tests can be run by executing:
```
python -m pytest
python -m pytest --asyncio-mode=auto
```
This will run all unit tests in your current development environment. Depending on the level of the change, you might need to run the test suite on various versions of Python. The unit testing pipeline will run the entire suite across multiple Python versions that we support when you submit your PR.

Expand Down
2 changes: 1 addition & 1 deletion app/api/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def enable(self):
self.app_svc.application.router.add_route('*', '/api/rest', self.rest_core)
self.app_svc.application.router.add_route('GET', '/api/{index}', self.rest_core_info)
self.app_svc.application.router.add_route('GET', '/file/download_exfil', self.download_exfil_file)
self.app_svc.application.router.add_route('GET', '/{tail:(?!plugin/).*}', self.handle_catch)
self.app_svc.application.router.add_route('GET', '/{tail:(?!plugin/|api/v2/).*}', self.handle_catch)

async def validate_login(self, request):
return await self.auth_svc.login_user(request)
Expand Down
3 changes: 1 addition & 2 deletions app/api/v2/handlers/health_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from aiohttp import web

import app
from app.api.v2 import security
from app.api.v2.handlers.base_api import BaseApi
from app.api.v2.schemas.caldera_info_schemas import CalderaInfoSchema

Expand All @@ -16,7 +15,7 @@ def __init__(self, services):

def add_routes(self, app: web.Application):
router = app.router
router.add_get('/health', security.authentication_exempt(self.get_health_info))
router.add_get('/health', self.get_health_info)

@aiohttp_apispec.docs(tags=['health'],
summary='Health endpoints returns the status of Caldera',
Expand Down
2 changes: 1 addition & 1 deletion app/api/v2/handlers/payload_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def post_payloads(self, request: web.Request):
tags=['payloads'],
summary='Delete a payload',
description='Deletes a given payload.',
responses = {
responses={
204: {"description": "Payload has been properly deleted."},
404: {"description": "Payload not found."},
})
Expand Down
1 change: 1 addition & 0 deletions app/api/v2/managers/fact_source_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app.api.v2.managers.base_api_manager import BaseApiManager


class FactSourceApiManager(BaseApiManager):
def __init__(self, data_svc, file_svc, knowledge_svc):
super().__init__(data_svc=data_svc, file_svc=file_svc)
Expand Down
8 changes: 4 additions & 4 deletions app/api/v2/schemas/payload_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@


class PayloadQuerySchema(schema.Schema):
sort = fields.Boolean(required=False, default=False)
exclude_plugins = fields.Boolean(required=False, default=False)
add_path = fields.Boolean(required=False, default=False)
sort = fields.Boolean(required=False, load_default=False)
exclude_plugins = fields.Boolean(required=False, load_default=False)
add_path = fields.Boolean(required=False, load_default=False)


class PayloadSchema(schema.Schema):
payloads = fields.List(fields.String())


class PayloadCreateRequestSchema(schema.Schema):
file = fields.Raw(type="file", required=True)
file = fields.Raw(required=True, metadata={'type': 'file'})


class PayloadDeleteRequestSchema(schema.Schema):
Expand Down
2 changes: 1 addition & 1 deletion app/service/data_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ async def _verify_adversary_profiles(self):
def _get_plugin_name(self, filename):
plugin_path = pathlib.PurePath(filename).parts
return plugin_path[1] if 'plugins' in plugin_path else ''

async def get_facts_from_source(self, fact_source_id):
fact_sources = await self.locate('sources', match=dict(id=fact_source_id))
if len(fact_sources) == 0:
Expand Down
7 changes: 3 additions & 4 deletions tests/api/v2/handlers/test_health_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@pytest.fixture
def expected_caldera_info():
return {
'access': 'RED',
'application': 'Caldera',
'plugins': [],
'version': app.get_version()
Expand All @@ -20,8 +21,6 @@ async def test_get_health(self, api_v2_client, api_cookies, expected_caldera_inf
output_info = await resp.json()
assert output_info == expected_caldera_info

async def test_unauthorized_get_health(self, api_v2_client, expected_caldera_info):
async def test_unauthorized_get_health(self, api_v2_client):
resp = await api_v2_client.get('/api/v2/health')
assert resp.status == HTTPStatus.OK
output_info = await resp.json()
assert output_info == expected_caldera_info
assert resp.status == HTTPStatus.UNAUTHORIZED
49 changes: 44 additions & 5 deletions tests/api/v2/handlers/test_payloads_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,53 @@
import os
import tempfile
from http import HTTPStatus

import pytest


@pytest.fixture
def expected_payload_file_paths():
"""
Generates (and deletes) real dummy files because the payload API looks for payload files in
"data/payloads" and/or in "plugins/<plugin-name>/payloads".
:return: A set of relative paths of dummy payloads.
"""
directory = "data/payloads"
os.makedirs(directory, exist_ok=True)

file_paths = set()
current_working_dir = os.getcwd()

try:
for _ in range(3):
fd, file_path = tempfile.mkstemp(prefix="payload_", dir=directory)
os.close(fd)
relative_path = os.path.relpath(file_path, start=current_working_dir)
file_paths.add(relative_path)
yield file_paths
finally:
for file_path in file_paths:
os.remove(file_path)


@pytest.fixture
def expected_payload_file_names(expected_payload_file_paths):
return {os.path.basename(path) for path in expected_payload_file_paths}


class TestPayloadsApi:

async def test_get_payloads(self, api_v2_client, api_cookies):
async def test_get_payloads(self, api_v2_client, api_cookies, expected_payload_file_names):
resp = await api_v2_client.get('/api/v2/payloads', cookies=api_cookies)
payloads_list = await resp.json()
assert len(payloads_list) > 0
payload = payloads_list[0]
assert type(payload) is str
payload_file_names = await resp.json()
assert len(payload_file_names) >= len(expected_payload_file_names)

filtered_payload_file_names = { # Excluding any other real files in data/payloads...
file_name for file_name in payload_file_names
if file_name in expected_payload_file_names
}

assert filtered_payload_file_names == expected_payload_file_names

async def test_unauthorized_get_payloads(self, api_v2_client):
resp = await api_v2_client.get('/api/v2/payloads')
Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os.path

import jinja2
import pytest
import random
import string
Expand All @@ -14,8 +15,8 @@
from unittest import mock
from aiohttp_apispec import validation_middleware
from aiohttp import web
import aiohttp_jinja2
from pathlib import Path

from app.api.v2.handlers.agent_api import AgentApi
from app.api.v2.handlers.ability_api import AbilityApi
from app.api.v2.handlers.objective_api import ObjectiveApi
Expand All @@ -29,6 +30,7 @@
from app.api.v2.handlers.planner_api import PlannerApi
from app.api.v2.handlers.health_api import HealthApi
from app.api.v2.handlers.schedule_api import ScheduleApi
from app.api.v2.handlers.payload_api import PayloadApi
from app.objects.c_obfuscator import Obfuscator
from app.objects.c_objective import Objective
from app.objects.c_planner import PlannerSchema
Expand Down Expand Up @@ -356,6 +358,7 @@ def make_app(svcs):
PlannerApi(svcs).add_routes(app)
HealthApi(svcs).add_routes(app)
ScheduleApi(svcs).add_routes(app)
PayloadApi(svcs).add_routes(app)
return app

async def initialize():
Expand Down Expand Up @@ -392,6 +395,10 @@ async def initialize():
)
app_svc.application.middlewares.append(apispec_request_validation_middleware)
app_svc.application.middlewares.append(validation_middleware)
templates = ['plugins/%s/templates' % p.lower() for p in app_svc.get_config('plugins')]
templates.append('plugins/magma/dist')
templates.append("templates")
aiohttp_jinja2.setup(app_svc.application, loader=jinja2.FileSystemLoader(templates))
return app_svc

app_svc = await initialize()
Expand Down
6 changes: 4 additions & 2 deletions tests/objects/test_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_link_knowledge_svc_synchronization(self, event_loop, executor, ability,
knowledge_base_r = event_loop.run_until_complete(knowledge_svc.get_relationships(dict(edge='has_admin')))
assert len(knowledge_base_r) == 1

def test_create_relationship_source_fact(self, event_loop, ability, executor, operation, knowledge_svc, fire_event_mock):
def test_create_relationship_source_fact(self, event_loop, ability, executor, operation, data_svc, knowledge_svc, fire_event_mock):
test_executor = executor(name='psh', platform='windows')
test_ability = ability(ability_id='123', executors=[test_executor])
fact1 = Fact(trait='remote.host.fqdn', value='dc')
Expand All @@ -149,6 +149,7 @@ def test_create_relationship_source_fact(self, event_loop, ability, executor, op
adversary=Adversary(name='sample', adversary_id='XYZ', atomic_ordering=[],
description='test'),
source=Source(id='test-source', facts=[fact1]))
event_loop.run_until_complete(data_svc.store(operation.source))
event_loop.run_until_complete(operation._init_source())
event_loop.run_until_complete(link1.create_relationships([relationship], operation))

Expand All @@ -161,7 +162,7 @@ def test_create_relationship_source_fact(self, event_loop, ability, executor, op
assert len(fact_store_operation) == 1
assert len(fact_store_operation_source[0].collected_by) == 2

def test_save_discover_seeded_fact_not_in_command(self, event_loop, ability, executor, operation, knowledge_svc, fire_event_mock):
def test_save_discover_seeded_fact_not_in_command(self, event_loop, ability, executor, operation, knowledge_svc, data_svc, fire_event_mock):
test_executor = executor(name='psh', platform='windows')
test_ability = ability(ability_id='123', executors=[test_executor])
fact1 = Fact(trait='remote.host.fqdn', value='dc')
Expand All @@ -172,6 +173,7 @@ def test_save_discover_seeded_fact_not_in_command(self, event_loop, ability, exe
adversary=Adversary(name='sample', adversary_id='XYZ', atomic_ordering=[],
description='test'),
source=Source(id='test-source', facts=[fact1, fact2]))
event_loop.run_until_complete(data_svc.store(operation.source))
event_loop.run_until_complete(operation._init_source())
event_loop.run_until_complete(link.save_fact(operation, fact2, 1, relationship))

Expand Down
1 change: 1 addition & 0 deletions tests/objects/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def test_without_learning_parser(self, event_loop, app_svc, contact_svc, data_sv

def test_facts(self, event_loop, app_svc, contact_svc, file_svc, data_svc, learning_svc, fire_event_mock,
op_with_learning_and_seeded, make_test_link, make_test_result, knowledge_svc):
event_loop.run_until_complete(data_svc.store(op_with_learning_and_seeded.source))
test_link = make_test_link(9876)
op_with_learning_and_seeded.add_link(test_link)

Expand Down
7 changes: 1 addition & 6 deletions tests/web_server/test_core_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ async def test_home(aiohttp_client):
assert resp.content_type == 'text/html'


async def test_access_denied(aiohttp_client):
resp = await aiohttp_client.get('/enter')
assert resp.status == HTTPStatus.UNAUTHORIZED


async def test_login(aiohttp_client):
resp = await aiohttp_client.post('/enter', allow_redirects=False, data=dict(username='admin', password='admin'))
assert resp.status == HTTPStatus.FOUND
Expand Down Expand Up @@ -152,7 +147,7 @@ async def handle_login_redirect(self, request, **kwargs):
assert resp.status == HTTPStatus.UNAUTHORIZED
assert await resp.text() == 'Automatic rejection'

resp = await aiohttp_client.get('/', allow_redirects=False)
resp = await aiohttp_client.get('/api/v2', allow_redirects=False)
assert resp.status == HTTPStatus.UNAUTHORIZED
assert await resp.text() == 'Automatic rejection'

Expand Down

0 comments on commit 06042d6

Please sign in to comment.