Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes gather_with_concurrency typing #714

Merged
merged 16 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ repos:
- pybtex
- numpy
- pandas-stubs
- pydantic~=2.0 # Match pyproject.toml
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
- pydantic~=2.0
- pydantic-settings
- rich
- tantivy
Expand Down
18 changes: 11 additions & 7 deletions paperqa/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
import logging
from collections.abc import Collection, Coroutine, Sequence
from collections.abc import Awaitable, Collection, Coroutine, Sequence
from typing import Any, cast

import aiohttp
Expand Down Expand Up @@ -161,19 +161,22 @@ async def query(self, **kwargs) -> DocDetails | None:

# then process and re-aggregate the results
if doc_details and task.processors:
doc_details = sum(
await gather_with_concurrency(
len(task.processors),
task.processor_queries(doc_details, session),
doc_details = (
sum(
await gather_with_concurrency(
len(task.processors),
task.processor_queries(doc_details, session),
)
)
or None
)

if doc_details:

# abuse int handling in __add__ for empty all_doc_details, None types won't work
all_doc_details = doc_details + (all_doc_details or 0)

if not cast(DocDetails, all_doc_details).is_hydration_needed(
if not all_doc_details.is_hydration_needed(
inclusion=kwargs.get("fields", [])
):
logger.debug(
Expand All @@ -191,7 +194,8 @@ async def bulk_query(
self, queries: Collection[dict[str, Any]], concurrency: int = 10
) -> list[DocDetails]:
return await gather_with_concurrency(
concurrency, [self.query(**kwargs) for kwargs in queries]
concurrency,
[cast(Awaitable[DocDetails], self.query(**kwargs)) for kwargs in queries],
)

async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ async def aget_evidence(
for _, llm_result in results:
session.add_tokens(llm_result)

session.contexts += [r for r, _ in results if r is not None]
session.contexts += [r for r, _ in results]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being discussed in #687. However, the typing change didn't fix the mypy error saying that r is not None is a always True statement

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot do this I think. You can get a None back I believe

return session

def query(
Expand Down
6 changes: 5 additions & 1 deletion paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
cvar_session_id = contextvars.ContextVar[UUID | None]("session_id", default=None)


def get_session_id() -> UUID | None:
return cvar_session_id.get()


@contextmanager
def set_llm_session_ids(session_id: UUID):
token = cvar_session_id.set(session_id)
Expand Down Expand Up @@ -73,7 +77,7 @@ class LLMResult(BaseModel):

id: UUID = Field(default_factory=uuid4)
session_id: UUID | None = Field(
default_factory=cvar_session_id.get,
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
default_factory=get_session_id,
description="A persistent ID to associate a group of LLMResults",
alias="answer_id",
)
Expand Down
8 changes: 5 additions & 3 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import re
import string
import unicodedata
from collections.abc import Collection, Coroutine, Iterable, Iterator
from collections.abc import Awaitable, Collection, Iterable, Iterator
from datetime import datetime
from functools import reduce
from http import HTTPStatus
from pathlib import Path
from typing import Any, BinaryIO, ClassVar
from typing import Any, BinaryIO, ClassVar, TypeVar
from uuid import UUID

import aiohttp
Expand All @@ -36,6 +36,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")


class ImpossibleParsingError(Exception):
"""Error to throw when a parsing is impossible."""
Expand Down Expand Up @@ -103,7 +105,7 @@ def md5sum(file_path: str | os.PathLike) -> str:
return hexdigest(f.read())


async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]:
async def gather_with_concurrency(n: int, coros: Iterable[Awaitable[T]]) -> list[T]:
# https://stackoverflow.com/a/61478547/2392535
semaphore = asyncio.Semaphore(n)

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.