Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes gather_with_concurrency typing #714

Merged
merged 16 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
18 changes: 11 additions & 7 deletions paperqa/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
import logging
from collections.abc import Collection, Coroutine, Sequence
from collections.abc import Awaitable, Collection, Coroutine, Sequence
from typing import Any, cast

import aiohttp
Expand Down Expand Up @@ -161,19 +161,22 @@ async def query(self, **kwargs) -> DocDetails | None:

# then process and re-aggregate the results
if doc_details and task.processors:
doc_details = sum(
await gather_with_concurrency(
len(task.processors),
task.processor_queries(doc_details, session),
doc_details = (
sum(
await gather_with_concurrency(
len(task.processors),
task.processor_queries(doc_details, session),
)
)
or None
)

if doc_details:

# abuse int handling in __add__ for empty all_doc_details, None types won't work
all_doc_details = doc_details + (all_doc_details or 0)

if not cast(DocDetails, all_doc_details).is_hydration_needed(
if not all_doc_details.is_hydration_needed(
inclusion=kwargs.get("fields", [])
):
logger.debug(
Expand All @@ -191,7 +194,8 @@ async def bulk_query(
self, queries: Collection[dict[str, Any]], concurrency: int = 10
) -> list[DocDetails]:
return await gather_with_concurrency(
concurrency, [self.query(**kwargs) for kwargs in queries]
concurrency,
[cast(Awaitable[DocDetails], self.query(**kwargs)) for kwargs in queries],
)

async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ async def aget_evidence(
for _, llm_result in results:
session.add_tokens(llm_result)

session.contexts += [r for r, _ in results if r is not None]
session.contexts += [r for r, _ in results]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being discussed in #687. However, the typing change didn't fix the mypy error saying that r is not None is a always True statement

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot do this I think. You can get a None back I believe

return session

def query(
Expand Down
2 changes: 1 addition & 1 deletion paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class LLMResult(BaseModel):

id: UUID = Field(default_factory=uuid4)
session_id: UUID | None = Field(
default_factory=cvar_session_id.get,
maykcaldas marked this conversation as resolved.
Show resolved Hide resolved
default_factory=cvar_session_id.get, # type: ignore[arg-type]
description="A persistent ID to associate a group of LLMResults",
alias="answer_id",
)
Expand Down
8 changes: 5 additions & 3 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import re
import string
import unicodedata
from collections.abc import Collection, Coroutine, Iterable, Iterator
from collections.abc import Awaitable, Collection, Iterable, Iterator
from datetime import datetime
from functools import reduce
from http import HTTPStatus
from pathlib import Path
from typing import Any, BinaryIO, ClassVar
from typing import Any, BinaryIO, ClassVar, TypeVar
from uuid import UUID

import aiohttp
Expand All @@ -36,6 +36,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")


class ImpossibleParsingError(Exception):
"""Error to throw when a parsing is impossible."""
Expand Down Expand Up @@ -103,7 +105,7 @@ def md5sum(file_path: str | os.PathLike) -> str:
return hexdigest(f.read())


async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]:
async def gather_with_concurrency(n: int, coros: Iterable[Awaitable[T]]) -> list[T]:
# https://stackoverflow.com/a/61478547/2392535
semaphore = asyncio.Semaphore(n)

Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ ignore = [
"S311", # Ok to use python random
"SLF001", # Overly pedantic
"T201", # Overly pedantic
"TCH001", # TCH001, TCH002, TCH003: don't care to enforce type checking blocks
"TCH002",
"TCH003",
"TC001", # TCH001, TCH002, TCH003: don't care to enforce type checking blocks
"TC002",
"TC003",
"TD002", # Don't care for TODO author
"TD003", # Don't care for TODO links
"TRY003", # Overly pedantic
Expand All @@ -391,7 +391,7 @@ unfixable = [
"ERA001", # While debugging, temporarily commenting code can be useful
"F401", # While debugging, unused imports can be useful
"F841", # While debugging, unused locals can be useful
"TCH004", # While debugging, it can be nice to keep TYPE_CHECKING in tact
"TC004", # While debugging, it can be nice to keep TYPE_CHECKING in tact
]

[tool.ruff.lint.flake8-annotations]
Expand Down
Loading