Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to enable message parts #704

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 26 additions & 38 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from paperqa.readers import read_doc
from paperqa.settings import MaybeSettings, get_settings
from paperqa.types import (
Context,
Doc,
DocDetails,
DocKey,
Expand Down Expand Up @@ -705,7 +706,11 @@ async def aquery( # noqa: PLR0912
)
contexts = session.contexts
pre_str = None
if prompt_config.pre is not None:
# we check if we have a pre-context
# to avoid recreating it
if prompt_config.pre is not None and not any(
c.text.name == "Extra background information" for c in contexts
):
with set_llm_session_ids(session.id):
pre = await llm_model.run_prompt(
prompt=prompt_config.pre,
Expand All @@ -717,45 +722,28 @@ async def aquery( # noqa: PLR0912
session.add_tokens(pre)
pre_str = pre.text

# sort by first score, then name
filtered_contexts = sorted(
contexts,
key=lambda x: (-x.score, x.text.name),
)[: answer_config.answer_max_sources]
# remove any contexts with a score of 0
filtered_contexts = [c for c in filtered_contexts if c.score > 0]

# shim deprecated flag
# TODO: remove in v6
context_inner_prompt = prompt_config.context_inner
if (
not answer_config.evidence_detailed_citations
and "\nFrom {citation}" in context_inner_prompt
):
# Only keep "\nFrom {citation}" if we are showing detailed citations
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")

inner_context_strs = [
context_inner_prompt.format(
name=c.text.name,
text=c.context,
citation=c.text.doc.formatted_citation,
**(c.model_extra or {}),
)
for c in filtered_contexts
]
if pre_str:
inner_context_strs += (
[f"Extra background information: {pre_str}"] if pre_str else []
# make a context to include this
pre_context = Context(
text=Text(
name="Extra background information",
text=pre_str,
doc=Doc.empty("Extra background information"),
),
context=pre_str,
score=10,
)
contexts.append(pre_context)

context_str = prompt_config.context_outer.format(
context_str="\n\n".join(inner_context_strs),
valid_keys=", ".join([c.text.name for c in filtered_contexts]),
context_parts = session.get_context_parts(
answer_config.answer_max_sources,
answer_config.evidence_cache,
prompt_config.context_inner,
prompt_config.context_outer,
)

bib = {}
if len(context_str) < 10: # noqa: PLR2004
# check if no contexts made it pass filter
if len(context_parts) == 0:
answer_text = (
"I cannot answer this question due to insufficient information."
)
Expand All @@ -764,21 +752,21 @@ async def aquery( # noqa: PLR0912
answer_result = await llm_model.run_prompt(
prompt=prompt_config.qa,
data={
"context": context_str,
"answer_length": answer_config.answer_length,
"question": session.question,
"example_citation": prompt_config.EXAMPLE_CITATION,
},
callbacks=callbacks,
name="answer",
system_prompt=prompt_config.system,
msg_parts=context_parts,
)
answer_text = answer_result.text
session.add_tokens(answer_result)
# it still happens
if prompt_config.EXAMPLE_CITATION in answer_text:
answer_text = answer_text.replace(prompt_config.EXAMPLE_CITATION, "")
for c in filtered_contexts:
for c in session.contexts:
name = c.text.name
citation = c.text.doc.formatted_citation
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
Expand Down Expand Up @@ -819,6 +807,6 @@ async def aquery( # noqa: PLR0912
session.formatted_answer = formatted_answer
session.references = bib_str
session.contexts = contexts
session.context = context_str
session.context = json.dumps(context_parts, indent=2)

return session
70 changes: 55 additions & 15 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,12 +302,12 @@ async def acomplete_iter(self, prompt: str) -> AsyncIterable[Chunk]: # noqa: AR
if False: # type: ignore[unreachable] # pylint: disable=using-constant-test
yield # Trick mypy: https://github.com/python/mypy/issues/5070#issuecomment-1050834495

async def achat(self, messages: Iterable[dict[str, str]]) -> Chunk:
async def achat(self, messages: Iterable[dict[str, str | Sequence[dict]]]) -> Chunk:
"""Return the completion as string and the number of tokens in the prompt and completion."""
raise NotImplementedError

async def achat_iter(
self, messages: Iterable[dict[str, str]] # noqa: ARG002
self, messages: Iterable[dict[str, str | Sequence[dict]]] # noqa: ARG002
) -> AsyncIterable[Chunk]:
"""Return an async generator that yields chunks of the completion.

Expand All @@ -320,7 +320,9 @@ async def achat_iter(
def infer_llm_type(self) -> str:
return "completion"

def count_tokens(self, text: str) -> int:
def count_tokens(self, text: str | Sequence[dict]) -> int:
if isinstance(text, list):
return sum(self.count_tokens(m["text"]) for m in text)
return len(text) // 4 # gross approximation

async def run_prompt(
Expand All @@ -330,12 +332,18 @@ async def run_prompt(
callbacks: list[Callable] | None = None,
name: str | None = None,
system_prompt: str | None = default_system_prompt,
msg_parts: list[dict] | None = None,
) -> LLMResult:
if self.llm_type is None:
self.llm_type = self.infer_llm_type()
if self.llm_type == "chat":
return await self._run_chat(prompt, data, callbacks, name, system_prompt)
return await self._run_chat(
prompt, data, callbacks, name, system_prompt, msg_parts
)
if self.llm_type == "completion":
if msg_parts:
# compress them into a single message
prompt = "".join([m["text"] for m in msg_parts] + [prompt])
return await self._run_completion(
prompt, data, callbacks, name, system_prompt
)
Expand All @@ -348,6 +356,7 @@ async def _run_chat(
callbacks: list[Callable] | None = None,
name: str | None = None,
system_prompt: str | None = default_system_prompt,
msg_parts: list[dict] | None = None,
) -> LLMResult:
"""Run a chat prompt.

Expand All @@ -356,20 +365,49 @@ async def _run_chat(
data: Keys for the input variables that will be formatted into prompt.
callbacks: Optional functions to call with each chunk of the completion.
name: Optional name for the result.
skip_system: Set True to skip the system prompt.
system_prompt: System prompt to use, or None/empty string to not use one.
msg_parts: Additional message parts to be inserted prior to the user message.

Returns:
Result of the chat.
"""
human_message_prompt = {"role": "user", "content": prompt}
messages = [
{"role": m["role"], "content": m["content"].format(**data)}
for m in (
[{"role": "system", "content": system_prompt}, human_message_prompt]
if system_prompt
else [human_message_prompt]
# build up a multipart message and insert the system prompt
msg_parts = msg_parts or []
formatted_prompt = prompt.format(**data)
system_content: str | list[dict] | None = system_prompt
human_content: list[dict] = []
# we will pack the msg_parts in with the system prompt
# (because that is how anthropic/gemini examples show it)
# find first text content to insert prompt
insert_point = next(
(i for i, m in enumerate(msg_parts) if m["type"] == "text"), None
)
if insert_point is None:
system_content = [{"type": "text", "text": system_prompt}, *msg_parts]
else:
# will modify to insert the system prompt before the first text
# we do this to preserve any cache control headers
system_content = msg_parts
msg_parts[insert_point]["text"] = (
system_prompt + "\n\n" + msg_parts[insert_point]["text"]
)
]
human_content = [{"type": "text", "text": formatted_prompt}]
system_message = {"role": "system", "content": system_content}
human_message = {"role": "user", "content": human_content}
# look for any non-zero content in the system message
# TODO: This is fraught because we could have, for example, an empty text message
# but enabled cache-control headers. We'll just assume callers are doing the right thing.
# the main things this checks for is that `count_tokens` below will not just die
has_content_in_system = any(
any(v for k, v in c.items() if k != "type") for c in system_content
)
messages = (
[system_message, human_message]
if has_content_in_system
else [human_message]
)
print(messages)
result = LLMResult(
model=self.name,
name=name,
Expand Down Expand Up @@ -711,7 +749,7 @@ async def acomplete_iter( # type: ignore[override]

@rate_limited
async def achat( # type: ignore[override]
self, messages: Iterable[dict[str, str]]
self, messages: Iterable[dict[str, str | Sequence[dict]]]
) -> Chunk:
response = await self.router.acompletion(self.name, list(messages))
return Chunk(
Expand All @@ -722,7 +760,7 @@ async def achat( # type: ignore[override]

@rate_limited
async def achat_iter( # type: ignore[override]
self, messages: Iterable[dict[str, str]]
self, messages: Iterable[dict[str, str | Sequence[dict]]]
) -> AsyncIterable[Chunk]:
completion = await self.router.acompletion(
self.name,
Expand Down Expand Up @@ -751,7 +789,9 @@ def infer_llm_type(self) -> str:
return "completion"
return "chat"

def count_tokens(self, text: str) -> int:
def count_tokens(self, text: str | Sequence[dict]) -> int:
if isinstance(text, list):
return sum(self.count_tokens(m["text"]) for m in text)
return litellm.token_counter(model=self.name, text=text)

async def select_tool(
Expand Down
6 changes: 2 additions & 4 deletions paperqa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

qa_prompt = (
"Answer the question below with the context.\n\n"
"Context (with relevance scores):\n\n{context}\n\n----\n\n"
"Question: {question}\n\n"
"Write an answer based on the context. "
"If the context provides insufficient information reply "
Expand Down Expand Up @@ -100,6 +99,5 @@
"\n\nSingle Letter Answer:"
)

CONTEXT_OUTER_PROMPT = "{context_str}\n\nValid Keys: {valid_keys}"
CONTEXT_INNER_PROMPT_NOT_DETAILED = "{name}: {text}"
CONTEXT_INNER_PROMPT = f"{CONTEXT_INNER_PROMPT_NOT_DETAILED}\nFrom {{citation}}"
CONTEXT_OUTER_PROMPT = "\n\nValid Keys: {valid_keys}"
CONTEXT_INNER_PROMPT = "{name}: {text}\nFrom {citation}"
12 changes: 7 additions & 5 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class AnswerSettings(BaseModel):
evidence_skip_summary: bool = Field(
default=False, description="Whether to summarization"
)
evidence_cache: bool = Field(
default=False,
description="Whether to cache evidence for reuse in future questions.",
)
answer_max_sources: int = Field(
default=5, description="Max number of sources to use for an answer"
)
Expand Down Expand Up @@ -114,12 +118,10 @@ class AnswerSettings(BaseModel):
def _deprecated_field(self) -> Self:
# default is True, so we only warn if it's False
if not self.evidence_detailed_citations:
warnings.warn(
"The 'evidence_detailed_citations' field is deprecated and will be"
" removed in version 6. Adjust 'PromptSettings.context_inner' to remove"
raise DeprecationWarning(
"The 'evidence_detailed_citations' field is deprecated."
" Adjust 'PromptSettings.context_inner' to remove"
" detailed citations.",
category=DeprecationWarning,
stacklevel=2,
)
return self

Expand Down
60 changes: 60 additions & 0 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ class Doc(Embeddable):
def __hash__(self) -> int:
return hash((self.docname, self.dockey))

@classmethod
def empty(cls, docname: str) -> Doc:
return cls(docname=docname, citation="", dockey=encode_id(docname))

@computed_field # type: ignore[prop-decorator]
@property
def formatted_citation(self) -> str:
Expand All @@ -153,6 +157,18 @@ class Context(BaseModel):
text: Text
score: int = 5

def get_part(self, prompt: str | None = None) -> dict:
"""Return the context formatted as a message part."""
formatted_text = self.context
if prompt:
formatted_text = prompt.format(
name=self.text.name,
text=self.context,
citation=self.text.doc.citation,
**(self.model_extra or {}),
)
return {"type": "text", "text": formatted_text}

def __str__(self) -> str:
"""Return the context as a string."""
return self.context
Expand Down Expand Up @@ -203,6 +219,50 @@ def used_contexts(self) -> set[str]:
"""Return the used contexts."""
return get_citenames(self.formatted_answer)

def get_context_parts(
self,
count: int,
cache: bool = False,
inner_prompt: str | None = None,
outer_prompt: str | None = None,
) -> list[dict]:
"""Return the context formatted as a message."""
# sort by first score, then name
filtered_contexts = sorted(
self.contexts,
key=lambda x: (-x.score, x.text.name),
)[:count]
# remove any contexts with a score of 0
filtered_contexts = [c for c in filtered_contexts if c.score > 0]
names = [c.text.name for c in filtered_contexts]
parts = [c.get_part(inner_prompt) for c in filtered_contexts]
if outer_prompt:
parts.append(
{
"type": "text",
"text": outer_prompt.format(valid_keys=", ".join(names)),
}
)

# now merge parts to make caching easier
collapsed_parts: list[dict] = []
for part in parts:
if part["type"] == "text":
if collapsed_parts and collapsed_parts[-1]["type"] == "text":
collapsed_parts[-1]["text"] += "\n\n" + part["text"]
else:
collapsed_parts.append(part)
else:
collapsed_parts.append(part)

# add caching
if cache:
collapsed_parts = [
{**p, "cache-control": "ephemeral"} for p in collapsed_parts
]

return collapsed_parts

def get_citation(self, name: str) -> str:
"""Return the formatted citation for the given docname."""
try:
Expand Down
11 changes: 0 additions & 11 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,17 +1127,6 @@ def test_context_inner_outer_prompt(stub_data_dir: Path) -> None:
assert "Valid Keys" not in response.context


def test_evidence_detailed_citations_shim(stub_data_dir: Path) -> None:
# TODO: delete this test in v6
settings = Settings.from_name("fast")
# NOTE: this bypasses DeprecationWarning, as the warning is done on construction
settings.answer.evidence_detailed_citations = False
docs = Docs()
docs.add(stub_data_dir / "bates.txt", "WikiMedia Foundation, 2023, Accessed now")
response = docs.query("What country is Bates from?", settings=settings)
assert "WikiMedia Foundation, 2023, Accessed now" not in response.context


def test_case_insensitive_matching():
assert strings_similarity("my test sentence", "My test sentence") == 1.0
assert strings_similarity("a b c d e", "a b c f") == 0.5
Expand Down
Loading