Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move fetch_ordered_annotations into a new annotation service #7927

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions h/activity/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import newrelic.agent
from pyramid.httpexceptions import HTTPFound
from sqlalchemy.orm import subqueryload

from h import links, presenters, storage
from h import links, presenters
from h.activity import bucketing
from h.models import Annotation, Group
from h.search import (
Expand All @@ -15,6 +14,7 @@
UsersAggregation,
parser,
)
from h.services import AnnotationService


class ActivityResults(
Expand Down Expand Up @@ -115,7 +115,7 @@ def execute(request, query, page_size):

# Load all referenced annotations from the database, bucket them, and add
# the buckets to result.timeframes.
anns = fetch_annotations(request.db, search_result.annotation_ids)
anns = _fetch_annotations(request, search_result.annotation_ids)
result.timeframes.extend(bucketing.bucket(anns))

# Fetch all groups
Expand Down Expand Up @@ -155,16 +155,11 @@ def aggregations_for(query):


@newrelic.agent.function_trace()
def fetch_annotations(session, ids):
def load_documents(query):
return query.options(subqueryload(Annotation.document))

annotations = storage.fetch_ordered_annotations(
session, ids, query_processor=load_documents
def _fetch_annotations(request, ids):
return request.find_service(AnnotationService).get_annotations_by_id(
ids=ids, eager_load=[Annotation.document]
)

return annotations


@newrelic.agent.function_trace()
def _execute_search(request, query, page_size):
Expand Down
4 changes: 4 additions & 0 deletions h/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Service definitions that handle business logic."""
from h.services.annotation import AnnotationService
from h.services.auth_cookie import AuthCookieService
from h.services.bulk_annotation import BulkAnnotationService
from h.services.subscription import SubscriptionService


def includeme(config): # pragma: no cover
config.register_service_factory(
"h.services.annotation.service_factory", iface=AnnotationService
)
config.register_service_factory(".annotation_json.factory", name="annotation_json")
config.register_service_factory(
".annotation_moderation.annotation_moderation_service_factory",
Expand Down
40 changes: 40 additions & 0 deletions h/services/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Iterable, List, Optional

from sqlalchemy import select
from sqlalchemy.orm import Session, subqueryload

from h.models import Annotation


class AnnotationService:
"""A service for storing and retrieving annotations."""

def __init__(self, db_session: Session):
self._db = db_session

def get_annotations_by_id(
self, ids: List[str], eager_load: Optional[List] = None
) -> Iterable[Annotation]:
"""
Get annotations in the same order as the provided ids.

:param ids: the list of annotation ids
:param eager_load: A list of annotation relationships to eager load
like `Annotation.document`
"""

if not ids:
return []

query = select(Annotation).where(Annotation.id.in_(ids))
if eager_load:
query = query.options(subqueryload(*eager_load))

annotations = self._db.execute(query).scalars()
return sorted(annotations, key=lambda annotation: ids.index(annotation.id))


def service_factory(_context, request) -> AnnotationService:
"""Get an annotation service instance."""

return AnnotationService(db_session=request.db)
56 changes: 27 additions & 29 deletions h/services/annotation_json.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from copy import deepcopy

from sqlalchemy.orm import subqueryload

from h import storage
from h.models import Annotation, User
from h.security import Identity, identity_permits
from h.security.permissions import Permission
from h.services.annotation import AnnotationService
from h.services.flag import FlagService
from h.services.links import LinksService
from h.services.user import UserService
from h.session import user_info
from h.traversal import AnnotationContext
from h.util.datetime import utc_iso8601
Expand All @@ -14,16 +15,23 @@
class AnnotationJSONService:
"""A service for generating API compatible JSON for annotations."""

def __init__(self, session, links_service, flag_service, user_service):
# pylint: disable=too-many-arguments
def __init__(
self,
annotation_service: AnnotationService,
links_service: LinksService,
flag_service: FlagService,
user_service: UserService,
):
"""
Instantiate the service.

:param session: DB session
:param annotation_service: AnnotationService instance
:param links_service: LinksService instance
:param flag_service: FlagService instance
:param user_service: UserService instance
"""
self._session = session
self._annotation_service = annotation_service
self._links_service = links_service
self._flag_service = flag_service
self._user_service = user_service
Expand Down Expand Up @@ -136,34 +144,25 @@ def present_all_for_user(self, annotation_ids, user: User):
self._flag_service.all_flagged(user, annotation_ids)
self._flag_service.flag_counts(annotation_ids)

annotations = storage.fetch_ordered_annotations(
self._session,
annotation_ids,
query_processor=self._eager_load_related_items,
annotations = self._annotation_service.get_annotations_by_id(
ids=annotation_ids,
eager_load=[
# Optimise access to the document
Annotation.document,
# Optimise the check used for "hidden" above
Annotation.moderation,
# Optimise the permissions check for MODERATE permissions,
# which ultimately depends on group permissions, causing a
# group lookup for every annotation without this
Annotation.group,
],
)

# Optimise the user service `fetch()` call
self._user_service.fetch_all([annotation.userid for annotation in annotations])

return [self.present_for_user(annotation, user) for annotation in annotations]

@staticmethod
def _eager_load_related_items(query):
# Ensure that accessing `annotation.document` or `.moderation`
# doesn't trigger any more queries by pre-loading these

return query.options(
# Optimise access to the document which is called in
# `AnnotationJSONPresenter`
subqueryload(Annotation.document),
# Optimise the check used for "hidden" above
subqueryload(Annotation.moderation),
# Optimise the permissions check for MODERATE permissions,
# which ultimately depends on group permissions, causing a
# group lookup for every annotation without this
subqueryload(Annotation.group),
)

@classmethod
def _get_read_permission(cls, annotation):
if not annotation.shared:
Expand All @@ -185,8 +184,7 @@ def _get_read_permission(cls, annotation):

def factory(_context, request):
return AnnotationJSONService(
session=request.db,
# Services
annotation_service=request.find_service(AnnotationService),
links_service=request.find_service(name="links"),
flag_service=request.find_service(name="flag"),
user_service=request.find_service(name="user"),
Expand Down
35 changes: 0 additions & 35 deletions h/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,6 @@ def fetch_annotation(session, id_):
return None


def fetch_ordered_annotations(session, ids, query_processor=None):
"""
Fetch all annotations with the given ids and order them based on the list of ids.

The optional `query_processor` parameter allows for passing in a function
that can change the query before it is run, especially useful for
eager-loading certain data. The function will get the query as an argument
and has to return a query object again.

:param session: the database session
:type session: sqlalchemy.orm.session.Session

:param ids: the list of annotation ids
:type ids: list

:param query_processor: an optional function that takes the query and
returns an updated query
:type query_processor: callable

:returns: the annotation, if found, or None.
:rtype: h.models.Annotation, NoneType
"""
if not ids:
return []

ordering = {x: i for i, x in enumerate(ids)}

query = session.query(models.Annotation).filter(models.Annotation.id.in_(ids))
if query_processor:
query = query_processor(query)

anns = sorted(query, key=lambda a: ordering.get(a.id))
return anns


def create_annotation(request, data):
"""
Create an annotation from already-validated data.
Expand Down
19 changes: 11 additions & 8 deletions h/views/feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@
from pyramid.view import view_config
from webob.multidict import MultiDict

from h import search
from h.feeds import render_atom, render_rss
from h.storage import fetch_ordered_annotations
from h.search import Search
from h.services import AnnotationService

_ = i18n.TranslationStringFactory(__package__)


def _annotations(request):
"""Return the annotations from the search API."""
result = search.Search(request).run(MultiDict(request.params))
return fetch_ordered_annotations(request.db, result.annotation_ids)


@view_config(route_name="stream_atom")
def stream_atom(request):
"""Get an Atom feed of the /stream page."""
Expand All @@ -40,3 +34,12 @@ def stream_rss(request):
description=request.registry.settings.get("h.feed.description")
or _("The Web. Annotated"),
)


def _annotations(request):
"""Return the annotations from the search API."""
result = Search(request).run(MultiDict(request.params))

return request.find_service(AnnotationService).get_annotations_by_id(
ids=result.annotation_ids
)
8 changes: 7 additions & 1 deletion tests/common/fixtures/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from h.services import BulkAnnotationService
from h.services import AnnotationService, BulkAnnotationService
from h.services.annotation_delete import AnnotationDeleteService
from h.services.annotation_json import AnnotationJSONService
from h.services.annotation_moderation import AnnotationModerationService
Expand Down Expand Up @@ -34,6 +34,7 @@
"mock_service",
"annotation_delete_service",
"annotation_json_service",
"annotation_service",
"auth_cookie_service",
"auth_token_service",
"bulk_annotation_service",
Expand Down Expand Up @@ -88,6 +89,11 @@ def annotation_json_service(mock_service):
return mock_service(AnnotationJSONService, name="annotation_json")


@pytest.fixture
def annotation_service(mock_service):
return mock_service(AnnotationService)


@pytest.fixture
def auth_cookie_service(mock_service):
return mock_service(AuthCookieService)
Expand Down
31 changes: 10 additions & 21 deletions tests/h/activity/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pyramid.httpexceptions import HTTPFound
from webob.multidict import MultiDict

from h.activity.query import check_url, execute, extract, fetch_annotations
from h.activity.query import check_url, execute, extract
from h.models import Annotation


class TestExtract:
Expand Down Expand Up @@ -190,7 +191,7 @@ def unparse(self):


@pytest.mark.usefixtures(
"fetch_annotations",
"annotation_service",
"_fetch_groups",
"bucketing",
"presenters",
Expand Down Expand Up @@ -353,20 +354,22 @@ def test_it_returns_the_search_result_if_there_are_no_matches(
assert result.timeframes == []

def test_it_fetches_the_annotations_from_the_database(
self, fetch_annotations, pyramid_request, search
self, annotation_service, pyramid_request, search
):
execute(pyramid_request, MultiDict(), self.PAGE_SIZE)

fetch_annotations.assert_called_once_with(
pyramid_request.db, search.run.return_value.annotation_ids
annotation_service.get_annotations_by_id.assert_called_once_with(
ids=search.run.return_value.annotation_ids, eager_load=[Annotation.document]
)

def test_it_buckets_the_annotations(
self, fetch_annotations, bucketing, pyramid_request
self, annotation_service, bucketing, pyramid_request
):
result = execute(pyramid_request, MultiDict(), self.PAGE_SIZE)

bucketing.bucket.assert_called_once_with(fetch_annotations.return_value)
bucketing.bucket.assert_called_once_with(
annotation_service.get_annotations_by_id.return_value
)
assert result.timeframes == bucketing.bucket.return_value

def test_it_fetches_the_groups_from_the_database(
Expand Down Expand Up @@ -460,10 +463,6 @@ def test_it_returns_the_aggregations(self, pyramid_request):

assert result.aggregations == mock.sentinel.aggregations

@pytest.fixture
def fetch_annotations(self, patch):
return patch("h.activity.query.fetch_annotations")

@pytest.fixture
def _fetch_groups(self, group_pubids, patch):
_fetch_groups = patch("h.activity.query._fetch_groups")
Expand Down Expand Up @@ -607,16 +606,6 @@ def pyramid_request(self, pyramid_request):
return pyramid_request


class TestFetchAnnotations:
def test_it_returns_annotations_by_ids(self, db_session, factories):
annotations = factories.Annotation.create_batch(3)
ids = [a.id for a in annotations]

result = fetch_annotations(db_session, ids)

assert annotations == result


@pytest.fixture
def pyramid_request(pyramid_request):
class DummyRoute:
Expand Down
Loading