diff --git a/h/storage.py b/h/storage.py index db1ca4ccbb3..d3b49b9f982 100644 --- a/h/storage.py +++ b/h/storage.py @@ -20,9 +20,9 @@ from pyramid import i18n from h import models, schemas -from h.db import types from h.models.document import update_document_metadata from h.security import Permission +from h.services.annotation_read import AnnotationReadService from h.traversal.group import GroupContext from h.util.group_scope import url_in_scope from h.util.uri import normalize as normalize_uri @@ -30,25 +30,6 @@ _ = i18n.TranslationStringFactory(__package__) -def fetch_annotation(session, id_): - """ - Fetch the annotation with the given id. - - :param session: the database session - :type session: sqlalchemy.orm.session.Session - - :param id_: the annotation ID - :type id_: str - - :returns: the annotation, if found, or None. - :rtype: h.models.Annotation, NoneType - """ - try: - return session.query(models.Annotation).get(id_) - except types.InvalidUUID: - return None - - def fetch_ordered_annotations(session, ids, query_processor=None): """ Fetch all annotations with the given ids and order them based on the list of ids. @@ -96,11 +77,13 @@ def create_annotation(request, data): """ document_data = data.pop("document", {}) + annotation_read: AnnotationReadService = request.find_service(AnnotationReadService) + # Replies must have the same group as their parent. if data["references"]: root_annotation_id = data["references"][0] - if root_annotation := fetch_annotation(request.db, root_annotation_id): + if root_annotation := annotation_read.get_annotation_by_id(root_annotation_id): data["groupid"] = root_annotation.groupid else: raise schemas.ValidationError( diff --git a/tests/h/storage_test.py b/tests/h/storage_test.py index 115b2e2b478..f1dacb3e9ea 100644 --- a/tests/h/storage_test.py +++ b/tests/h/storage_test.py @@ -16,17 +16,6 @@ pytestmark = pytest.mark.usefixtures("search_index") -class TestFetchAnnotation: - def test_it_fetches_and_returns_the_annotation(self, db_session, factories): - annotation = factories.Annotation() - - actual = storage.fetch_annotation(db_session, annotation.id) - assert annotation == actual - - def test_it_does_not_crash_if_id_is_invalid(self, db_session): - assert storage.fetch_annotation(db_session, "foo") is None - - class TestFetchOrderedAnnotations: def test_it_returns_annotations_for_ids_in_the_same_order( self, db_session, factories @@ -130,6 +119,7 @@ def test_expand_uri_document_uris(self, db_session, normalized, expected_uris): assert uris == expected_uris +@pytest.mark.usefixtures("annotation_read_service") class TestCreateAnnotation: def test_it(self, pyramid_request, annotation_data, datetime): annotation = storage.create_annotation(pyramid_request, annotation_data) @@ -181,10 +171,16 @@ def test_it_queues_the_search_index( ) def test_it_sets_the_group_to_match_the_parent_for_replies( - self, pyramid_request, annotation_data, factories, other_group + self, + pyramid_request, + annotation_data, + factories, + other_group, + annotation_read_service, ): parent_annotation = factories.Annotation(group=other_group) annotation_data["references"] = [parent_annotation.id] + annotation_read_service.get_annotation_by_id.return_value = parent_annotation annotation = storage.create_annotation(pyramid_request, annotation_data) @@ -192,8 +188,10 @@ def test_it_sets_the_group_to_match_the_parent_for_replies( assert annotation.group == parent_annotation.group def test_it_raises_if_parent_annotation_does_not_exist( - self, pyramid_request, annotation_data + self, pyramid_request, annotation_data, annotation_read_service ): + annotation_read_service.get_annotation_by_id.return_value = None + annotation_data["references"] = ["MISSING_ID"] with pytest.raises(ValidationError):