Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pulling in latest fhaviary, mypy, ruff #647

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .github/renovate.json5
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
prHourlyLimit: 4,
timezone: "America/Los_Angeles",
rangeStrategy: "widen",
lockFileMaintenance: {
enabled: true,
},
"pre-commit": {
enabled: true,
},
lockFileMaintenance: { enabled: true },
"pre-commit": { enabled: true },
packageRules: [
{
// Allow 'widen' range strategy while matching aviary_internal pyproject.toml
matchPackageNames: ["openai"],
allowedVersions: "<1.47",
},
{
// TODO: remove after fhaviary supports Python 3.13
matchPackageNames: ["python"],
allowedVersions: "<=3.12",
},
],
}
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down Expand Up @@ -55,36 +55,37 @@ repos:
hooks:
- id: check-mailmap
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.20.2
rev: v0.22
hooks:
- id: validate-pyproject
additional_dependencies:
- "validate-pyproject-schema-store[all]>=2024.06.24" # Pin for Ruff's FURB154
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.4.21
rev: 0.4.27
hooks:
- id: uv-lock
- repo: https://github.com/renovatebot/pre-commit-hooks
rev: 38.122.0
rev: 38.131.1
hooks:
- id: renovate-config-validator
args: [--strict]
- repo: https://github.com/adamchainz/blacken-docs
rev: 1.19.0
rev: 1.19.1
hooks:
- id: blacken-docs
- repo: https://github.com/jsh9/markdown-toc-creator
rev: 0.0.8
hooks:
- id: markdown-toc-creator
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
args: [--pretty, --ignore-missing-imports]
additional_dependencies:
- aiohttp
- coredis
- fhaviary[llm]>=0.6 # Match pyproject.toml
- fhaviary[llm]>=0.8.2 # Match pyproject.toml
- ldp>=0.9 # Match pyproject.toml
- html2text
- httpx
Expand Down
11 changes: 8 additions & 3 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from copy import deepcopy
from typing import Any, Self, cast

from aviary.env import Environment, Frame
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage
from aviary.core import (
Environment,
Frame,
Message,
Tool,
ToolRequestMessage,
ToolResponseMessage,
)

from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
Expand Down
10 changes: 6 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any

from aviary.message import MalformedMessageError, Message
from aviary.tools import (
from aviary.core import (
MalformedMessageError,
Message,
Tool,
ToolCall,
ToolRequestMessage,
Expand Down Expand Up @@ -40,7 +41,7 @@ class Callback: # type: ignore[no-redef]
from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch

if TYPE_CHECKING:
from aviary.env import Environment
from aviary.core import Environment
from ldp.agent import Agent, SimpleAgentState
from ldp.graph.ops import OpResult

Expand Down Expand Up @@ -234,7 +235,8 @@ async def run_aviary_agent(
while not done:
if max_timesteps is not None and timestep >= max_timesteps:
logger.warning(
f"Agent didn't finish within {max_timesteps} timesteps, just answering."
f"Agent didn't finish within {max_timesteps} timesteps, just"
" answering."
)
generate_answer_tool = next(
filter(
Expand Down
21 changes: 9 additions & 12 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ async def maybe_get_manifest(
}
if not file_loc_to_records:
raise ValueError( # noqa: TRY301
f"No mapping of file location to details extracted from manifest"
"No mapping of file location to details extracted from manifest"
f" file {filename}."
)
logger.debug(
Expand Down Expand Up @@ -593,11 +593,9 @@ async def get_directory_index( # noqa: PLR0912
index_settings = _settings.agent.index
if index_name:
warnings.warn(
(
f"The index_name argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6."
),
"The index_name argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand All @@ -620,11 +618,9 @@ async def get_directory_index( # noqa: PLR0912

if not sync_index_w_directory:
warnings.warn(
(
f"The sync_index_w_directory argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6."
),
"The sync_index_w_directory argument has been moved to"
f" {type(_settings.agent.index).__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -686,7 +682,8 @@ async def get_directory_index( # noqa: PLR0912
)
else:
logger.debug(
f"File {rel_file_path} found in paper directory {paper_directory}."
f"File {rel_file_path} found in paper directory"
f" {paper_directory}."
)

if search_index.changed:
Expand Down
12 changes: 9 additions & 3 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Self, assert_never

from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage
from aviary.core import (
TASK_DATASET_REGISTRY,
Frame,
Message,
TaskDataset,
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY

from paperqa.types import DocDetails

Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/semantic_scholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ async def s2_title_search(
# need to check if nested under a 'data' key or not (depends on filtering)
if (
strings_similarity(
data.get("title") if "data" not in data else data["data"][0]["title"],
data.get("title", "") if "data" not in data else data["data"][0]["title"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mskarlin FYI on this

title,
)
< title_similarity_threshold
Expand Down
3 changes: 2 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def __init__(self, **kwargs):
from sentence_transformers import SentenceTransformer
except ImportError as exc:
raise ImportError(
"Please install paper-qa[local] to use SentenceTransformerEmbeddingModel."
"Please install paper-qa[local] to use"
" SentenceTransformerEmbeddingModel."
) from exc

self._model = SentenceTransformer(self.name)
Expand Down
6 changes: 2 additions & 4 deletions paperqa/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,6 @@ async def try_acquire(
raise_impossible_limits (:obj:`bool`, optional): flag will raise a
ValueError for weights that exceed the rate.

Returns:
None if the rate limit is satisfied.

Raises:
TimeoutError: if the acquire_timeout is exceeded.
ValueError: if the weight exceeds the rate limit and raise_impossible_limits is True.
Expand All @@ -352,7 +349,8 @@ async def try_acquire(

if rate_limit.amount < weight and raise_impossible_limits:
raise ValueError(
f"Weight ({weight}) > RateLimit ({rate_limit}), cannot satisfy rate limit."
f"Weight ({weight}) > RateLimit ({rate_limit}), cannot satisfy rate"
" limit."
)
while True:
elapsed = 0.0
Expand Down
12 changes: 9 additions & 3 deletions paperqa/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from math import ceil
from pathlib import Path
from typing import Literal, overload
from typing import Literal, cast, overload

import pymupdf
import tiktoken
Expand Down Expand Up @@ -172,7 +172,9 @@ def chunk_text(
f"ParsedText.content must be a `str`, not {type(parsed_text.content)}."
)

content = parsed_text.content if not use_tiktoken else parsed_text.encode_content()
content: str | list[int] = (
parsed_text.content if not use_tiktoken else parsed_text.encode_content()
)
if not content: # Avoid div0 in token calculations
raise ImpossibleParsingError(
f"No text was parsed from the document named {doc.docname!r} with ID"
Expand All @@ -195,7 +197,11 @@ def chunk_text(
]
texts.append(
Text(
text=enc.decode(split) if use_tiktoken else split,
text=(
enc.decode(cast(list[int], split))
if use_tiktoken
else cast(str, split)
),
name=f"{doc.docname} chunk {i + 1}",
doc=doc,
)
Expand Down
35 changes: 16 additions & 19 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Self, assert_never, cast

import anyio
from aviary.tools import ToolSelector
from aviary.core import ToolSelector
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -99,11 +99,9 @@ def _deprecated_field(self) -> Self:
# default is True, so we only warn if it's False
if not self.evidence_detailed_citations:
warnings.warn(
(
"The 'evidence_detailed_citations' field is deprecated and will be"
" removed in version 6. Adjust 'PromptSettings.context_inner' to remove"
" detailed citations."
),
"The 'evidence_detailed_citations' field is deprecated and will be"
" removed in version 6. Adjust 'PromptSettings.context_inner' to remove"
" detailed citations.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -259,8 +257,10 @@ class PromptSettings(BaseModel):
)
context_inner: str = Field(
default=CONTEXT_INNER_PROMPT,
description="Prompt for how to format a single context in generate answer. "
"This should at least contain key and name.",
description=(
"Prompt for how to format a single context in generate answer. "
"This should at least contain key and name."
),
)

@field_validator("summary")
Expand Down Expand Up @@ -380,7 +380,8 @@ class IndexSettings(BaseModel):
default=True,
description=(
"Whether to sync the index with the paper directory when loading an index."
" Setting to True will add or delete index files to match the source paper directory."
" Setting to True will add or delete index files to match the source paper"
" directory."
),
)

Expand Down Expand Up @@ -537,11 +538,9 @@ def _deprecated_field(self) -> Self:
value = getattr(self, deprecated_field_name)
if value != type(self).model_fields[deprecated_field_name].default:
warnings.warn(
(
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6."
),
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -667,11 +666,9 @@ def _deprecated_field(self) -> Self:
value = getattr(self, deprecated_field_name)
if value != type(self).model_fields[deprecated_field_name].default:
warnings.warn(
(
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6."
),
f"The {deprecated_field_name!r} field has been moved to"
f" {AgentSettings.__name__},"
" this deprecation will conclude in version 6.",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down
12 changes: 7 additions & 5 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import litellm # for cost
import tiktoken
from aviary.message import Message
from aviary.core import Message
from pybtex.database import BibliographyData, Entry, Person
from pybtex.database.input.bibtex import Parser
from pybtex.scanner import PybtexSyntaxError
Expand Down Expand Up @@ -360,8 +360,10 @@ class DocDetails(Doc):
file_location: str | os.PathLike | None = None
license: str | None = Field(
default=None,
description="string indicating license."
" Should refer specifically to pdf_url (since that could be preprint). None means unknown/unset.",
description=(
"string indicating license. Should refer specifically to pdf_url (since"
" that could be preprint). None means unknown/unset."
),
)
pdf_url: str | None = None
other: dict[str, Any] = Field(
Expand Down Expand Up @@ -612,8 +614,8 @@ def formatted_citation(self) -> str:

if self.source_quality_message:
return (
f"{self.citation} This article has {self.citation_count} citations and is"
f" from a {self.source_quality_message}."
f"{self.citation} This article has {self.citation_count} citations and"
f" is from a {self.source_quality_message}."
)
return f"{self.citation} This article has {self.citation_count} citations."

Expand Down
Loading
Loading