diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index db6de2f06..a255370b9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -185,7 +185,13 @@ jobs: PYTHONUNBUFFERED=true python -m unittest --verbose - name: Stop development server run: | + docker compose logs app > ${{ runner.temp }}/iriswebapp_app.log docker compose down + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: Test API iriswebapp_app logs + path: ${{ runner.temp }}/iriswebapp_app.log test-database-migration: name: Database migration tests diff --git a/source/app/blueprints/access_controls.py b/source/app/blueprints/access_controls.py index ded068a1a..d39c66583 100644 --- a/source/app/blueprints/access_controls.py +++ b/source/app/blueprints/access_controls.py @@ -48,7 +48,7 @@ from app.datamgmt.manage.manage_access_control_db import user_has_client_access from app.datamgmt.manage.manage_users_db import get_user from app.iris_engine.access_control.iris_user import iris_current_user -from app.iris_engine.access_control.utils import ac_fast_check_user_has_case_access +from app.business.access_controls import ac_fast_check_user_has_case_access from app.iris_engine.access_control.utils import ac_get_effective_permissions_of_user from app.iris_engine.utils.tracker import track_activity from app.models.authorization import Permissions diff --git a/source/app/blueprints/graphql/permissions.py b/source/app/blueprints/graphql/permissions.py index b86dcb2cc..0977e47ad 100644 --- a/source/app/blueprints/graphql/permissions.py +++ b/source/app/blueprints/graphql/permissions.py @@ -25,7 +25,7 @@ from app.blueprints.access_controls import get_case_access_from_api from app.iris_engine.access_control.iris_user import iris_current_user from app.iris_engine.access_control.utils import ac_get_effective_permissions_of_user -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access class PermissionDeniedError(Exception): diff --git a/source/app/blueprints/pages/manage/manage_cases_routes.py b/source/app/blueprints/pages/manage/manage_cases_routes.py index 7ca380dbf..af7f52a52 100644 --- a/source/app/blueprints/pages/manage/manage_cases_routes.py +++ b/source/app/blueprints/pages/manage/manage_cases_routes.py @@ -34,7 +34,7 @@ from app.datamgmt.manage.manage_cases_db import get_case_protagonists from app.datamgmt.manage.manage_common import get_severities_list from app.forms import AddCaseForm -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.iris_engine.access_control.utils import ac_current_user_has_permission from app.models.authorization import CaseAccessLevel from app.models.authorization import Permissions diff --git a/source/app/blueprints/rest/alerts_routes.py b/source/app/blueprints/rest/alerts_routes.py index a14759010..d92795d47 100644 --- a/source/app/blueprints/rest/alerts_routes.py +++ b/source/app/blueprints/rest/alerts_routes.py @@ -1019,6 +1019,7 @@ def alert_comment_edit(alert_id, com_id): @alerts_rest_blueprint.route('/alerts//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/alerts/{alert_identifier}/comments') @ac_api_requires(Permissions.alerts_write) def case_comment_add(alert_id): """ diff --git a/source/app/blueprints/rest/case/case_assets_routes.py b/source/app/blueprints/rest/case/case_assets_routes.py index f98bac59a..484867108 100644 --- a/source/app/blueprints/rest/case/case_assets_routes.py +++ b/source/app/blueprints/rest/case/case_assets_routes.py @@ -47,7 +47,7 @@ from app.datamgmt.manage.manage_attribute_db import get_default_custom_attributes from app.datamgmt.manage.manage_users_db import get_user_cases_fast from app.datamgmt.states import get_assets_state -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.utils.tracker import track_activity from app.models.models import AnalysisStatus @@ -349,6 +349,7 @@ def case_comment_asset_list(cur_id, caseid): @case_assets_rest_blueprint.route('/case/assets//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/assets/{asset_identifier}/comments') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_comment_asset_add(cur_id, caseid): diff --git a/source/app/blueprints/rest/case/case_evidences_routes.py b/source/app/blueprints/rest/case/case_evidences_routes.py index f3475cf16..e67637757 100644 --- a/source/app/blueprints/rest/case/case_evidences_routes.py +++ b/source/app/blueprints/rest/case/case_evidences_routes.py @@ -152,6 +152,7 @@ def case_comment_evidence_list(cur_id, caseid): @case_evidences_rest_blueprint.route('/case/evidences//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/evidences/{evidence_identifier}/comments') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_comment_evidence_add(cur_id, caseid): diff --git a/source/app/blueprints/rest/case/case_ioc_routes.py b/source/app/blueprints/rest/case/case_ioc_routes.py index 789fd237e..fb4fd7d3d 100644 --- a/source/app/blueprints/rest/case/case_ioc_routes.py +++ b/source/app/blueprints/rest/case/case_ioc_routes.py @@ -45,7 +45,7 @@ from app.datamgmt.case.case_iocs_db import get_tlps_dict from app.datamgmt.manage.manage_attribute_db import get_default_custom_attributes from app.datamgmt.states import get_ioc_state -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.utils.tracker import track_activity from app.models.authorization import CaseAccessLevel @@ -259,6 +259,7 @@ def case_comment_ioc_list(cur_id, caseid): @case_ioc_rest_blueprint.route('/case/ioc//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/iocs/{ioc_identifier}/comments') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_comment_ioc_add(cur_id, caseid): diff --git a/source/app/blueprints/rest/case/case_notes_routes.py b/source/app/blueprints/rest/case/case_notes_routes.py index 2c13dd83f..43f879c5f 100644 --- a/source/app/blueprints/rest/case/case_notes_routes.py +++ b/source/app/blueprints/rest/case/case_notes_routes.py @@ -404,6 +404,7 @@ def case_comment_note_list(cur_id, caseid): @case_notes_rest_blueprint.route('/case/notes//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/notes/{note_identifier}/comments') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_comment_note_add(cur_id, caseid): diff --git a/source/app/blueprints/rest/case/case_routes.py b/source/app/blueprints/rest/case/case_routes.py index bd180364e..f91b944fe 100644 --- a/source/app/blueprints/rest/case/case_routes.py +++ b/source/app/blueprints/rest/case/case_routes.py @@ -38,9 +38,8 @@ from app.datamgmt.manage.manage_groups_db import get_group_with_members from app.datamgmt.manage.manage_users_db import get_user from app.datamgmt.manage.manage_users_db import get_users_list_restricted_from_case -from app.datamgmt.manage.manage_users_db import set_user_case_access +from app.business.access_controls import set_user_case_access, ac_fast_check_user_has_case_access from app.business.cases import cases_export_to_json -from app.iris_engine.access_control.utils import ac_fast_check_user_has_case_access from app.iris_engine.access_control.utils import ac_set_case_access_for_users from app.iris_engine.utils.tracker import track_activity from app.models.models import CaseStatus @@ -243,21 +242,29 @@ def user_cac_set_case(caseid): try: - success, logs = set_user_case_access(user.id, data.get('case_id'), data.get('access_level')) + case_identifier = data.get('case_id') + access_level = data.get('access_level') + + if user.id is None or type(user.id) is not int: + return response_error('Invalid user id') + if case_identifier is None or type(case_identifier) is not int: + return response_error('Invalid case id') + if access_level is None or type(access_level) is not int: + return response_error('Invalid access level') + if CaseAccessLevel.has_value(access_level) is False: + return response_error('Invalid access level') + + set_user_case_access(user.id, case_identifier, access_level) track_activity('case access set to {} for user {}'.format(data.get('access_level'), user.name), caseid) add_obj_history_entry(case, 'access changed to {} for user {}'.format(data.get('access_level'), user.name)) db.session.commit() + return response_success(msg=f'Case access set to {access_level} for user {user.id}') except Exception as e: log.error(f'Error while setting case access for user: {e}') log.error(traceback.format_exc()) - return response_error(msg=str(e)) - - if success: - return response_success(msg=logs) - - return response_error(msg=logs) + return response_error(str(e)) @case_rest_blueprint.route('/case/update-status', methods=['POST']) diff --git a/source/app/blueprints/rest/case/case_tasks_routes.py b/source/app/blueprints/rest/case/case_tasks_routes.py index 6856e6815..e0b586fa4 100644 --- a/source/app/blueprints/rest/case/case_tasks_routes.py +++ b/source/app/blueprints/rest/case/case_tasks_routes.py @@ -177,6 +177,7 @@ def case_comment_task_list(cur_id: int, caseid: int): @case_tasks_rest_blueprint.route('/case/tasks//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/tasks/{task_identifier}/comments') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_comment_task_add(cur_id: int, caseid: int): diff --git a/source/app/blueprints/rest/case/case_timeline_routes.py b/source/app/blueprints/rest/case/case_timeline_routes.py index d8500bc3a..c5cf17b40 100644 --- a/source/app/blueprints/rest/case/case_timeline_routes.py +++ b/source/app/blueprints/rest/case/case_timeline_routes.py @@ -125,6 +125,7 @@ def case_comment_edit(cur_id, com_id, caseid): @case_timeline_rest_blueprint.route('/case/timeline/events//comments/add', methods=['POST']) +@endpoint_deprecated('POST', '/api/v2/events/{event_identifier}/comments') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_comment_add(cur_id, caseid): diff --git a/source/app/blueprints/rest/case_comments.py b/source/app/blueprints/rest/case_comments.py index ef77df391..816e38144 100644 --- a/source/app/blueprints/rest/case_comments.py +++ b/source/app/blueprints/rest/case_comments.py @@ -25,6 +25,7 @@ from app.blueprints.responses import response_success from app.business.comments import comments_update_for_case from app.business.errors import BusinessProcessingError +from app.iris_engine.access_control.iris_user import iris_current_user def case_comment_update(comment_id, object_type, caseid): @@ -32,7 +33,7 @@ def case_comment_update(comment_id, object_type, caseid): comment_schema = CommentSchema() rq_t = request.get_json() comment_text = rq_t.get('comment_text') - comment = comments_update_for_case(comment_text, comment_id, object_type, caseid) + comment = comments_update_for_case(iris_current_user, comment_text, comment_id, object_type, caseid) return response_success("Comment edited", data=comment_schema.dump(comment)) except BusinessProcessingError as e: return response_error(e.get_message(), data=e.get_data()) diff --git a/source/app/blueprints/rest/manage/manage_assets_routes.py b/source/app/blueprints/rest/manage/manage_assets_routes.py index 595af98de..72c418423 100644 --- a/source/app/blueprints/rest/manage/manage_assets_routes.py +++ b/source/app/blueprints/rest/manage/manage_assets_routes.py @@ -21,7 +21,7 @@ from werkzeug import Response from app.datamgmt.manage.manage_assets_db import get_filtered_assets -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel from app.schema.marshables import CaseAssetsSchema from app.blueprints.access_controls import ac_api_requires diff --git a/source/app/blueprints/rest/manage/manage_cases_routes.py b/source/app/blueprints/rest/manage/manage_cases_routes.py index ac13a5197..9c162f600 100644 --- a/source/app/blueprints/rest/manage/manage_cases_routes.py +++ b/source/app/blueprints/rest/manage/manage_cases_routes.py @@ -38,7 +38,7 @@ from app.datamgmt.manage.manage_cases_db import get_case_details_rt from app.datamgmt.manage.manage_cases_db import list_cases_dict from app.datamgmt.manage.manage_cases_db import reopen_case -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.module_handler.module_handler import configure_module_on_init from app.iris_engine.module_handler.module_handler import instantiate_module_from_name diff --git a/source/app/blueprints/rest/manage/manage_groups.py b/source/app/blueprints/rest/manage/manage_groups.py index 73c0eab45..1b2ba6945 100644 --- a/source/app/blueprints/rest/manage/manage_groups.py +++ b/source/app/blueprints/rest/manage/manage_groups.py @@ -37,10 +37,9 @@ from app.datamgmt.manage.manage_groups_db import update_group_members from app.datamgmt.manage.manage_users_db import get_user from app.iris_engine.access_control.utils import ac_ldp_group_removal -from app.iris_engine.access_control.utils import ac_flag_match_mask from app.iris_engine.access_control.utils import ac_ldp_group_update from app.iris_engine.access_control.utils import ac_recompute_effective_ac_from_users_list -from app.models.authorization import Permissions +from app.models.authorization import Permissions, ac_flag_match_mask from app.schema.marshables import AuthorizationGroupSchema from app.blueprints.access_controls import ac_api_requires from app.blueprints.access_controls import ac_api_return_access_denied diff --git a/source/app/blueprints/rest/search_routes.py b/source/app/blueprints/rest/search_routes.py index 18887a0ca..b30125c83 100644 --- a/source/app/blueprints/rest/search_routes.py +++ b/source/app/blueprints/rest/search_routes.py @@ -21,7 +21,7 @@ from sqlalchemy import and_ from app.iris_engine.utils.tracker import track_activity -from app.models.models import Comments +from app.models.comments import Comments from app.models.authorization import Permissions from app.models.cases import Cases from app.models.models import Client diff --git a/source/app/blueprints/rest/v2/alerts_routes/comments.py b/source/app/blueprints/rest/v2/alerts_routes/comments.py index caf36ef02..753edc1b0 100644 --- a/source/app/blueprints/rest/v2/alerts_routes/comments.py +++ b/source/app/blueprints/rest/v2/alerts_routes/comments.py @@ -18,14 +18,18 @@ from flask import Blueprint from flask import request +from marshmallow.exceptions import ValidationError from app.blueprints.access_controls import ac_api_requires from app.models.authorization import Permissions from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.parsing import parse_pagination_parameters from app.schema.marshables import CommentSchema from app.business.comments import comments_get_filtered_by_alert +from app.business.comments import comments_create_for_alert from app.iris_engine.access_control.iris_user import iris_current_user from app.business.errors import ObjectNotFoundError @@ -35,7 +39,7 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() - def get(self, alert_identifier): + def search(self, alert_identifier): pagination_parameters = parse_pagination_parameters(request) try: comments = comments_get_filtered_by_alert(iris_current_user, alert_identifier, pagination_parameters) @@ -43,6 +47,17 @@ def get(self, alert_identifier): except ObjectNotFoundError: return response_api_not_found() + def create(self, alert_identifier): + try: + comment = self._schema.load(request.get_json()) + comments_create_for_alert(iris_current_user, comment, alert_identifier) + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + alerts_comments_blueprint = Blueprint('alerts_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -51,4 +66,10 @@ def get(self, alert_identifier): @alerts_comments_blueprint.get('') @ac_api_requires(Permissions.alerts_read) def get_alerts_comments(alert_identifier): - return comments_operations.get(alert_identifier) + return comments_operations.search(alert_identifier) + + +@alerts_comments_blueprint.post('') +@ac_api_requires(Permissions.alerts_write) +def create_alerts_comment(alert_identifier): + return comments_operations.create(alert_identifier) diff --git a/source/app/blueprints/rest/v2/assets.py b/source/app/blueprints/rest/v2/assets.py index f7069c910..d695e6b3c 100644 --- a/source/app/blueprints/rest/v2/assets.py +++ b/source/app/blueprints/rest/v2/assets.py @@ -28,7 +28,7 @@ from app.business.assets import assets_get from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel from app.schema.marshables import CaseAssetsSchema from app.blueprints.access_controls import ac_api_return_access_denied diff --git a/source/app/blueprints/rest/v2/assets_routes/comments.py b/source/app/blueprints/rest/v2/assets_routes/comments.py index cd80375b8..42b721c53 100644 --- a/source/app/blueprints/rest/v2/assets_routes/comments.py +++ b/source/app/blueprints/rest/v2/assets_routes/comments.py @@ -18,17 +18,21 @@ from flask import Blueprint from flask import request +from marshmallow.exceptions import ValidationError from app.blueprints.access_controls import ac_api_requires from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.parsing import parse_pagination_parameters -from app.blueprints.access_controls import ac_api_return_access_denied from app.business.comments import comments_get_filtered_by_asset +from app.business.comments import comments_create_for_asset from app.business.assets import assets_get from app.business.errors import ObjectNotFoundError from app.schema.marshables import CommentSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access +from app.iris_engine.access_control.iris_user import iris_current_user from app.models.authorization import CaseAccessLevel @@ -37,20 +41,37 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() + @staticmethod + def _get_asset(asset_identifier, possible_case_access_levels): + asset = assets_get(asset_identifier) + if not ac_fast_check_current_user_has_case_access(asset.case_id, possible_case_access_levels): + raise ObjectNotFoundError() + return asset + def get(self, asset_identifier): try: - asset = assets_get(asset_identifier) - if not ac_fast_check_current_user_has_case_access(asset.case_id, - [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=asset.case_id) + asset = self._get_asset(asset_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]) pagination_parameters = parse_pagination_parameters(request) - comments = comments_get_filtered_by_asset(asset_identifier, pagination_parameters) + comments = comments_get_filtered_by_asset(asset, pagination_parameters) return response_api_paginated(self._schema, comments) except ObjectNotFoundError: return response_api_not_found() + def create(self, asset_identifier): + try: + asset = self._get_asset(asset_identifier, [CaseAccessLevel.full_access]) + comment = self._schema.load(request.get_json()) + comments_create_for_asset(iris_current_user, asset, comment) + + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + assets_comments_blueprint = Blueprint('assets_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -60,3 +81,9 @@ def get(self, asset_identifier): @ac_api_requires() def get_assets_comments(asset_identifier): return comments_operations.get(asset_identifier) + + +@assets_comments_blueprint.post('') +@ac_api_requires() +def create_assets_comment(asset_identifier): + return comments_operations.create(asset_identifier) diff --git a/source/app/blueprints/rest/v2/case_objects/assets.py b/source/app/blueprints/rest/v2/case_objects/assets.py index aa431a7fa..f55b1fffc 100644 --- a/source/app/blueprints/rest/v2/case_objects/assets.py +++ b/source/app/blueprints/rest/v2/case_objects/assets.py @@ -38,7 +38,7 @@ from app.business.assets import assets_delete from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.iris_engine.module_handler.module_handler import call_deprecated_on_preload_modules_hook from app.models.authorization import CaseAccessLevel from app.schema.marshables import CaseAssetsSchema diff --git a/source/app/blueprints/rest/v2/case_objects/events.py b/source/app/blueprints/rest/v2/case_objects/events.py index ee7ac7445..9167724e5 100644 --- a/source/app/blueprints/rest/v2/case_objects/events.py +++ b/source/app/blueprints/rest/v2/case_objects/events.py @@ -36,7 +36,7 @@ from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError from app.business.cases import cases_exists -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.iris_engine.utils.collab import notify from app.models.authorization import CaseAccessLevel from app.iris_engine.module_handler.module_handler import call_deprecated_on_preload_modules_hook diff --git a/source/app/blueprints/rest/v2/case_objects/evidences.py b/source/app/blueprints/rest/v2/case_objects/evidences.py index f66aabed1..c1e2cea6f 100644 --- a/source/app/blueprints/rest/v2/case_objects/evidences.py +++ b/source/app/blueprints/rest/v2/case_objects/evidences.py @@ -20,7 +20,7 @@ from flask import request from app.blueprints.access_controls import ac_api_requires -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError diff --git a/source/app/blueprints/rest/v2/case_objects/iocs.py b/source/app/blueprints/rest/v2/case_objects/iocs.py index 2ca210465..c18518bab 100644 --- a/source/app/blueprints/rest/v2/case_objects/iocs.py +++ b/source/app/blueprints/rest/v2/case_objects/iocs.py @@ -35,7 +35,7 @@ from app.business.iocs import iocs_delete from app.business.iocs import iocs_update from app.datamgmt.case.case_iocs_db import get_filtered_iocs -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel from app.schema.marshables import IocSchemaForAPIV2 from app.blueprints.access_controls import ac_api_return_access_denied diff --git a/source/app/blueprints/rest/v2/case_objects/notes.py b/source/app/blueprints/rest/v2/case_objects/notes.py index c27515940..afdfa207e 100644 --- a/source/app/blueprints/rest/v2/case_objects/notes.py +++ b/source/app/blueprints/rest/v2/case_objects/notes.py @@ -22,7 +22,7 @@ from app.blueprints.access_controls import ac_api_requires from app.blueprints.access_controls import ac_api_return_access_denied -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.blueprints.rest.endpoints import response_api_created from app.blueprints.rest.endpoints import response_api_success from app.blueprints.rest.endpoints import response_api_deleted diff --git a/source/app/blueprints/rest/v2/case_objects/notes_directories.py b/source/app/blueprints/rest/v2/case_objects/notes_directories.py index 2bbfc801d..cc326fcc8 100644 --- a/source/app/blueprints/rest/v2/case_objects/notes_directories.py +++ b/source/app/blueprints/rest/v2/case_objects/notes_directories.py @@ -35,7 +35,7 @@ from app.business.notes_directories import notes_directories_update from app.business.notes_directories import notes_directories_delete from app.business.cases import cases_exists -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel diff --git a/source/app/blueprints/rest/v2/case_objects/tasks.py b/source/app/blueprints/rest/v2/case_objects/tasks.py index 500fd8242..5b98a107f 100644 --- a/source/app/blueprints/rest/v2/case_objects/tasks.py +++ b/source/app/blueprints/rest/v2/case_objects/tasks.py @@ -37,7 +37,7 @@ from app.business.tasks import tasks_delete from app.business.tasks import tasks_filter from app.models.authorization import CaseAccessLevel -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access case_tasks_blueprint = Blueprint('case_tasks', __name__, diff --git a/source/app/blueprints/rest/v2/cases.py b/source/app/blueprints/rest/v2/cases.py index c896a253b..a04611a9d 100644 --- a/source/app/blueprints/rest/v2/cases.py +++ b/source/app/blueprints/rest/v2/cases.py @@ -45,7 +45,7 @@ from app.datamgmt.manage.manage_cases_db import get_filtered_cases from app.schema.marshables import CaseSchemaForAPIV2 from app.blueprints.access_controls import ac_api_requires -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.blueprints.access_controls import ac_api_return_access_denied from app.models.authorization import Permissions from app.models.authorization import CaseAccessLevel diff --git a/source/app/blueprints/rest/v2/events_routes/comments.py b/source/app/blueprints/rest/v2/events_routes/comments.py index 54a273ff4..52ad385e1 100644 --- a/source/app/blueprints/rest/v2/events_routes/comments.py +++ b/source/app/blueprints/rest/v2/events_routes/comments.py @@ -18,17 +18,22 @@ from flask import Blueprint from flask import request +from marshmallow import ValidationError from app.blueprints.access_controls import ac_api_requires from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.parsing import parse_pagination_parameters -from app.blueprints.access_controls import ac_api_return_access_denied +from app.iris_engine.access_control.iris_user import iris_current_user from app.business.comments import comments_get_filtered_by_event +from app.business.comments import comments_create_for_event from app.business.events import events_get from app.business.errors import ObjectNotFoundError +from app.models.cases import CasesEvent from app.schema.marshables import CommentSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel @@ -37,20 +42,37 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() + @staticmethod + def _get_event(event_identifier, possible_case_access_levels) -> CasesEvent: + event = events_get(event_identifier) + if not ac_fast_check_current_user_has_case_access(event.case_id, possible_case_access_levels): + raise ObjectNotFoundError() + return event + def get(self, event_identifier): try: - event = events_get(event_identifier) - if not ac_fast_check_current_user_has_case_access(event.case_id, - [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=event.case_id) + event = self._get_event(event_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]) pagination_parameters = parse_pagination_parameters(request) - comments = comments_get_filtered_by_event(event_identifier, pagination_parameters) + comments = comments_get_filtered_by_event(event, pagination_parameters) return response_api_paginated(self._schema, comments) except ObjectNotFoundError: return response_api_not_found() + def create(self, event_identifier): + try: + event = self._get_event(event_identifier, [CaseAccessLevel.full_access]) + comment = self._schema.load(request.get_json()) + comments_create_for_event(iris_current_user, event, comment) + + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + events_comments_blueprint = Blueprint('events_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -60,3 +82,9 @@ def get(self, event_identifier): @ac_api_requires() def get_event_comments(event_identifier): return comments_operations.get(event_identifier) + + +@events_comments_blueprint.post('') +@ac_api_requires() +def create_event_comment(event_identifier): + return comments_operations.create(event_identifier) diff --git a/source/app/blueprints/rest/v2/evidences_routes/comments.py b/source/app/blueprints/rest/v2/evidences_routes/comments.py index ca09bb033..4c56817ec 100644 --- a/source/app/blueprints/rest/v2/evidences_routes/comments.py +++ b/source/app/blueprints/rest/v2/evidences_routes/comments.py @@ -18,17 +18,22 @@ from flask import Blueprint from flask import request +from marshmallow import ValidationError from app.blueprints.access_controls import ac_api_requires from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error +from app.iris_engine.access_control.iris_user import iris_current_user from app.blueprints.rest.parsing import parse_pagination_parameters -from app.blueprints.access_controls import ac_api_return_access_denied from app.business.comments import comments_get_filtered_by_evidence +from app.business.comments import comments_create_for_evidence +from app.models.models import CaseReceivedFile from app.business.evidences import evidences_get from app.business.errors import ObjectNotFoundError from app.schema.marshables import CommentSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel @@ -37,20 +42,38 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() + @staticmethod + def _get_evidence(evidence_identifier, possible_case_access_levels) -> CaseReceivedFile: + evidence = evidences_get(evidence_identifier) + if not ac_fast_check_current_user_has_case_access(evidence.case_id, possible_case_access_levels): + raise ObjectNotFoundError() + return evidence + def get(self, evidence_identifier): try: - evidence = evidences_get(evidence_identifier) - if not ac_fast_check_current_user_has_case_access(evidence.case_id, - [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=evidence.case_id) + evidence = self._get_evidence(evidence_identifier, + [CaseAccessLevel.read_only, CaseAccessLevel.full_access]) pagination_parameters = parse_pagination_parameters(request) - comments = comments_get_filtered_by_evidence(evidence_identifier, pagination_parameters) + comments = comments_get_filtered_by_evidence(evidence, pagination_parameters) return response_api_paginated(self._schema, comments) except ObjectNotFoundError: return response_api_not_found() + def create(self, evidence_identifier): + try: + evidence = self._get_evidence(evidence_identifier, [CaseAccessLevel.full_access]) + comment = self._schema.load(request.get_json()) + comments_create_for_evidence(iris_current_user, evidence, comment) + + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + evidences_comments_blueprint = Blueprint('evidences_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -60,3 +83,9 @@ def get(self, evidence_identifier): @ac_api_requires() def get_evidences_comments(evidence_identifier): return comments_operations.get(evidence_identifier) + + +@evidences_comments_blueprint.post('') +@ac_api_requires() +def create_evidences_comment(evidence_identifier): + return comments_operations.create(evidence_identifier) diff --git a/source/app/blueprints/rest/v2/iocs.py b/source/app/blueprints/rest/v2/iocs.py index 52ecec084..f6998b6d8 100644 --- a/source/app/blueprints/rest/v2/iocs.py +++ b/source/app/blueprints/rest/v2/iocs.py @@ -29,7 +29,7 @@ from app.business.iocs import iocs_update from app.business.iocs import iocs_delete from app.business.iocs import iocs_get -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel from app.schema.marshables import IocSchemaForAPIV2 from app.blueprints.access_controls import ac_api_return_access_denied diff --git a/source/app/blueprints/rest/v2/iocs_routes/comments.py b/source/app/blueprints/rest/v2/iocs_routes/comments.py index 4f91bcecf..f12e3e3a5 100644 --- a/source/app/blueprints/rest/v2/iocs_routes/comments.py +++ b/source/app/blueprints/rest/v2/iocs_routes/comments.py @@ -18,17 +18,21 @@ from flask import Blueprint from flask import request +from marshmallow import ValidationError +from app.iris_engine.access_control.iris_user import iris_current_user from app.blueprints.access_controls import ac_api_requires from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.parsing import parse_pagination_parameters -from app.blueprints.access_controls import ac_api_return_access_denied from app.business.comments import comments_get_filtered_by_ioc +from app.business.comments import comments_create_for_ioc from app.business.iocs import iocs_get from app.business.errors import ObjectNotFoundError from app.schema.marshables import CommentSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel @@ -37,20 +41,37 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() + @staticmethod + def _get_ioc(ioc_identifier, possible_case_access_levels): + ioc = iocs_get(ioc_identifier) + if not ac_fast_check_current_user_has_case_access(ioc.case_id, possible_case_access_levels): + raise ObjectNotFoundError() + return ioc + def get(self, ioc_identifier): try: - ioc = iocs_get(ioc_identifier) - if not ac_fast_check_current_user_has_case_access(ioc.case_id, - [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=ioc.case_id) + ioc = self._get_ioc(ioc_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]) pagination_parameters = parse_pagination_parameters(request) - - comments = comments_get_filtered_by_ioc(ioc_identifier, pagination_parameters) + comments = comments_get_filtered_by_ioc(ioc, pagination_parameters) return response_api_paginated(self._schema, comments) except ObjectNotFoundError: return response_api_not_found() + def create(self, ioc_identifier): + try: + ioc = self._get_ioc(ioc_identifier, [CaseAccessLevel.full_access]) + + comment = self._schema.load(request.get_json()) + comments_create_for_ioc(iris_current_user, ioc, comment) + + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + iocs_comments_blueprint = Blueprint('iocs_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -60,3 +81,9 @@ def get(self, ioc_identifier): @ac_api_requires() def get_iocs_comments(ioc_identifier): return comments_operations.get(ioc_identifier) + + +@iocs_comments_blueprint.post('') +@ac_api_requires() +def create_iocs_comment(ioc_identifier): + return comments_operations.create(ioc_identifier) diff --git a/source/app/blueprints/rest/v2/manage_routes/groups.py b/source/app/blueprints/rest/v2/manage_routes/groups.py index 1ce0dfca9..6968678e8 100644 --- a/source/app/blueprints/rest/v2/manage_routes/groups.py +++ b/source/app/blueprints/rest/v2/manage_routes/groups.py @@ -31,11 +31,10 @@ from app.business.groups import groups_get from app.business.groups import groups_update from app.business.groups import groups_delete -from app.models.authorization import Permissions +from app.models.authorization import Permissions, ac_flag_match_mask from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError from app.iris_engine.access_control.iris_user import iris_current_user -from app.iris_engine.access_control.utils import ac_flag_match_mask from app.iris_engine.access_control.utils import ac_ldp_group_update diff --git a/source/app/blueprints/rest/v2/notes_routes/comments.py b/source/app/blueprints/rest/v2/notes_routes/comments.py index cccbec537..aefba9824 100644 --- a/source/app/blueprints/rest/v2/notes_routes/comments.py +++ b/source/app/blueprints/rest/v2/notes_routes/comments.py @@ -18,17 +18,21 @@ from flask import Blueprint from flask import request +from marshmallow import ValidationError +from app.iris_engine.access_control.iris_user import iris_current_user from app.blueprints.access_controls import ac_api_requires from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.parsing import parse_pagination_parameters -from app.blueprints.access_controls import ac_api_return_access_denied from app.business.comments import comments_get_filtered_by_note +from app.business.comments import comments_create_for_note from app.business.notes import notes_get from app.business.errors import ObjectNotFoundError from app.schema.marshables import CommentSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel @@ -37,20 +41,38 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() + @staticmethod + def _get_note(note_identifier, possible_case_access_levels): + note = notes_get(note_identifier) + if not ac_fast_check_current_user_has_case_access(note.note_case_id, possible_case_access_levels): + raise ObjectNotFoundError() + return note + def get(self, note_identifier): try: - note = notes_get(note_identifier) - if not ac_fast_check_current_user_has_case_access(note.note_case_id, - [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=note.note_case_id) + note = self._get_note(note_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]) pagination_parameters = parse_pagination_parameters(request) - comments = comments_get_filtered_by_note(note_identifier, pagination_parameters) + comments = comments_get_filtered_by_note(note, pagination_parameters) return response_api_paginated(self._schema, comments) except ObjectNotFoundError: return response_api_not_found() + def create(self, note_identifier): + try: + note = self._get_note(note_identifier, [CaseAccessLevel.full_access]) + + comment = self._schema.load(request.get_json()) + comments_create_for_note(iris_current_user, note, comment) + + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + notes_comments_blueprint = Blueprint('notes_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -60,3 +82,9 @@ def get(self, note_identifier): @ac_api_requires() def get_notes_comments(note_identifier): return comments_operations.get(note_identifier) + + +@notes_comments_blueprint.post('') +@ac_api_requires() +def create_notes_comment(note_identifier): + return comments_operations.create(note_identifier) diff --git a/source/app/blueprints/rest/v2/tasks.py b/source/app/blueprints/rest/v2/tasks.py index 808b0188b..cac1e8754 100644 --- a/source/app/blueprints/rest/v2/tasks.py +++ b/source/app/blueprints/rest/v2/tasks.py @@ -30,7 +30,7 @@ from app.business.errors import BusinessProcessingError from app.models.authorization import CaseAccessLevel from app.schema.marshables import CaseTaskSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.blueprints.rest.v2.tasks_routes.comments import tasks_comments_blueprint diff --git a/source/app/blueprints/rest/v2/tasks_routes/comments.py b/source/app/blueprints/rest/v2/tasks_routes/comments.py index 74deba84a..dc1c89b09 100644 --- a/source/app/blueprints/rest/v2/tasks_routes/comments.py +++ b/source/app/blueprints/rest/v2/tasks_routes/comments.py @@ -18,17 +18,21 @@ from flask import Blueprint from flask import request +from marshmallow import ValidationError +from app.iris_engine.access_control.iris_user import iris_current_user from app.blueprints.access_controls import ac_api_requires from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_not_found from app.blueprints.rest.parsing import parse_pagination_parameters -from app.blueprints.access_controls import ac_api_return_access_denied +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_error from app.business.comments import comments_get_filtered_by_task +from app.business.comments import comments_create_for_task from app.business.tasks import tasks_get from app.business.errors import ObjectNotFoundError from app.schema.marshables import CommentSchema -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.access_controls import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel @@ -37,20 +41,38 @@ class CommentsOperations: def __init__(self): self._schema = CommentSchema() + @staticmethod + def _get_task(task_identifier, possible_case_access_levels): + task = tasks_get(task_identifier) + if not ac_fast_check_current_user_has_case_access(task.task_case_id, possible_case_access_levels): + raise ObjectNotFoundError() + return task + def get(self, task_identifier): try: - task = tasks_get(task_identifier) - if not ac_fast_check_current_user_has_case_access(task.task_case_id, - [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=task.task_case_id) + task = self._get_task(task_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]) pagination_parameters = parse_pagination_parameters(request) - comments = comments_get_filtered_by_task(task_identifier, pagination_parameters) + comments = comments_get_filtered_by_task(task, pagination_parameters) return response_api_paginated(self._schema, comments) except ObjectNotFoundError: return response_api_not_found() + def create(self, task_identifier): + try: + task = self._get_task(task_identifier, [CaseAccessLevel.full_access]) + + comment = self._schema.load(request.get_json()) + comments_create_for_task(iris_current_user, task, comment) + + result = self._schema.dump(comment) + return response_api_created(result) + except ValidationError as e: + return response_api_error('Data error', data=e.normalized_messages()) + except ObjectNotFoundError: + return response_api_not_found() + tasks_comments_blueprint = Blueprint('tasks_comments', __name__, url_prefix='//comments') comments_operations = CommentsOperations() @@ -60,3 +82,9 @@ def get(self, task_identifier): @ac_api_requires() def get_tasks_comments(task_identifier): return comments_operations.get(task_identifier) + + +@tasks_comments_blueprint.post('') +@ac_api_requires() +def create_tasks_comment(task_identifier): + return comments_operations.create(task_identifier) diff --git a/source/app/business/access_controls.py b/source/app/business/access_controls.py new file mode 100644 index 000000000..6cbf8865c --- /dev/null +++ b/source/app/business/access_controls.py @@ -0,0 +1,96 @@ +# IRIS Source Code +# Copyright (C) 2025 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from app import db +from app.datamgmt.manage.manage_access_control_db import get_case_effective_access, \ + remove_duplicate_user_case_effective_accesses, set_user_case_effective_access +from app.datamgmt.manage.manage_access_control_db import check_ua_case_client +from app.iris_engine.access_control.iris_user import iris_current_user +from app.logger import logger +from app.models.authorization import UserCaseAccess +from app.models.authorization import CaseAccessLevel +from app.models.authorization import ac_flag_match_mask + + +def ac_fast_check_current_user_has_case_access(cid, access_level): + return ac_fast_check_user_has_case_access(iris_current_user.id, cid, access_level) + + +def set_user_case_access(user_id, case_id, access_level): + + uca = UserCaseAccess.query.filter( + UserCaseAccess.user_id == user_id, + UserCaseAccess.case_id == case_id + ).all() + + if len(uca) > 1: + for u in uca: + db.session.delete(u) + db.session.commit() + uca = None + + if not uca: + uca = UserCaseAccess() + uca.user_id = user_id + uca.case_id = case_id + uca.access_level = access_level + db.session.add(uca) + else: + uca[0].access_level = access_level + + db.session.commit() + + set_case_effective_access_for_user(user_id, case_id, access_level) + + +def set_case_effective_access_for_user(user_id, case_id, access_level: int): + """ + Set a case access from a user + """ + + if remove_duplicate_user_case_effective_accesses(user_id, case_id): + logger.error(f'Multiple access found for user {user_id} and case {case_id}') + + set_user_case_effective_access(access_level, case_id, user_id) + + +def ac_fast_check_user_has_case_access(user_id, cid, expected_access_levels: list[CaseAccessLevel]): + """ + Checks the user has access to the case with at least one of the access_level + if the user has access, returns the access level of the user to the case + Returns None otherwise + """ + access_level = get_case_effective_access(user_id, cid) + + if not access_level: + # The user has no direct access, check if he is part of the client + access_level = check_ua_case_client(user_id, cid) + if not access_level: + return None + set_case_effective_access_for_user(user_id, cid, access_level) + + return access_level + + if ac_flag_match_mask(access_level, CaseAccessLevel.deny_all.value): + return None + + for acl in expected_access_levels: + if ac_flag_match_mask(access_level, acl.value): + return access_level + + return None diff --git a/source/app/business/alerts.py b/source/app/business/alerts.py index 224010fdb..c5eee164e 100644 --- a/source/app/business/alerts.py +++ b/source/app/business/alerts.py @@ -61,26 +61,26 @@ def alerts_create(alert: Alert, iocs: list[Ioc], assets: list[CaseAssets]) -> Al return alert -def alerts_get(current_user, identifier) -> Alert: +def _get(current_user, identifier): alert = get_alert_by_id(identifier) - if not alert: - raise ObjectNotFoundError() + return None if not user_has_client_access(current_user.id, alert.alert_customer_id): - raise ObjectNotFoundError() + return None + return alert + +def alerts_get(current_user, identifier) -> Alert: + alert = _get(current_user, identifier) + if not alert: + raise ObjectNotFoundError() return alert def alerts_exists(current_user, identifier) -> bool: - alert = get_alert_by_id(identifier) - - if not alert: - return False - if not user_has_client_access(current_user.id, alert.alert_customer_id): - return False + alert = _get(current_user, identifier) - return True + return alert is not None def alerts_update(alert: Alert, updated_alert: Alert, activity_data) -> Alert: diff --git a/source/app/business/comments.py b/source/app/business/comments.py index eaedeee31..4789422c9 100644 --- a/source/app/business/comments.py +++ b/source/app/business/comments.py @@ -15,12 +15,14 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + from datetime import datetime from flask_sqlalchemy.pagination import Pagination from app import db from app.business.alerts import alerts_exists +from app.business.alerts import alerts_get from app.business.errors import ObjectNotFoundError from app.business.errors import BusinessProcessingError from app.datamgmt.case.case_comments import get_case_comment @@ -31,11 +33,23 @@ from app.datamgmt.comments import get_filtered_note_comments from app.datamgmt.comments import get_filtered_task_comments from app.datamgmt.comments import get_filtered_event_comments -from app.iris_engine.access_control.iris_user import iris_current_user +from app.datamgmt.case.case_assets_db import add_comment_to_asset +from app.datamgmt.case.case_rfiles_db import add_comment_to_evidence +from app.datamgmt.case.case_iocs_db import add_comment_to_ioc +from app.datamgmt.case.case_notes_db import add_comment_to_note +from app.datamgmt.case.case_tasks_db import add_comment_to_task +from app.datamgmt.case.case_events_db import add_comment_to_event from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.utils.tracker import track_activity -from app.models.models import Comments +from app.models.comments import Comments +from app.models.models import CaseAssets +from app.models.models import CaseReceivedFile +from app.models.iocs import Ioc +from app.models.models import Notes +from app.models.models import CaseTasks +from app.models.cases import CasesEvent from app.models.pagination_parameters import PaginationParameters +from app.util import add_obj_history_entry def comments_get_filtered_by_alert(current_user, alert_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: @@ -45,38 +59,37 @@ def comments_get_filtered_by_alert(current_user, alert_identifier: int, paginati return get_filtered_alert_comments(alert_identifier, pagination_parameters) -def comments_get_filtered_by_asset(asset_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: - return get_filtered_asset_comments(asset_identifier, pagination_parameters) +def comments_get_filtered_by_asset(asset: CaseAssets, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_asset_comments(asset.asset_id, pagination_parameters) -def comments_get_filtered_by_evidence(evidence_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: - return get_filtered_evidence_comments(evidence_identifier, pagination_parameters) +def comments_get_filtered_by_evidence(evidence: CaseReceivedFile, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_evidence_comments(evidence.id, pagination_parameters) -def comments_get_filtered_by_ioc(ioc_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: - return get_filtered_ioc_comments(ioc_identifier, pagination_parameters) +def comments_get_filtered_by_ioc(ioc: Ioc, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_ioc_comments(ioc.ioc_id, pagination_parameters) -def comments_get_filtered_by_note(note_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: - return get_filtered_note_comments(note_identifier, pagination_parameters) +def comments_get_filtered_by_note(note: Notes, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_note_comments(note.note_id, pagination_parameters) -def comments_get_filtered_by_task(taks_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: - return get_filtered_task_comments(taks_identifier, pagination_parameters) +def comments_get_filtered_by_task(task: CaseTasks, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_task_comments(task.id, pagination_parameters) -def comments_get_filtered_by_event(event_identifier: int, pagination_parameters: PaginationParameters) -> Pagination: - return get_filtered_event_comments(event_identifier, pagination_parameters) +def comments_get_filtered_by_event(event: CasesEvent, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_event_comments(event.event_id, pagination_parameters) -def comments_update_for_case(comment_text, comment_id, object_type, caseid) -> Comments: +def comments_update_for_case(current_user, comment_text, comment_id, object_type, caseid) -> Comments: comment = get_case_comment(comment_id, caseid) if not comment: raise BusinessProcessingError('Invalid comment ID') - if hasattr(iris_current_user, 'id') and iris_current_user.id is not None: - if comment.comment_user_id != iris_current_user.id: - raise BusinessProcessingError('Permission denied') + if comment.comment_user_id != current_user.id: + raise BusinessProcessingError('Permission denied') comment.comment_text = comment_text comment.comment_update_date = datetime.utcnow() @@ -91,3 +104,128 @@ def comments_update_for_case(comment_text, comment_id, object_type, caseid) -> C track_activity(f"comment {comment.comment_id} on {object_type} edited", caseid=caseid) return comment + + +def comments_create_for_alert(current_user, comment: Comments, alert_identifier: int): + alert = alerts_get(current_user, alert_identifier) + comment.comment_alert_id = alert_identifier + comment.comment_user_id = current_user.id + comment.comment_date = datetime.now() + comment.comment_update_date = datetime.now() + + db.session.add(comment) + + add_obj_history_entry(alert, 'commented') + db.session.commit() + + hook_data = { + 'comment': comment, + 'alert': alert + } + call_modules_hook('on_postload_alert_commented', hook_data) + track_activity(f'alert "{alert.alert_id}" commented', ctx_less=True) + + +def comments_create_for_asset(current_user, asset: CaseAssets, comment: Comments): + _create_comment(current_user, comment, asset.case_id) + + add_comment_to_asset(asset.asset_id, comment.comment_id) + + db.session.commit() + + hook_data = { + 'comment': comment, + 'asset': asset + } + call_modules_hook('on_postload_asset_commented', data=hook_data, caseid=asset.case_id) + + track_activity(f'asset "{asset.asset_name}" commented', caseid=asset.case_id) + + +def comments_create_for_evidence(current_user, evidence: CaseReceivedFile, comment: Comments): + _create_comment(current_user, comment, evidence.case_id) + + add_comment_to_evidence(evidence.id, comment.comment_id) + + db.session.commit() + + hook_data = { + 'comment': comment, + 'evidence': evidence + } + call_modules_hook('on_postload_evidence_commented', data=hook_data, caseid=evidence.case_id) + track_activity(f'evidence "{evidence.filename}" commented', caseid=evidence.case_id) + + +def comments_create_for_ioc(current_user, ioc: Ioc, comment: Comments): + _create_comment(current_user, comment, ioc.case_id) + + add_comment_to_ioc(ioc.ioc_id, comment.comment_id) + + db.session.commit() + + hook_data = { + 'comment': comment, + 'ioc': ioc + } + call_modules_hook('on_postload_ioc_commented', data=hook_data, caseid=ioc.case_id) + track_activity(f'ioc "{ioc.ioc_value}" commented', caseid=ioc.case_id) + + +def comments_create_for_note(current_user, note: Notes, comment: Comments): + _create_comment(current_user, comment, note.note_case_id) + + add_comment_to_note(note.note_id, comment.comment_id) + + db.session.commit() + + hook_data = { + 'comment': comment, + 'note': note + } + call_modules_hook('on_postload_note_commented', data=hook_data, caseid=note.note_case_id) + + track_activity(f'note "{note.note_title}" commented', caseid=note.note_case_id) + + +def comments_create_for_task(current_user, task: CaseTasks, comment: Comments): + _create_comment(current_user, comment, task.task_case_id) + + add_comment_to_task(task.id, comment.comment_id) + + db.session.commit() + + hook_data = { + 'comment': comment, + 'task': task + } + call_modules_hook('on_postload_task_commented', data=hook_data, caseid=task.task_case_id) + + track_activity(f'task "{task.task_title}" commented', caseid=task.task_case_id) + + +def comments_create_for_event(current_user, event: CasesEvent, comment: Comments): + _create_comment(current_user, comment, event.case_id) + + add_comment_to_event(event.event_id, comment.comment_id) + + add_obj_history_entry(event, 'commented') + + db.session.commit() + + hook_data = { + 'comment': comment, + 'event': event + } + call_modules_hook('on_postload_event_commented', data=hook_data, caseid=event.case_id) + + track_activity(f'event "{event.event_title}" commented', caseid=event.case_id) + + +def _create_comment(current_user, comment, case_identifier): + comment.comment_case_id = case_identifier + comment.comment_user_id = current_user.id + comment.comment_date = datetime.now() + comment.comment_update_date = datetime.now() + db.session.add(comment) + db.session.commit() diff --git a/source/app/business/evidences.py b/source/app/business/evidences.py index fc2415c7f..ba933af8a 100644 --- a/source/app/business/evidences.py +++ b/source/app/business/evidences.py @@ -47,7 +47,7 @@ def evidences_create(case_identifier, request_json) -> CaseReceivedFile: evidence = _load(request_data) - crf = add_rfile(evidence=evidence, user_id=iris_current_user.id, caseid=case_identifier) + crf = add_rfile(evidence, case_identifier, iris_current_user.id) crf = call_modules_hook('on_postload_evidence_create', data=crf, caseid=case_identifier) if not crf: diff --git a/source/app/datamgmt/alerts/alerts_db.py b/source/app/datamgmt/alerts/alerts_db.py index 458fcdb09..0a35fafa2 100644 --- a/source/app/datamgmt/alerts/alerts_db.py +++ b/source/app/datamgmt/alerts/alerts_db.py @@ -23,13 +23,22 @@ from typing import List from typing import Tuple -from sqlalchemy import desc, asc, func, tuple_, or_, not_, and_ -from sqlalchemy.orm import aliased, make_transient, selectinload +from sqlalchemy import desc +from sqlalchemy import asc +from sqlalchemy import func +from sqlalchemy import tuple_ +from sqlalchemy import or_ +from sqlalchemy import not_ +from sqlalchemy import and_ +from sqlalchemy.orm import aliased +from sqlalchemy.orm import make_transient +from sqlalchemy.orm import selectinload from flask_sqlalchemy.pagination import Pagination import app from app import db -from app.datamgmt.filtering import combine_conditions, apply_custom_conditions +from app.datamgmt.filtering import combine_conditions +from app.datamgmt.filtering import apply_custom_conditions from app.datamgmt.case.case_assets_db import create_asset from app.datamgmt.case.case_assets_db import set_ioc_links from app.datamgmt.case.case_assets_db import get_unspecified_analysis_status_id @@ -48,7 +57,7 @@ from app.models.models import EventCategory from app.models.models import Tags from app.models.models import AssetsType -from app.models.models import Comments +from app.models.comments import Comments from app.models.models import CaseAssets from app.models.models import alert_assets_association from app.models.iocs import alert_iocs_association diff --git a/source/app/datamgmt/case/case_assets_db.py b/source/app/datamgmt/case/case_assets_db.py index 549986fab..efbb60f6e 100644 --- a/source/app/datamgmt/case/case_assets_db.py +++ b/source/app/datamgmt/case/case_assets_db.py @@ -17,6 +17,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import datetime +from typing import Optional from sqlalchemy import and_ from sqlalchemy import func @@ -28,12 +29,11 @@ from app.datamgmt.states import update_assets_state from app.models.models import AnalysisStatus from app.models.models import CaseStatus -from app.models.models import AssetComments from app.models.models import AssetsType from app.models.models import CaseAssets from app.models.models import CaseEventsAssets from app.models.cases import Cases -from app.models.models import Comments +from app.models.comments import Comments, AssetComments from app.models.models import CompromiseStatus from app.models.iocs import Ioc from app.models.models import IocAssetLink @@ -129,7 +129,7 @@ def get_assets_name(caseid): return assets_names -def get_asset(asset_id) -> CaseAssets: +def get_asset(asset_id) -> Optional[CaseAssets]: asset = CaseAssets.query.filter( CaseAssets.asset_id == asset_id, ).first() diff --git a/source/app/datamgmt/case/case_comments.py b/source/app/datamgmt/case/case_comments.py index c02c29433..fa95fdeda 100644 --- a/source/app/datamgmt/case/case_comments.py +++ b/source/app/datamgmt/case/case_comments.py @@ -15,11 +15,12 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +from typing import Optional -from app.models.models import Comments +from app.models.comments import Comments -def get_case_comment(comment_id, caseid) -> Comments: +def get_case_comment(comment_id, caseid) -> Optional[Comments]: if caseid is None: return Comments.query.filter( Comments.comment_id == comment_id diff --git a/source/app/datamgmt/case/case_events_db.py b/source/app/datamgmt/case/case_events_db.py index 9ae07bd1d..f37af08b3 100644 --- a/source/app/datamgmt/case/case_events_db.py +++ b/source/app/datamgmt/case/case_events_db.py @@ -26,9 +26,8 @@ from app.models.models import CaseEventsAssets from app.models.models import CaseEventsIoc from app.models.cases import CasesEvent -from app.models.models import Comments +from app.models.comments import Comments, EventComments from app.models.models import EventCategory -from app.models.models import EventComments from app.models.iocs import Ioc from app.models.models import IocAssetLink from app.models.models import IocType @@ -174,6 +173,19 @@ def delete_event_comment(event_id, comment_id): return True, "Comment deleted" +def delete_events_comments_in_case(case_identifier): + com_ids = EventComments.query.with_entities( + EventComments.comment_id + ).join(CasesEvent).filter( + EventComments.comment_event_id == CasesEvent.event_id, + CasesEvent.case_id == case_identifier + ).all() + + com_ids = [c.comment_id for c in com_ids] + EventComments.query.filter(EventComments.comment_id.in_(com_ids)).delete() + Comments.query.filter(Comments.comment_id.in_(com_ids)).delete() + + def add_comment_to_event(event_id, comment_id): ec = EventComments() ec.comment_event_id = event_id diff --git a/source/app/datamgmt/case/case_iocs_db.py b/source/app/datamgmt/case/case_iocs_db.py index 7861a599a..ade70045b 100644 --- a/source/app/datamgmt/case/case_iocs_db.py +++ b/source/app/datamgmt/case/case_iocs_db.py @@ -15,6 +15,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + from sqlalchemy import and_ from app import db @@ -24,16 +25,15 @@ from app.datamgmt.states import update_ioc_state from app.iris_engine.access_control.utils import ac_get_fast_user_cases_access from app.models.alerts import Alert -from app.models.cases import Cases, CasesEvent -from app.models.models import Client, CaseAssets -from app.models.models import Comments +from app.models.cases import Cases +from app.models.cases import CasesEvent +from app.models.models import Client +from app.models.models import CaseAssets +from app.models.comments import Comments, IocComments from app.models.iocs import Ioc -from app.models.models import IocComments from app.models.models import IocType from app.models.iocs import Tlp from app.models.authorization import User -from app.models.authorization import UserCaseEffectiveAccess -from app.models.authorization import CaseAccessLevel from app.models.pagination_parameters import PaginationParameters from app.util import add_obj_history_entry @@ -90,10 +90,7 @@ def delete_ioc(ioc: Ioc): com_ids = [c.comment_id for c in com_ids] IocComments.query.filter(IocComments.comment_id.in_(com_ids)).delete() - - Comments.query.filter( - Comments.comment_id.in_(com_ids) - ).delete() + Comments.query.filter(Comments.comment_id.in_(com_ids)).delete() db.session.delete(ioc) @@ -293,17 +290,6 @@ def get_ioc_by_value(ioc_value, caseid=None): return Ioc.query.filter(Ioc.ioc_value == ioc_value).first() -def user_list_cases_view(user_id): - res = UserCaseEffectiveAccess.query.with_entities( - UserCaseEffectiveAccess.case_id - ).filter(and_( - UserCaseEffectiveAccess.user_id == user_id, - UserCaseEffectiveAccess.access_level != CaseAccessLevel.deny_all.value - )).all() - - return [r.case_id for r in res] - - def get_filtered_iocs( caseid: int = None, pagination_parameters: PaginationParameters = None, diff --git a/source/app/datamgmt/case/case_notes_db.py b/source/app/datamgmt/case/case_notes_db.py index 464000ef6..e564a51e2 100644 --- a/source/app/datamgmt/case/case_notes_db.py +++ b/source/app/datamgmt/case/case_notes_db.py @@ -21,11 +21,10 @@ from app.iris_engine.access_control.iris_user import iris_current_user from app.datamgmt.manage.manage_attribute_db import get_default_custom_attributes from app.datamgmt.states import update_notes_state -from app.models.models import Comments +from app.models.comments import Comments, NotesComments from app.models.models import NoteDirectory from app.models.models import NoteRevisions from app.models.models import Notes -from app.models.models import NotesComments from app.models.models import NotesGroup from app.models.models import NotesGroupLink from app.models.authorization import User @@ -98,6 +97,19 @@ def delete_note(note_identifier, case_identifier): update_notes_state(caseid=case_identifier) +def delete_notes_comments_in_case(case_identifier): + com_ids = NotesComments.query.with_entities( + NotesComments.comment_id + ).join(Notes).filter( + NotesComments.comment_note_id == Notes.note_id, + Notes.note_case_id == case_identifier + ).all() + + com_ids = [c.comment_id for c in com_ids] + NotesComments.query.filter(NotesComments.comment_id.in_(com_ids)).delete() + Comments.query.filter(Comments.comment_id.in_(com_ids)).delete() + + def update_note(note_content, note_title, update_date, user_id, note_id, caseid): note = get_note_raw(note_id, caseid=caseid) diff --git a/source/app/datamgmt/case/case_rfiles_db.py b/source/app/datamgmt/case/case_rfiles_db.py index 03b6183b7..ed5270983 100644 --- a/source/app/datamgmt/case/case_rfiles_db.py +++ b/source/app/datamgmt/case/case_rfiles_db.py @@ -25,8 +25,7 @@ from app.datamgmt.manage.manage_attribute_db import get_default_custom_attributes from app.datamgmt.states import update_evidences_state from app.models.models import CaseReceivedFile -from app.models.models import Comments -from app.models.models import EvidencesComments +from app.models.comments import Comments, EvidencesComments from app.models.authorization import User from app.models.pagination_parameters import PaginationParameters from app.datamgmt.conversions import convert_sort_direction @@ -60,7 +59,7 @@ def get_paginated_evidences(case_identifier, pagination_parameters: PaginationPa ) -def add_rfile(evidence, caseid, user_id): +def add_rfile(evidence: CaseReceivedFile, caseid, user_id): evidence.date_added = datetime.datetime.now() evidence.case_id = caseid @@ -110,6 +109,19 @@ def delete_rfile(evidence: CaseReceivedFile): db.session.commit() +def delete_evidences_comments_in_case(case_identifier): + com_ids = EvidencesComments.query.with_entities( + EvidencesComments.comment_id + ).join(CaseReceivedFile).filter( + EvidencesComments.comment_evidence_id == CaseReceivedFile.id, + CaseReceivedFile.case_id == case_identifier + ).all() + + com_ids = [c.comment_id for c in com_ids] + EvidencesComments.query.filter(EvidencesComments.comment_id.in_(com_ids)).delete() + Comments.query.filter(Comments.comment_id.in_(com_ids)).delete() + + def get_case_evidence_comments(evidence_id): return Comments.query.filter( EvidencesComments.comment_evidence_id == evidence_id diff --git a/source/app/datamgmt/case/case_tasks_db.py b/source/app/datamgmt/case/case_tasks_db.py index 89b8ab8f9..dc2bf90de 100644 --- a/source/app/datamgmt/case/case_tasks_db.py +++ b/source/app/datamgmt/case/case_tasks_db.py @@ -29,8 +29,7 @@ from app.models.models import CaseTasks from app.models.models import TaskAssignee from app.models.cases import Cases -from app.models.models import Comments -from app.models.models import TaskComments +from app.models.comments import Comments, TaskComments from app.models.models import TaskStatus from app.models.authorization import User from app.models.pagination_parameters import PaginationParameters @@ -316,6 +315,19 @@ def delete_task_comment(task_id, comment_id): return True, "Comment deleted" +def delete_tasks_comments_in_case(case_identifier): + com_ids = TaskComments.query.with_entities( + TaskComments.comment_id + ).join(CaseTasks).filter( + TaskComments.comment_task_id == CaseTasks.id, + CaseTasks.task_case_id == case_identifier + ).all() + + com_ids = [c.comment_id for c in com_ids] + TaskComments.query.filter(TaskComments.comment_id.in_(com_ids)).delete() + Comments.query.filter(Comments.comment_id.in_(com_ids)).delete() + + def get_tasks_cases_mapping(open_cases_only=False): condition = Cases.close_date == None if open_cases_only else True diff --git a/source/app/datamgmt/comments.py b/source/app/datamgmt/comments.py index ef8d862e0..9d4cbb451 100644 --- a/source/app/datamgmt/comments.py +++ b/source/app/datamgmt/comments.py @@ -18,13 +18,8 @@ from flask_sqlalchemy.pagination import Pagination -from app.models.models import Comments -from app.models.models import AssetComments -from app.models.models import EvidencesComments -from app.models.models import IocComments -from app.models.models import NotesComments -from app.models.models import TaskComments -from app.models.models import EventComments +from app.models.comments import Comments, EventComments, TaskComments, IocComments, AssetComments, EvidencesComments, \ + NotesComments from app.models.pagination_parameters import PaginationParameters diff --git a/source/app/datamgmt/manage/manage_access_control_db.py b/source/app/datamgmt/manage/manage_access_control_db.py index a97db33a8..cb65fdc18 100644 --- a/source/app/datamgmt/manage/manage_access_control_db.py +++ b/source/app/datamgmt/manage/manage_access_control_db.py @@ -14,11 +14,14 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +from sqlalchemy import and_ from app import ac_current_user_has_permission +from app import db from app.models.cases import Cases from app.models.authorization import Group from app.models.authorization import UserClient +from app.models.authorization import UserCaseEffectiveAccess from app.models.authorization import Permissions from app.models.authorization import CaseAccessLevel from app.models.authorization import GroupCaseAccess @@ -76,7 +79,7 @@ def manage_ac_audit_users_db(): return ret -def check_ua_case_client(user_id: int, case_id: int) -> Optional[UserClient]: +def check_ua_case_client(user_id: int, case_id: int) -> Optional[int]: """Check if the user has access to the case, through the customer of the case (in other words, check that the customer of the case is assigned to the user) @@ -90,18 +93,34 @@ def check_ua_case_client(user_id: int, case_id: int) -> Optional[UserClient]: """ if ac_current_user_has_permission(Permissions.server_administrator): # Return a dummy object - uc = UserClient() - uc.access_level = CaseAccessLevel.full_access.value - return uc + return CaseAccessLevel.full_access.value result = UserClient.query.filter( UserClient.user_id == user_id, Cases.case_id == case_id - ).join(Cases, - UserClient.client_id == Cases.client_id + ).join( + Cases, + UserClient.client_id == Cases.client_id ).first() - return result + if not result: + return None + + return result.access_level + + +def get_case_effective_access(user_identifier, case_identifier) -> Optional[int]: + row = UserCaseEffectiveAccess.query.with_entities( + UserCaseEffectiveAccess.access_level + ).filter( + UserCaseEffectiveAccess.user_id == user_identifier, + UserCaseEffectiveAccess.case_id == case_identifier + ).first() + + if not row: + return None + + return row[0] def get_client_users(client_id: int) -> list: @@ -157,3 +176,29 @@ def user_has_client_access(user_id: int, client_id: int) -> bool: ).first() return result is not None + + +def remove_duplicate_user_case_effective_accesses(user_id, case_id): + uac = UserCaseEffectiveAccess.query.where(and_( + UserCaseEffectiveAccess.user_id == user_id, + UserCaseEffectiveAccess.case_id == case_id + )).all() + + if len(uac) <= 1: + return False + + for u in uac[1:]: + db.session.delete(u) + db.session.commit() + return True + + +def set_user_case_effective_access(access_level, case_id, user_id): + uac = UserCaseEffectiveAccess.query.where(and_( + UserCaseEffectiveAccess.user_id == user_id, + UserCaseEffectiveAccess.case_id == case_id + )).first() + if uac: + uac = uac[0] + uac.access_level = access_level + db.session.commit() diff --git a/source/app/datamgmt/manage/manage_case_templates_db.py b/source/app/datamgmt/manage/manage_case_templates_db.py index 4dd3fb4ff..d82975299 100644 --- a/source/app/datamgmt/manage/manage_case_templates_db.py +++ b/source/app/datamgmt/manage/manage_case_templates_db.py @@ -58,7 +58,7 @@ def get_case_templates_list() -> List[dict]: return c_cl -def get_case_template_by_id(cur_id: int) -> CaseTemplate: +def get_case_template_by_id(cur_id: int) -> Optional[CaseTemplate]: """Get a case template Args: @@ -67,8 +67,7 @@ def get_case_template_by_id(cur_id: int) -> CaseTemplate: Returns: CaseTemplate: Case template """ - case_template = CaseTemplate.query.filter_by(id=cur_id).first() - return case_template + return CaseTemplate.query.filter_by(id=cur_id).first() def delete_case_template_by_id(case_template_id: int): diff --git a/source/app/datamgmt/manage/manage_cases_db.py b/source/app/datamgmt/manage/manage_cases_db.py index dc83b83b1..f01e654ac 100644 --- a/source/app/datamgmt/manage/manage_cases_db.py +++ b/source/app/datamgmt/manage/manage_cases_db.py @@ -15,6 +15,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + from datetime import datetime from datetime import date from datetime import timedelta @@ -56,6 +57,7 @@ from app.models.models import NotesGroupLink from app.models.models import UserActivity from app.models.alerts import AlertCaseAssociation +from app.models.comments import Comments, IocComments, AssetComments from app.models.authorization import CaseAccessLevel from app.models.authorization import GroupCaseAccess from app.models.authorization import OrganisationCaseAccess @@ -67,6 +69,10 @@ from app.models.cases import CaseTags from app.models.cases import CaseState from app.models.pagination_parameters import PaginationParameters +from app.datamgmt.case.case_rfiles_db import delete_evidences_comments_in_case +from app.datamgmt.case.case_notes_db import delete_notes_comments_in_case +from app.datamgmt.case.case_tasks_db import delete_tasks_comments_in_case +from app.datamgmt.case.case_events_db import delete_events_comments_in_case def list_cases_id(): @@ -314,14 +320,98 @@ def get_case_details_rt(case_id): return res +def _delete_iocs(case_identifier): + # TODO should do this with the 2.0 SQLAlchemy API + # TODO maybe this can be performed automatically with cascades + com_ids = IocComments.query.with_entities( + IocComments.comment_id + ).join( + Ioc + ).filter( + IocComments.comment_ioc_id == Ioc.ioc_id, + Ioc.case_id == case_identifier + ).all() + + com_ids = [c.comment_id for c in com_ids] + IocComments.query.filter(IocComments.comment_id.in_(com_ids)).delete() + + Comments.query.filter( + Comments.comment_id.in_(com_ids) + ).delete() + Ioc.query.filter(Ioc.case_id == case_identifier).delete() + + +def _delete_assets(case_identifier): + com_ids = AssetComments.query.with_entities( + AssetComments.comment_id + ).join(CaseAssets).filter( + AssetComments.comment_asset_id == CaseAssets.asset_id, + CaseAssets.case_id == case_identifier + ).all() + + com_ids = [c.comment_id for c in com_ids] + AssetComments.query.filter(AssetComments.comment_id.in_(com_ids)).delete() + Comments.query.filter(Comments.comment_id.in_(com_ids)).delete() + + CaseAssetsAlias = aliased(CaseAssets) + + # Query for CaseAssets that are not referenced in alerts and match the case_id + assets_to_delete = db.session.query(CaseAssets).filter( + and_( + CaseAssets.case_id == case_identifier, + ~db.session.query(alert_assets_association).filter( + alert_assets_association.c.asset_id == CaseAssetsAlias.asset_id + ).exists() + ) + ) + # Delete the assets + assets_to_delete.delete(synchronize_session='fetch') + + +def _delete_evidences(case_identifier): + delete_evidences_comments_in_case(case_identifier) + CaseReceivedFile.query.filter(CaseReceivedFile.case_id == case_identifier).delete() + + +def _delete_notes(case_identifier): + delete_notes_comments_in_case(case_identifier) + # Legacy code + NotesGroupLink.query.filter(NotesGroupLink.case_id == case_identifier).delete() + NotesGroup.query.filter(NotesGroup.group_case_id == case_identifier).delete() + NoteRevisions.query.filter( + and_( + Notes.note_case_id == case_identifier, + NoteRevisions.note_id == Notes.note_id + ) + ).delete() + Notes.query.filter(Notes.note_case_id == case_identifier).delete() + NoteDirectory.query.filter(NoteDirectory.case_id == case_identifier).delete() + + +def _delete_tasks(case_identifier): + delete_tasks_comments_in_case(case_identifier) + tasks = CaseTasks.query.filter(CaseTasks.task_case_id == case_identifier).all() + for task in tasks: + TaskAssignee.query.filter(TaskAssignee.task_id == task.id).delete() + CaseTasks.query.filter(CaseTasks.id == task.id).delete() + + +def _delete_events(case_identifier): + delete_events_comments_in_case(case_identifier) + da = CasesEvent.query.with_entities(CasesEvent.event_id).filter(CasesEvent.case_id == case_identifier).all() + for event in da: + CaseEventCategory.query.filter(CaseEventCategory.event_id == event.event_id).delete() + CasesEvent.query.filter(CasesEvent.case_id == case_identifier).delete() + + def delete_case(case_id): if not Cases.query.filter(Cases.case_id == case_id).first(): return False delete_case_states(caseid=case_id) UserActivity.query.filter(UserActivity.case_id == case_id).delete() - CaseReceivedFile.query.filter(CaseReceivedFile.case_id == case_id).delete() - Ioc.query.filter(Ioc.case_id == case_id).delete() + _delete_evidences(case_id) + _delete_iocs(case_id) CaseTags.query.filter(CaseTags.case_id == case_id).delete() CaseProtagonist.query.filter(CaseProtagonist.case_id == case_id).delete() @@ -347,20 +437,7 @@ def delete_case(case_id): CaseEventsAssets.query.filter(CaseEventsAssets.case_id == case_id).delete() CaseEventsIoc.query.filter(CaseEventsIoc.case_id == case_id).delete() - CaseAssetsAlias = aliased(CaseAssets) - - # Query for CaseAssets that are not referenced in alerts and match the case_id - assets_to_delete = db.session.query(CaseAssets).filter( - and_( - CaseAssets.case_id == case_id, - ~db.session.query(alert_assets_association).filter( - alert_assets_association.c.asset_id == CaseAssetsAlias.asset_id - ).exists() - ) - ) - - # Delete the assets - assets_to_delete.delete(synchronize_session='fetch') + _delete_assets(case_id) # Get all alerts associated with assets in the case alerts_to_update = db.session.query(CaseAssets).filter(CaseAssets.case_id == case_id) @@ -369,30 +446,10 @@ def delete_case(case_id): alerts_to_update.update({CaseAssets.case_id: None}, synchronize_session='fetch') db.session.commit() - # Legacy code - NotesGroupLink.query.filter(NotesGroupLink.case_id == case_id).delete() - NotesGroup.query.filter(NotesGroup.group_case_id == case_id).delete() - - NoteRevisions.query.filter( - and_( - Notes.note_case_id == case_id, - NoteRevisions.note_id == Notes.note_id - ) - ).delete() - - Notes.query.filter(Notes.note_case_id == case_id).delete() - NoteDirectory.query.filter(NoteDirectory.case_id == case_id).delete() - - tasks = CaseTasks.query.filter(CaseTasks.task_case_id == case_id).all() - for task in tasks: - TaskAssignee.query.filter(TaskAssignee.task_id == task.id).delete() - CaseTasks.query.filter(CaseTasks.id == task.id).delete() - - da = CasesEvent.query.with_entities(CasesEvent.event_id).filter(CasesEvent.case_id == case_id).all() - for event in da: - CaseEventCategory.query.filter(CaseEventCategory.event_id == event.event_id).delete() + _delete_notes(case_id) + _delete_tasks(case_id) - CasesEvent.query.filter(CasesEvent.case_id == case_id).delete() + _delete_events(case_id) UserCaseAccess.query.filter(UserCaseAccess.case_id == case_id).delete() UserCaseEffectiveAccess.query.filter(UserCaseEffectiveAccess.case_id == case_id).delete() diff --git a/source/app/datamgmt/manage/manage_users_db.py b/source/app/datamgmt/manage/manage_users_db.py index dc45dac2a..2bf76d3ab 100644 --- a/source/app/datamgmt/manage/manage_users_db.py +++ b/source/app/datamgmt/manage/manage_users_db.py @@ -34,7 +34,6 @@ from app.iris_engine.access_control.utils import ac_auto_update_user_effective_access from app.iris_engine.access_control.utils import ac_get_detailed_effective_permissions_from_groups from app.iris_engine.access_control.utils import ac_remove_case_access_from_user -from app.iris_engine.access_control.utils import ac_set_case_access_for_user from app.models.cases import Cases from app.models.models import Client from app.models.models import UserActivity @@ -445,46 +444,6 @@ def remove_case_access_from_user(user_id, case_id): return True, 'Case access removed' -def set_user_case_access(user_id, case_id, access_level): - if user_id is None or type(user_id) is not int: - return False, 'Invalid user id' - - if case_id is None or type(case_id) is not int: - return False, "Invalid case id" - - if access_level is None or type(access_level) is not int: - return False, "Invalid access level" - - if CaseAccessLevel.has_value(access_level) is False: - return False, "Invalid access level" - - uca = UserCaseAccess.query.filter( - UserCaseAccess.user_id == user_id, - UserCaseAccess.case_id == case_id - ).all() - - if len(uca) > 1: - for u in uca: - db.session.delete(u) - db.session.commit() - uca = None - - if not uca: - uca = UserCaseAccess() - uca.user_id = user_id - uca.case_id = case_id - uca.access_level = access_level - db.session.add(uca) - else: - uca[0].access_level = access_level - - db.session.commit() - - ac_set_case_access_for_user(user_id, case_id, access_level) - - return True, f'Case access set to {access_level} for user {user_id}' - - def get_user_details(user_id, include_api_key=False): user = User.query.filter(User.id == user_id).first() diff --git a/source/app/datamgmt/reporter/report_db.py b/source/app/datamgmt/reporter/report_db.py index bfdc715ec..e4690a72e 100644 --- a/source/app/datamgmt/reporter/report_db.py +++ b/source/app/datamgmt/reporter/report_db.py @@ -33,7 +33,7 @@ from app.models.models import CaseTasks from app.models.cases import Cases from app.models.cases import CasesEvent -from app.models.models import Comments +from app.models.comments import Comments from app.models.models import EventCategory from app.models.iocs import Ioc from app.models.models import IocAssetLink diff --git a/source/app/iris_engine/access_control/iris_user.py b/source/app/iris_engine/access_control/iris_user.py index f3fca0f5b..a426d8b63 100644 --- a/source/app/iris_engine/access_control/iris_user.py +++ b/source/app/iris_engine/access_control/iris_user.py @@ -33,7 +33,7 @@ def __init__(self, user_data): self.is_anonymous = False -def get_current_user(): +def _get_current_user(): """ Returns a compatible user object for both session and token auth For token auth, uses data from g.auth_user @@ -47,4 +47,4 @@ def get_current_user(): return None -iris_current_user = LocalProxy(lambda: get_current_user()) +iris_current_user = LocalProxy(lambda: _get_current_user()) diff --git a/source/app/iris_engine/access_control/utils.py b/source/app/iris_engine/access_control/utils.py index 5f85f865d..9de8cc59e 100644 --- a/source/app/iris_engine/access_control/utils.py +++ b/source/app/iris_engine/access_control/utils.py @@ -1,13 +1,14 @@ from flask import session from sqlalchemy import and_ -import app from app import db +from app.business.access_controls import set_case_effective_access_for_user +from app.logger import logger from app.iris_engine.access_control.iris_user import iris_current_user -from app.datamgmt.manage.manage_access_control_db import check_ua_case_client from app.models.cases import Cases from app.models.models import Client from app.models.authorization import CaseAccessLevel +from app.models.authorization import ac_flag_match_mask from app.models.authorization import UserClient from app.models.authorization import Group from app.models.authorization import GroupCaseAccess @@ -17,12 +18,6 @@ from app.models.authorization import UserCaseEffectiveAccess from app.models.authorization import UserGroup -log = app.app.logger - - -def ac_flag_match_mask(flag, mask): - return (flag & mask) == mask - def ac_get_mask_full_permissions(): """ @@ -289,42 +284,6 @@ def ac_trace_effective_user_permissions(user_id): return perms -def ac_fast_check_user_has_case_access(user_id, cid, access_level: list[CaseAccessLevel]): - """ - Checks the user has access to the case with at least one of the access_level - if the user has access, returns the access level of the user to the case - Returns None otherwise - """ - ucea = UserCaseEffectiveAccess.query.with_entities( - UserCaseEffectiveAccess.access_level - ).filter( - UserCaseEffectiveAccess.user_id == user_id, - UserCaseEffectiveAccess.case_id == cid - ).first() - - if not ucea: - # The user has no direct access, check if he is part of the client - cuacu = check_ua_case_client(user_id, cid) - if cuacu is None: - return None - ac_set_case_access_for_user(user_id, cid, cuacu.access_level) - - return cuacu.access_level - - if ac_flag_match_mask(ucea[0], CaseAccessLevel.deny_all.value): - return None - - for acl in access_level: - if ac_flag_match_mask(ucea[0], acl.value): - return ucea[0] - - return None - - -def ac_fast_check_current_user_has_case_access(cid, access_level): - return ac_fast_check_user_has_case_access(iris_current_user.id, cid, access_level) - - def ac_recompute_effective_ac_from_users_list(users_list): """ Recompute all users effective access of users @@ -541,7 +500,7 @@ def ac_remove_case_access_from_user(user_id, case_id): )).all() if len(uac) > 1: - log.error(f'Multiple access found for user {user_id} and case {case_id}') + logger.error(f'Multiple access found for user {user_id} and case {case_id}') for u in uac: db.session.delete(u) db.session.commit() @@ -571,47 +530,15 @@ def ac_set_case_access_for_users(users, case_id, access_level): user_id = user.get('id') if user_id == iris_current_user.id: logs = "It's done, but I excluded you from the list of users to update, Dave" - ac_set_case_access_for_user(user.get('id'), case_id, access_level=CaseAccessLevel.full_access.value) + set_case_effective_access_for_user(user.get('id'), case_id, CaseAccessLevel.full_access.value) continue - ac_set_case_access_for_user(user.get('id'), case_id, access_level) + set_case_effective_access_for_user(user.get('id'), case_id, access_level) db.session.commit() return True, logs -def ac_set_case_access_for_user(user_id, case_id, access_level, commit=True): - """ - Set a case access from a user - """ - - uac = UserCaseEffectiveAccess.query.where(and_( - UserCaseEffectiveAccess.user_id == user_id, - UserCaseEffectiveAccess.case_id == case_id - )).all() - - if len(uac) > 1: - log.error(f'Multiple access found for user {user_id} and case {case_id}') - for u in uac: - db.session.delete(u) - db.session.commit() - - uac = UserCaseEffectiveAccess() - uac.user_id = user_id - uac.case_id = case_id - uac.access_level = access_level - db.session.add(uac) - - elif len(uac) == 1: - uac = uac[0] - uac.access_level = access_level - - if commit: - db.session.commit() - - return - - def ac_get_fast_user_cases_access(user_id): ucea = UserCaseEffectiveAccess.query.with_entities( UserCaseEffectiveAccess.case_id @@ -734,7 +661,6 @@ def ac_trace_user_effective_cases_access_2(user_id): UserCaseAccess.user ).all() - effective_cases_access = {} cases = Cases.query.with_entities( Cases.case_id, Cases.name @@ -755,6 +681,7 @@ def ac_trace_user_effective_cases_access_2(user_id): UserClient.user_id == user_id ).all() + effective_cases_access = {} # Organisation case access. Default access level for oca in cases: access = { @@ -775,7 +702,7 @@ def ac_trace_user_effective_cases_access_2(user_id): 'case_id': oca.case_id }, 'user_access': [], - 'user_effective_access': CaseAccessLevel.deny_all.value + 'user_effective_access': ac_access_level_to_list(CaseAccessLevel.deny_all.value) } effective_cases_access[oca.case_id]['user_access'].append(access) @@ -795,7 +722,7 @@ def ac_trace_user_effective_cases_access_2(user_id): } if gca.case_id in effective_cases_access: - effective_cases_access[gca.case_id]['user_effective_access'] = gca.access_level + effective_cases_access[gca.case_id]['user_effective_access'] = ac_access_level_to_list(gca.access_level) for kec in effective_cases_access[gca.case_id]['user_access']: kec['state'] = f'Overwritten by group {gca.group_name}' @@ -806,7 +733,7 @@ def ac_trace_user_effective_cases_access_2(user_id): 'case_id': gca.case_id }, 'user_access': [], - 'user_effective_access': gca.access_level + 'user_effective_access': ac_access_level_to_list(gca.access_level) } effective_cases_access[gca.case_id]['user_access'].append(access) @@ -826,7 +753,7 @@ def ac_trace_user_effective_cases_access_2(user_id): } if cca.case_id in effective_cases_access: - effective_cases_access[cca.case_id]['user_effective_access'] = cca.access_level + effective_cases_access[cca.case_id]['user_effective_access'] = ac_access_level_to_list(cca.access_level) for kec in effective_cases_access[cca.case_id]['user_access']: kec['state'] = f'Overwritten by customer {cca.client_name}' @@ -837,7 +764,7 @@ def ac_trace_user_effective_cases_access_2(user_id): 'case_id': cca.case_id }, 'user_access': [], - 'user_effective_access': cca.access_level + 'user_effective_access': ac_access_level_to_list(cca.access_level) } effective_cases_access[cca.case_id]['user_access'].append(access) @@ -857,7 +784,7 @@ def ac_trace_user_effective_cases_access_2(user_id): } if uca.case_id in effective_cases_access: - effective_cases_access[uca.case_id]['user_effective_access'] = uca.access_level + effective_cases_access[uca.case_id]['user_effective_access'] = ac_access_level_to_list(uca.access_level) for kec in effective_cases_access[uca.case_id]['user_access']: kec['state'] = 'Overwritten by self user access' @@ -869,15 +796,11 @@ def ac_trace_user_effective_cases_access_2(user_id): 'case_id': uca.case_id }, 'user_access': [], - 'user_effective_access': uca.access_level + 'user_effective_access': ac_access_level_to_list(uca.access_level) } effective_cases_access[uca.case_id]['user_access'].append(access) - for case_id in effective_cases_access: - effective_cases_access[case_id]['user_effective_access'] = ac_access_level_to_list( - effective_cases_access[case_id]['user_effective_access']) - return effective_cases_access diff --git a/source/app/models/authorization.py b/source/app/models/authorization.py index 1f40c2337..444cc1a27 100644 --- a/source/app/models/authorization.py +++ b/source/app/models/authorization.py @@ -253,3 +253,7 @@ def save(self): db.session.commit() return self + + +def ac_flag_match_mask(flag, mask): + return (flag & mask) == mask diff --git a/source/app/models/comments.py b/source/app/models/comments.py new file mode 100644 index 000000000..459051ed5 --- /dev/null +++ b/source/app/models/comments.py @@ -0,0 +1,114 @@ +# IRIS Source Code +# Copyright (C) 2025 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +import uuid + +from sqlalchemy import Column +from sqlalchemy import BigInteger +from sqlalchemy import UUID +from sqlalchemy import text +from sqlalchemy import Text +from sqlalchemy import DateTime +from sqlalchemy import ForeignKey +from sqlalchemy.orm import relationship + +from app import db + + +class Comments(db.Model): + __tablename__ = "comments" + + comment_id = Column(BigInteger, primary_key=True) + comment_uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, server_default=text("gen_random_uuid()"), + nullable=False) + comment_text = Column(Text) + comment_date = Column(DateTime) + comment_update_date = Column(DateTime) + comment_user_id = Column(ForeignKey('user.id')) + comment_case_id = Column(ForeignKey('cases.case_id')) + comment_alert_id = Column(ForeignKey('alerts.alert_id')) + + user = relationship('User') + case = relationship('Cases') + alert = relationship('Alert') + + +class EventComments(db.Model): + __tablename__ = "event_comments" + + id = Column(BigInteger, primary_key=True) + comment_id = Column(ForeignKey('comments.comment_id')) + comment_event_id = Column(ForeignKey('cases_events.event_id')) + + event = relationship('CasesEvent') + comment = relationship('Comments') + + +class TaskComments(db.Model): + __tablename__ = "task_comments" + + id = Column(BigInteger, primary_key=True) + comment_id = Column(ForeignKey('comments.comment_id')) + comment_task_id = Column(ForeignKey('case_tasks.id')) + + task = relationship('CaseTasks') + comment = relationship('Comments') + + +class IocComments(db.Model): + __tablename__ = "ioc_comments" + + id = Column(BigInteger, primary_key=True) + comment_id = Column(ForeignKey('comments.comment_id')) + comment_ioc_id = Column(ForeignKey('ioc.ioc_id')) + + ioc = relationship('Ioc') + comment = relationship('Comments') + + +class AssetComments(db.Model): + __tablename__ = "asset_comments" + + id = Column(BigInteger, primary_key=True) + comment_id = Column(ForeignKey('comments.comment_id')) + comment_asset_id = Column(ForeignKey('case_assets.asset_id')) + + asset = relationship('CaseAssets') + comment = relationship('Comments') + + +class EvidencesComments(db.Model): + __tablename__ = "evidence_comments" + + id = Column(BigInteger, primary_key=True) + comment_id = Column(ForeignKey('comments.comment_id')) + comment_evidence_id = Column(ForeignKey('case_received_file.id')) + + evidence = relationship('CaseReceivedFile') + comment = relationship('Comments') + + +class NotesComments(db.Model): + __tablename__ = "note_comments" + + id = Column(BigInteger, primary_key=True) + comment_id = Column(ForeignKey('comments.comment_id')) + comment_note_id = Column(ForeignKey('notes.note_id')) + + note = relationship('Notes') + comment = relationship('Comments') diff --git a/source/app/models/models.py b/source/app/models/models.py index ad3cbcb4e..a4a2f0616 100644 --- a/source/app/models/models.py +++ b/source/app/models/models.py @@ -709,90 +709,6 @@ class ServerSettings(db.Model): force_confirmation_before_delete = Column(Boolean) -class Comments(db.Model): - __tablename__ = "comments" - - comment_id = Column(BigInteger, primary_key=True) - comment_uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, server_default=text("gen_random_uuid()"), - nullable=False) - comment_text = Column(Text) - comment_date = Column(DateTime) - comment_update_date = Column(DateTime) - comment_user_id = Column(ForeignKey('user.id')) - comment_case_id = Column(ForeignKey('cases.case_id')) - comment_alert_id = Column(ForeignKey('alerts.alert_id')) - - user = relationship('User') - case = relationship('Cases') - alert = relationship('Alert') - - -class EventComments(db.Model): - __tablename__ = "event_comments" - - id = Column(BigInteger, primary_key=True) - comment_id = Column(ForeignKey('comments.comment_id')) - comment_event_id = Column(ForeignKey('cases_events.event_id')) - - event = relationship('CasesEvent') - comment = relationship('Comments') - - -class TaskComments(db.Model): - __tablename__ = "task_comments" - - id = Column(BigInteger, primary_key=True) - comment_id = Column(ForeignKey('comments.comment_id')) - comment_task_id = Column(ForeignKey('case_tasks.id')) - - task = relationship('CaseTasks') - comment = relationship('Comments') - - -class IocComments(db.Model): - __tablename__ = "ioc_comments" - - id = Column(BigInteger, primary_key=True) - comment_id = Column(ForeignKey('comments.comment_id')) - comment_ioc_id = Column(ForeignKey('ioc.ioc_id')) - - ioc = relationship('Ioc') - comment = relationship('Comments') - - -class AssetComments(db.Model): - __tablename__ = "asset_comments" - - id = Column(BigInteger, primary_key=True) - comment_id = Column(ForeignKey('comments.comment_id')) - comment_asset_id = Column(ForeignKey('case_assets.asset_id')) - - asset = relationship('CaseAssets') - comment = relationship('Comments') - - -class EvidencesComments(db.Model): - __tablename__ = "evidence_comments" - - id = Column(BigInteger, primary_key=True) - comment_id = Column(ForeignKey('comments.comment_id')) - comment_evidence_id = Column(ForeignKey('case_received_file.id')) - - evidence = relationship('CaseReceivedFile') - comment = relationship('Comments') - - -class NotesComments(db.Model): - __tablename__ = "note_comments" - - id = Column(BigInteger, primary_key=True) - comment_id = Column(ForeignKey('comments.comment_id')) - comment_note_id = Column(ForeignKey('notes.note_id')) - - note = relationship('Notes') - comment = relationship('Comments') - - class IrisModule(db.Model): __tablename__ = "iris_module" diff --git a/source/app/schema/marshables.py b/source/app/schema/marshables.py index 3ccb2df43..4d207553b 100644 --- a/source/app/schema/marshables.py +++ b/source/app/schema/marshables.py @@ -69,7 +69,7 @@ from app.models.cases import Cases from app.models.cases import CasesEvent from app.models.models import Client -from app.models.models import Comments +from app.models.comments import Comments from app.models.models import Contact from app.models.models import DataStoreFile from app.models.models import EventCategory diff --git a/tests/iris.py b/tests/iris.py index 8df8d075c..710e41296 100644 --- a/tests/iris.py +++ b/tests/iris.py @@ -16,6 +16,7 @@ # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +from time import sleep from uuid import uuid4 from pathlib import Path from docker_compose import DockerCompose @@ -37,6 +38,8 @@ IRIS_PERMISSION_ALERTS_DELETE = 0x10 IRIS_PERMISSION_CUSTOMERS_WRITE = 0x80 +IRIS_CASE_ACCESS_LEVEL_READ_ONLY = 0x2 + class Iris: @@ -129,3 +132,40 @@ def clear_database(self): def extract_logs(self, service): return self._docker_compose.extract_logs(service) + + def get_module_identifier_by_name(self, module_name): + response = self.get('/manage/modules/list').json() + module_identifier = None + for module in response['data']: + if module['module_human_name'] == module_name: + module_identifier = module['id'] + return module_identifier + + @staticmethod + def get_most_recent_object_history_entry(response): + modification_history = response['modification_history'] + current_timestamp = 0 + result = None + for timestamp_as_string, modification in modification_history.items(): + timestamp = float(timestamp_as_string) + if timestamp < current_timestamp: + continue + result = modification + current_timestamp = timestamp + return result + + def wait_for_module_task(self): + response = self.get('/dim/tasks/list/1').json() + attempts = 0 + while len(response['data']) == 0: + sleep(1) + response = self.get('/dim/tasks/list/1').json() + attempts += 1 + if attempts > 20: + logs = self.extract_logs('worker') + raise TimeoutError(f'Timed out with logs: {logs}') + return response['data'][0] + + def get_latest_activity(self): + activities = self.get('/activities/list-all').json() + return activities['data'][0] diff --git a/tests/tests_rest_assets.py b/tests/tests_rest_assets.py index 0af24affa..ff736e003 100644 --- a/tests/tests_rest_assets.py +++ b/tests/tests_rest_assets.py @@ -24,18 +24,6 @@ _CASE_ACCESS_LEVEL_FULL_ACCESS = 4 -def _get_most_recent_modification(modification_history): - current_timestamp = 0 - result = None - for timestamp_as_string, modification in modification_history.items(): - timestamp = float(timestamp_as_string) - if timestamp < current_timestamp: - continue - result = modification - current_timestamp = timestamp - return result - - class TestsRestAssets(TestCase): def setUp(self) -> None: @@ -89,7 +77,7 @@ def test_create_asset_should_update_modification_history(self): case_identifier = self._subject.create_dummy_case() body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() - modification = _get_most_recent_modification(response['modification_history']) + modification = self._subject.get_most_recent_object_history_entry(response) self.assertEqual('created', modification['action']) def test_get_asset_should_return_200(self): @@ -159,7 +147,7 @@ def test_update_asset_should_update_modification_history(self): identifier = response['asset_id'] body = {'asset_type_id': 1, 'asset_name': 'new_asset_name'} response = self._subject.update(f'/api/v2/cases/{case_identifier}/assets/{identifier}', body).json() - modification = _get_most_recent_modification(response['modification_history']) + modification = self._subject.get_most_recent_object_history_entry(response) self.assertEqual('updated', modification['action']) def test_delete_asset_should_return_204(self): diff --git a/tests/tests_rest_comments.py b/tests/tests_rest_comments.py index 2250c01b6..d0ce2155a 100644 --- a/tests/tests_rest_comments.py +++ b/tests/tests_rest_comments.py @@ -18,7 +18,9 @@ from unittest import TestCase from iris import Iris +from iris import ADMINISTRATOR_USER_IDENTIFIER from iris import IRIS_PERMISSION_ALERTS_READ +from iris import IRIS_CASE_ACCESS_LEVEL_READ_ONLY _IDENTIFIER_FOR_NONEXISTENT_OBJECT = 123456789 @@ -112,7 +114,7 @@ def test_get_assets_comments_should_return_200(self): response = self._subject.get(f'/api/v2/assets/{object_identifier}/comments') self.assertEqual(200, response.status_code) - def test_get_assets_comments_should_return_403_when_user_has_no_permission_to_access_case(self): + def test_get_assets_comments_should_return_404_when_user_has_no_permission_to_access_case(self): case_identifier = self._subject.create_dummy_case() body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() @@ -120,7 +122,7 @@ def test_get_assets_comments_should_return_403_when_user_has_no_permission_to_ac user = self._subject.create_dummy_user() response = user.get(f'/api/v2/assets/{object_identifier}/comments') - self.assertEqual(403, response.status_code) + self.assertEqual(404, response.status_code) def test_get_assets_comments_should_return_404_when_asset_is_not_found(self): response = self._subject.get(f'/api/v2/assets/{_IDENTIFIER_FOR_NONEXISTENT_OBJECT}/comments') @@ -137,8 +139,8 @@ def test_get_evidences_comments_should_return_200(self): def test_get_iocs_comments_should_return_200(self): case_identifier = self._subject.create_dummy_case() body = {'ioc_type_id': 1, 'ioc_tlp_id': 2, 'ioc_value': '8.8.8.8', 'ioc_description': 'rewrw', 'ioc_tags': ''} - test = self._subject.create(f'/api/v2/cases/{case_identifier}/iocs', body).json() - object_identifier = test['ioc_id'] + response = self._subject.create(f'/api/v2/cases/{case_identifier}/iocs', body).json() + object_identifier = response['ioc_id'] response = self._subject.get(f'/api/v2/iocs/{object_identifier}/comments') self.assertEqual(200, response.status_code) @@ -172,3 +174,318 @@ def test_get_events_comments_should_return_200(self): response = self._subject.get(f'/api/v2/events/{object_identifier}/comments') self.assertEqual(200, response.status_code) + + def test_create_alerts_comment_should_return_201(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + response = self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_create_alerts_comment_should_set_comment_text(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + body = { + 'comment_text': 'comment text' + } + response = self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', body).json() + self.assertEqual('comment text', response['comment_text']) + + def test_create_alerts_comment_should_set_comment_alert_id(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + response = self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', {}).json() + self.assertEqual(object_identifier, response['comment_alert_id']) + + def test_create_alerts_comment_should_set_comment_user_id(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + response = self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', {}).json() + self.assertEqual(ADMINISTRATOR_USER_IDENTIFIER, response['comment_user_id']) + + def test_create_alerts_comment_should_add_history_entry_on_alert(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', {}).json() + response = self._subject.get(f'/api/v2/alerts/{object_identifier}', body).json() + history_entry = self._subject.get_most_recent_object_history_entry(response) + self.assertEqual('commented', history_entry['action']) + + def test_create_alerts_comment_should_call_module_hook(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + + module_identifier = self._subject.get_module_identifier_by_name('IrisCheck') + self._subject.create(f'/manage/modules/enable/{module_identifier}', {}) + self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', {}) + self._subject.create(f'/manage/modules/disable/{module_identifier}', {}) + task = self._subject.wait_for_module_task() + self.assertEqual('iris_check_module::on_postload_alert_commented', task['module']) + + def test_create_alerts_comment_should_be_tracked(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + + self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', {}) + activity = self._subject.get_latest_activity() + self.assertEqual(f'Alert "{object_identifier}" commented', activity['activity_desc']) + + def test_create_alerts_comment_should_return_404_when_alert_is_not_found(self): + response = self._subject.create(f'/api/v2/alerts/{_IDENTIFIER_FOR_NONEXISTENT_OBJECT}/comments', {}) + self.assertEqual(404, response.status_code) + + def test_create_alerts_comment_should_return_400_when_comment_text_is_not_a_string(self): + body = { + 'alert_title': 'title', + 'alert_severity_id': 4, + 'alert_status_id': 3, + 'alert_customer_id': 1, + } + response = self._subject.create('/api/v2/alerts', body).json() + object_identifier = response['alert_id'] + body = { + 'comment_text': 1 + } + response = self._subject.create(f'/api/v2/alerts/{object_identifier}/comments', body) + self.assertEqual(400, response.status_code) + + def test_create_assets_comments_should_return_201(self): + case_identifier = self._subject.create_dummy_case() + body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() + object_identifier = response['asset_id'] + + response = self._subject.create(f'/api/v2/assets/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_create_assets_comment_should_set_comment_text(self): + case_identifier = self._subject.create_dummy_case() + body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() + object_identifier = response['asset_id'] + + body = { + 'comment_text': 'comment text' + } + response = self._subject.create(f'/api/v2/assets/{object_identifier}/comments', body).json() + self.assertEqual('comment text', response['comment_text']) + + def test_create_assets_comment_should_return_404_when_asset_is_not_found(self): + response = self._subject.create(f'/api/v2/assets/{_IDENTIFIER_FOR_NONEXISTENT_OBJECT}/comments', {}) + self.assertEqual(404, response.status_code) + + def test_create_assets_comment_should_return_404_when_user_has_only_read_only_case_access(self): + case_identifier = self._subject.create_dummy_case() + body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() + object_identifier = response['asset_id'] + + user = self._subject.create_dummy_user() + user_identifier = user.get_identifier() + body = { + 'access_level': IRIS_CASE_ACCESS_LEVEL_READ_ONLY, + 'cases_list': [case_identifier] + } + self._subject.create(f'/manage/users/{user_identifier}/cases-access/update', body) + response = user.create(f'/api/v2/assets/{object_identifier}/comments', body) + self.assertEqual(404, response.status_code) + + def test_create_assets_comment_should_return_400_when_comment_text_is_not_a_string(self): + case_identifier = self._subject.create_dummy_case() + body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() + object_identifier = response['asset_id'] + body = { + 'comment_text': 1 + } + response = self._subject.create(f'/api/v2/assets/{object_identifier}/comments', body) + self.assertEqual(400, response.status_code) + + def test_create_evidences_comment_should_return_201(self): + case_identifier = self._subject.create_dummy_case() + body = {'filename': 'filename'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/evidences', body).json() + object_identifier = response['id'] + + response = self._subject.create(f'/api/v2/evidences/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_create_evidences_comment_should_return_404_when_user_has_only_read_only_case_access(self): + case_identifier = self._subject.create_dummy_case() + body = {'filename': 'filename'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/evidences', body).json() + object_identifier = response['id'] + + user = self._subject.create_dummy_user() + user_identifier = user.get_identifier() + body = { + 'access_level': IRIS_CASE_ACCESS_LEVEL_READ_ONLY, + 'cases_list': [case_identifier] + } + self._subject.create(f'/manage/users/{user_identifier}/cases-access/update', body) + response = user.create(f'/api/v2/evidences/{object_identifier}/comments', body) + self.assertEqual(404, response.status_code) + + def test_create_iocs_comment_should_return_201(self): + case_identifier = self._subject.create_dummy_case() + body = {'ioc_type_id': 1, 'ioc_tlp_id': 2, 'ioc_value': '8.8.8.8', 'ioc_description': 'rewrw', 'ioc_tags': ''} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/iocs', body).json() + object_identifier = response['ioc_id'] + + response = self._subject.create(f'/api/v2/iocs/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_create_notes_comment_should_return_201(self): + case_identifier = self._subject.create_dummy_case() + response = self._subject.create(f'/api/v2/cases/{case_identifier}/notes-directories', + {'name': 'directory_name'}).json() + directory_identifier = response['id'] + body = {'directory_id': directory_identifier} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/notes', body).json() + object_identifier = response['note_id'] + + response = self._subject.create(f'/api/v2/notes/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_delete_case_with_ioc_comment_should_return_204(self): + case_identifier = self._subject.create_dummy_case() + body = {'ioc_type_id': 1, 'ioc_tlp_id': 2, 'ioc_value': '8.8.8.8', 'ioc_description': 'rewrw', 'ioc_tags': ''} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/iocs', body).json() + object_identifier = response['ioc_id'] + + self._subject.create(f'/api/v2/iocs/{object_identifier}/comments', {}) + response = self._subject.delete(f'/api/v2/cases/{case_identifier}') + self.assertEqual(204, response.status_code) + + def test_delete_case_with_ioc_comment_should_not_delete_comments_in_another_case(self): + case_identifier = self._subject.create_dummy_case() + body = {'ioc_type_id': 1, 'ioc_tlp_id': 2, 'ioc_value': '8.8.8.8', 'ioc_description': 'rewrw', 'ioc_tags': ''} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/iocs', body).json() + object_identifier = response['ioc_id'] + self._subject.create(f'/api/v2/iocs/{object_identifier}/comments', {}) + + case_identifier2 = self._subject.create_dummy_case() + body = {'ioc_type_id': 1, 'ioc_tlp_id': 2, 'ioc_value': '8.8.8.8', 'ioc_description': 'rewrw', 'ioc_tags': ''} + response = self._subject.create(f'/api/v2/cases/{case_identifier2}/iocs', body).json() + object_identifier2 = response['ioc_id'] + self._subject.create(f'/api/v2/iocs/{object_identifier2}/comments', {}) + self._subject.delete(f'/api/v2/cases/{case_identifier2}') + + response = self._subject.get(f'/api/v2/iocs/{object_identifier}/comments').json() + self.assertEqual(1, response['total']) + + def test_delete_case_with_asset_comment_should_return_204(self): + case_identifier = self._subject.create_dummy_case() + body = {'asset_type_id': 1, 'asset_name': 'admin_laptop_test'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/assets', body).json() + object_identifier = response['asset_id'] + + self._subject.create(f'/api/v2/assets/{object_identifier}/comments', {}) + response = self._subject.delete(f'/api/v2/cases/{case_identifier}') + self.assertEqual(204, response.status_code) + + def test_delete_case_with_evidences_comment_should_return_204(self): + case_identifier = self._subject.create_dummy_case() + body = {'filename': 'filename'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/evidences', body).json() + object_identifier = response['id'] + + self._subject.create(f'/api/v2/evidences/{object_identifier}/comments', {}) + response = self._subject.delete(f'/api/v2/cases/{case_identifier}') + self.assertEqual(204, response.status_code) + + def test_delete_case_with_notes_comment_should_return_204(self): + case_identifier = self._subject.create_dummy_case() + response = self._subject.create(f'/api/v2/cases/{case_identifier}/notes-directories', + {'name': 'directory_name'}).json() + directory_identifier = response['id'] + body = {'directory_id': directory_identifier} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/notes', body).json() + object_identifier = response['note_id'] + + self._subject.create(f'/api/v2/notes/{object_identifier}/comments', {}) + response = self._subject.delete(f'/api/v2/cases/{case_identifier}') + self.assertEqual(204, response.status_code) + + def test_create_tasks_comment_should_return_201(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'dummy title'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + object_identifier = response['id'] + + response = self._subject.create(f'/api/v2/tasks/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_delete_case_with_tasks_comment_should_return_204(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'dummy title'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + object_identifier = response['id'] + + self._subject.create(f'/api/v2/tasks/{object_identifier}/comments', {}) + response = self._subject.delete(f'/api/v2/cases/{case_identifier}') + self.assertEqual(204, response.status_code) + + def test_create_events_comment_should_return_201(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + object_identifier = response['event_id'] + + response = self._subject.create(f'/api/v2/events/{object_identifier}/comments', {}) + self.assertEqual(201, response.status_code) + + def test_delete_case_with_events_comment_should_return_204(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + object_identifier = response['event_id'] + + self._subject.create(f'/api/v2/events/{object_identifier}/comments', {}) + response = self._subject.delete(f'/api/v2/cases/{case_identifier}') + self.assertEqual(204, response.status_code) diff --git a/tests/tests_rest_miscellaneous.py b/tests/tests_rest_miscellaneous.py index db7696ff9..389c3ebb2 100644 --- a/tests/tests_rest_miscellaneous.py +++ b/tests/tests_rest_miscellaneous.py @@ -19,8 +19,6 @@ from unittest import TestCase from iris import Iris -from time import sleep - class TestsRestMiscellaneous(TestCase): @@ -72,25 +70,11 @@ def test_get_timeline_state_should_return_200(self): # since, by then, the case has already been removed from database, on the identifier and the fields with a server_default are filled # in particulier, client_id is None, and the code fails during the commit def test_delete_case_should_set_module_state_to_success(self): - response = self._subject.get('/manage/modules/list').json() - module_identifier = None - for module in response['data']: - if module['module_human_name'] == 'IrisCheck': - module_identifier = module['id'] + module_identifier = self._subject.get_module_identifier_by_name('IrisCheck') self._subject.create(f'/manage/modules/enable/{module_identifier}', {}) case_identifier = self._subject.create_dummy_case() self._subject.delete(f'/api/v2/cases/{case_identifier}') self._subject.create(f'/manage/modules/disable/{module_identifier}', {}) - response = self._subject.get('/dim/tasks/list/1').json() - attempts = 0 - while len(response['data']) == 0: - sleep(1) - response = self._subject.get('/dim/tasks/list/1').json() - attempts += 1 - if attempts > 20: - logs = self._subject.extract_logs('worker') - self.fail(f'Timed out with logs: {logs}') - module = response['data'][0] - - self.assertEqual('success', module['state']) + task = self._subject.wait_for_module_task() + self.assertEqual('success', task['state']) diff --git a/tests/tests_rest_notes_directories.py b/tests/tests_rest_notes_directories.py index 77563920f..0547060dc 100644 --- a/tests/tests_rest_notes_directories.py +++ b/tests/tests_rest_notes_directories.py @@ -157,9 +157,8 @@ def test_update_note_directory_should_add_an_activity(self): body = {'name': 'new name'} self._subject.update(f'/api/v2/cases/{case_identifier}/notes-directories/{identifier}', body) - activities = self._subject.get('/case/activities/list', {'cid': case_identifier}).json() - last_activity = activities['data'][0]['activity_desc'] - self.assertEqual('Modified directory "new name"', last_activity) + last_activity = self._subject.get_latest_activity() + self.assertEqual('Modified directory "new name"', last_activity['activity_desc']) def test_update_note_directory_should_return_400_when_field_name_is_not_a_string(self): case_identifier = self._subject.create_dummy_case()