diff --git a/paperqa/docs.py b/paperqa/docs.py index 7f68457f..6917e272 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -34,6 +34,7 @@ from paperqa.readers import read_doc from paperqa.settings import MaybeSettings, get_settings from paperqa.types import ( + Context, Doc, DocDetails, DocKey, @@ -705,7 +706,11 @@ async def aquery( # noqa: PLR0912 ) contexts = session.contexts pre_str = None - if prompt_config.pre is not None: + # we check if we have a pre-context + # to avoid recreating it + if prompt_config.pre is not None and not any( + c.text.name == "Extra background information" for c in contexts + ): with set_llm_session_ids(session.id): pre = await llm_model.run_prompt( prompt=prompt_config.pre, @@ -717,45 +722,28 @@ async def aquery( # noqa: PLR0912 session.add_tokens(pre) pre_str = pre.text - # sort by first score, then name - filtered_contexts = sorted( - contexts, - key=lambda x: (-x.score, x.text.name), - )[: answer_config.answer_max_sources] - # remove any contexts with a score of 0 - filtered_contexts = [c for c in filtered_contexts if c.score > 0] - - # shim deprecated flag - # TODO: remove in v6 - context_inner_prompt = prompt_config.context_inner - if ( - not answer_config.evidence_detailed_citations - and "\nFrom {citation}" in context_inner_prompt - ): - # Only keep "\nFrom {citation}" if we are showing detailed citations - context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "") - - inner_context_strs = [ - context_inner_prompt.format( - name=c.text.name, - text=c.context, - citation=c.text.doc.formatted_citation, - **(c.model_extra or {}), - ) - for c in filtered_contexts - ] - if pre_str: - inner_context_strs += ( - [f"Extra background information: {pre_str}"] if pre_str else [] + # make a context to include this + pre_context = Context( + text=Text( + name="Extra background information", + text=pre_str, + doc=Doc.empty("Extra background information"), + ), + context=pre_str, + score=10, ) + contexts.append(pre_context) - context_str = prompt_config.context_outer.format( - context_str="\n\n".join(inner_context_strs), - valid_keys=", ".join([c.text.name for c in filtered_contexts]), + context_parts = session.get_context_parts( + answer_config.answer_max_sources, + answer_config.evidence_cache, + prompt_config.context_inner, + prompt_config.context_outer, ) bib = {} - if len(context_str) < 10: # noqa: PLR2004 + # check if no contexts made it pass filter + if len(context_parts) == 0: answer_text = ( "I cannot answer this question due to insufficient information." ) @@ -764,7 +752,6 @@ async def aquery( # noqa: PLR0912 answer_result = await llm_model.run_prompt( prompt=prompt_config.qa, data={ - "context": context_str, "answer_length": answer_config.answer_length, "question": session.question, "example_citation": prompt_config.EXAMPLE_CITATION, @@ -772,13 +759,14 @@ async def aquery( # noqa: PLR0912 callbacks=callbacks, name="answer", system_prompt=prompt_config.system, + msg_parts=context_parts, ) answer_text = answer_result.text session.add_tokens(answer_result) # it still happens if prompt_config.EXAMPLE_CITATION in answer_text: answer_text = answer_text.replace(prompt_config.EXAMPLE_CITATION, "") - for c in filtered_contexts: + for c in session.contexts: name = c.text.name citation = c.text.doc.formatted_citation # do check for whole key (so we don't catch Callahan2019a with Callahan2019) @@ -819,6 +807,6 @@ async def aquery( # noqa: PLR0912 session.formatted_answer = formatted_answer session.references = bib_str session.contexts = contexts - session.context = context_str + session.context = json.dumps(context_parts, indent=2) return session diff --git a/paperqa/llms.py b/paperqa/llms.py index 3ff1d577..b2408ca5 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -302,12 +302,12 @@ async def acomplete_iter(self, prompt: str) -> AsyncIterable[Chunk]: # noqa: AR if False: # type: ignore[unreachable] # pylint: disable=using-constant-test yield # Trick mypy: https://github.com/python/mypy/issues/5070#issuecomment-1050834495 - async def achat(self, messages: Iterable[dict[str, str]]) -> Chunk: + async def achat(self, messages: Iterable[dict[str, str | Sequence[dict]]]) -> Chunk: """Return the completion as string and the number of tokens in the prompt and completion.""" raise NotImplementedError async def achat_iter( - self, messages: Iterable[dict[str, str]] # noqa: ARG002 + self, messages: Iterable[dict[str, str | Sequence[dict]]] # noqa: ARG002 ) -> AsyncIterable[Chunk]: """Return an async generator that yields chunks of the completion. @@ -320,7 +320,9 @@ async def achat_iter( def infer_llm_type(self) -> str: return "completion" - def count_tokens(self, text: str) -> int: + def count_tokens(self, text: str | Sequence[dict]) -> int: + if isinstance(text, list): + return sum(self.count_tokens(m["text"]) for m in text) return len(text) // 4 # gross approximation async def run_prompt( @@ -330,12 +332,18 @@ async def run_prompt( callbacks: list[Callable] | None = None, name: str | None = None, system_prompt: str | None = default_system_prompt, + msg_parts: list[dict] | None = None, ) -> LLMResult: if self.llm_type is None: self.llm_type = self.infer_llm_type() if self.llm_type == "chat": - return await self._run_chat(prompt, data, callbacks, name, system_prompt) + return await self._run_chat( + prompt, data, callbacks, name, system_prompt, msg_parts + ) if self.llm_type == "completion": + if msg_parts: + # compress them into a single message + prompt = "".join([m["text"] for m in msg_parts] + [prompt]) return await self._run_completion( prompt, data, callbacks, name, system_prompt ) @@ -348,6 +356,7 @@ async def _run_chat( callbacks: list[Callable] | None = None, name: str | None = None, system_prompt: str | None = default_system_prompt, + msg_parts: list[dict] | None = None, ) -> LLMResult: """Run a chat prompt. @@ -356,20 +365,49 @@ async def _run_chat( data: Keys for the input variables that will be formatted into prompt. callbacks: Optional functions to call with each chunk of the completion. name: Optional name for the result. + skip_system: Set True to skip the system prompt. system_prompt: System prompt to use, or None/empty string to not use one. + msg_parts: Additional message parts to be inserted prior to the user message. Returns: Result of the chat. """ - human_message_prompt = {"role": "user", "content": prompt} - messages = [ - {"role": m["role"], "content": m["content"].format(**data)} - for m in ( - [{"role": "system", "content": system_prompt}, human_message_prompt] - if system_prompt - else [human_message_prompt] + # build up a multipart message and insert the system prompt + msg_parts = msg_parts or [] + formatted_prompt = prompt.format(**data) + system_content: str | list[dict] | None = system_prompt + human_content: list[dict] = [] + # we will pack the msg_parts in with the system prompt + # (because that is how anthropic/gemini examples show it) + # find first text content to insert prompt + insert_point = next( + (i for i, m in enumerate(msg_parts) if m["type"] == "text"), None + ) + if insert_point is None: + system_content = [{"type": "text", "text": system_prompt}, *msg_parts] + else: + # will modify to insert the system prompt before the first text + # we do this to preserve any cache control headers + system_content = msg_parts + msg_parts[insert_point]["text"] = ( + system_prompt + "\n\n" + msg_parts[insert_point]["text"] ) - ] + human_content = [{"type": "text", "text": formatted_prompt}] + system_message = {"role": "system", "content": system_content} + human_message = {"role": "user", "content": human_content} + # look for any non-zero content in the system message + # TODO: This is fraught because we could have, for example, an empty text message + # but enabled cache-control headers. We'll just assume callers are doing the right thing. + # the main things this checks for is that `count_tokens` below will not just die + has_content_in_system = any( + any(v for k, v in c.items() if k != "type") for c in system_content + ) + messages = ( + [system_message, human_message] + if has_content_in_system + else [human_message] + ) + print(messages) result = LLMResult( model=self.name, name=name, @@ -711,7 +749,7 @@ async def acomplete_iter( # type: ignore[override] @rate_limited async def achat( # type: ignore[override] - self, messages: Iterable[dict[str, str]] + self, messages: Iterable[dict[str, str | Sequence[dict]]] ) -> Chunk: response = await self.router.acompletion(self.name, list(messages)) return Chunk( @@ -722,7 +760,7 @@ async def achat( # type: ignore[override] @rate_limited async def achat_iter( # type: ignore[override] - self, messages: Iterable[dict[str, str]] + self, messages: Iterable[dict[str, str | Sequence[dict]]] ) -> AsyncIterable[Chunk]: completion = await self.router.acompletion( self.name, @@ -751,7 +789,9 @@ def infer_llm_type(self) -> str: return "completion" return "chat" - def count_tokens(self, text: str) -> int: + def count_tokens(self, text: str | Sequence[dict]) -> int: + if isinstance(text, list): + return sum(self.count_tokens(m["text"]) for m in text) return litellm.token_counter(model=self.name, text=text) async def select_tool( diff --git a/paperqa/prompts.py b/paperqa/prompts.py index abb262d1..e312d7d3 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -15,7 +15,6 @@ qa_prompt = ( "Answer the question below with the context.\n\n" - "Context (with relevance scores):\n\n{context}\n\n----\n\n" "Question: {question}\n\n" "Write an answer based on the context. " "If the context provides insufficient information reply " @@ -100,6 +99,5 @@ "\n\nSingle Letter Answer:" ) -CONTEXT_OUTER_PROMPT = "{context_str}\n\nValid Keys: {valid_keys}" -CONTEXT_INNER_PROMPT_NOT_DETAILED = "{name}: {text}" -CONTEXT_INNER_PROMPT = f"{CONTEXT_INNER_PROMPT_NOT_DETAILED}\nFrom {{citation}}" +CONTEXT_OUTER_PROMPT = "\n\nValid Keys: {valid_keys}" +CONTEXT_INNER_PROMPT = "{name}: {text}\nFrom {citation}" diff --git a/paperqa/settings.py b/paperqa/settings.py index 2a0b457a..68c5a6f4 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -82,6 +82,10 @@ class AnswerSettings(BaseModel): evidence_skip_summary: bool = Field( default=False, description="Whether to summarization" ) + evidence_cache: bool = Field( + default=False, + description="Whether to cache evidence for reuse in future questions.", + ) answer_max_sources: int = Field( default=5, description="Max number of sources to use for an answer" ) @@ -114,12 +118,10 @@ class AnswerSettings(BaseModel): def _deprecated_field(self) -> Self: # default is True, so we only warn if it's False if not self.evidence_detailed_citations: - warnings.warn( - "The 'evidence_detailed_citations' field is deprecated and will be" - " removed in version 6. Adjust 'PromptSettings.context_inner' to remove" + raise DeprecationWarning( + "The 'evidence_detailed_citations' field is deprecated." + " Adjust 'PromptSettings.context_inner' to remove" " detailed citations.", - category=DeprecationWarning, - stacklevel=2, ) return self diff --git a/paperqa/types.py b/paperqa/types.py index 3726c192..bd04438a 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -129,6 +129,10 @@ class Doc(Embeddable): def __hash__(self) -> int: return hash((self.docname, self.dockey)) + @classmethod + def empty(cls, docname: str) -> Doc: + return cls(docname=docname, citation="", dockey=encode_id(docname)) + @computed_field # type: ignore[prop-decorator] @property def formatted_citation(self) -> str: @@ -153,6 +157,18 @@ class Context(BaseModel): text: Text score: int = 5 + def get_part(self, prompt: str | None = None) -> dict: + """Return the context formatted as a message part.""" + formatted_text = self.context + if prompt: + formatted_text = prompt.format( + name=self.text.name, + text=self.context, + citation=self.text.doc.citation, + **(self.model_extra or {}), + ) + return {"type": "text", "text": formatted_text} + def __str__(self) -> str: """Return the context as a string.""" return self.context @@ -203,6 +219,50 @@ def used_contexts(self) -> set[str]: """Return the used contexts.""" return get_citenames(self.formatted_answer) + def get_context_parts( + self, + count: int, + cache: bool = False, + inner_prompt: str | None = None, + outer_prompt: str | None = None, + ) -> list[dict]: + """Return the context formatted as a message.""" + # sort by first score, then name + filtered_contexts = sorted( + self.contexts, + key=lambda x: (-x.score, x.text.name), + )[:count] + # remove any contexts with a score of 0 + filtered_contexts = [c for c in filtered_contexts if c.score > 0] + names = [c.text.name for c in filtered_contexts] + parts = [c.get_part(inner_prompt) for c in filtered_contexts] + if outer_prompt: + parts.append( + { + "type": "text", + "text": outer_prompt.format(valid_keys=", ".join(names)), + } + ) + + # now merge parts to make caching easier + collapsed_parts: list[dict] = [] + for part in parts: + if part["type"] == "text": + if collapsed_parts and collapsed_parts[-1]["type"] == "text": + collapsed_parts[-1]["text"] += "\n\n" + part["text"] + else: + collapsed_parts.append(part) + else: + collapsed_parts.append(part) + + # add caching + if cache: + collapsed_parts = [ + {**p, "cache-control": "ephemeral"} for p in collapsed_parts + ] + + return collapsed_parts + def get_citation(self, name: str) -> str: """Return the formatted citation for the given docname.""" try: diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 9dc8dcf3..ebb22276 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -1127,17 +1127,6 @@ def test_context_inner_outer_prompt(stub_data_dir: Path) -> None: assert "Valid Keys" not in response.context -def test_evidence_detailed_citations_shim(stub_data_dir: Path) -> None: - # TODO: delete this test in v6 - settings = Settings.from_name("fast") - # NOTE: this bypasses DeprecationWarning, as the warning is done on construction - settings.answer.evidence_detailed_citations = False - docs = Docs() - docs.add(stub_data_dir / "bates.txt", "WikiMedia Foundation, 2023, Accessed now") - response = docs.query("What country is Bates from?", settings=settings) - assert "WikiMedia Foundation, 2023, Accessed now" not in response.context - - def test_case_insensitive_matching(): assert strings_similarity("my test sentence", "My test sentence") == 1.0 assert strings_similarity("a b c d e", "a b c f") == 0.5