Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LangChain Integration #60

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/manubot_ai_editor/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
OPENAI_API_KEY = "OPENAI_API_KEY"

# Language model to use. For example, "text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", etc
# The tool currently supports the "chat/completions", "completions", and "edits" endpoints, and you can check
# The tool currently supports the "chat/completions" and "completions" endpoints, and you can check
# compatible models here: https://platform.openai.com/docs/models/model-endpoint-compatibility
LANGUAGE_MODEL = "AI_EDITOR_LANGUAGE_MODEL"

Expand Down
134 changes: 85 additions & 49 deletions libs/manubot_ai_editor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time
import json

import openai
from langchain_openai import OpenAI, ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage

from manubot_ai_editor import env_vars

Expand Down Expand Up @@ -141,12 +142,13 @@ def __init__(
super().__init__()

# make sure the OpenAI API key is set
openai.api_key = openai_api_key
if openai_api_key is None:
# attempt to get the OpenAI API key from the environment, since one
# wasn't specified as an argument
openai_api_key = os.environ.get(env_vars.OPENAI_API_KEY, None)

if openai.api_key is None:
openai.api_key = os.environ.get(env_vars.OPENAI_API_KEY, None)

if openai.api_key is None or openai.api_key.strip() == "":
# if it's *still* not set, bail
if openai_api_key is None or openai_api_key.strip() == "":
raise ValueError(
f"OpenAI API key not found. Please provide it as parameter "
f"or set it as an the environment variable "
Expand Down Expand Up @@ -221,17 +223,14 @@ def __init__(
self.title = title
self.keywords = keywords if keywords is not None else []

# adjust options if edits or chat endpoint was selected
# adjust options if chat endpoint was selected
self.endpoint = "chat"

if model_engine.startswith(
("text-davinci-", "text-curie-", "text-babbage-", "text-ada-")
):
self.endpoint = "completions"

if "-edit-" in model_engine:
self.endpoint = "edits"

print(f"Language model: {model_engine}")
print(f"Model endpoint used: {self.endpoint}")

Expand All @@ -253,6 +252,18 @@ def __init__(

self.several_spaces_pattern = re.compile(r"\s+")

if self.endpoint == "chat":
client_cls = ChatOpenAI
else:
client_cls = OpenAI

# construct the OpenAI client after all the rest of
# the settings above have been processed
self.client = client_cls(
api_key=openai_api_key,
**self.model_parameters,
)

def get_prompt(
self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None
) -> str | tuple[str, str]:
Expand All @@ -268,13 +279,9 @@ def get_prompt(
resolved_prompt: prompt resolved via ai-revision config, if available

Returns:
If self.endpoint != "edits", then returns a string with the prompt to be used by the model for the revision of the paragraph.
A string with the prompt to be used by the model for the revision of the paragraph.
It contains two paragraphs of text: the command for the model
("Revise...") and the paragraph to revise.

If self.endpoint == "edits", then returns a tuple with two strings:
1) the instructions to be used by the model for the revision of the paragraph,
2) the paragraph to revise.
"""

# prompts are resolved in the following order, with the first satisfied
Expand Down Expand Up @@ -310,8 +317,6 @@ def get_prompt(
f"Using custom prompt from environment variable '{env_vars.CUSTOM_PROMPT}'"
)

# FIXME: if {paragraph_text} is in the prompt, this won't work for the edits endpoint
# a simple workaround is to remove {paragraph_text} from the prompt
prompt = custom_prompt.format(**placeholders)
elif resolved_prompt:
# use the resolved prompt from the ai-revision config files, if available
Expand Down Expand Up @@ -384,14 +389,10 @@ def get_prompt(
if custom_prompt is None:
prompt = self.several_spaces_pattern.sub(" ", prompt).strip()

if self.endpoint != "edits":
if custom_prompt is not None and "{paragraph_text}" in custom_prompt:
return prompt
if custom_prompt is not None and "{paragraph_text}" in custom_prompt:
return prompt

return f"{prompt}.\n\n{paragraph_text.strip()}"
else:
prompt = prompt.replace("the following paragraph", "this paragraph")
return f"{prompt}.", paragraph_text.strip()
return f"{prompt}.\n\n{paragraph_text.strip()}"

def get_max_tokens(self, paragraph_text: str, fraction: float = 2.0) -> int:
"""
Expand Down Expand Up @@ -465,21 +466,30 @@ def get_max_tokens_from_error_message(error_message: str) -> dict[str, int] | No
}

def get_params(self, paragraph_text, section_name, resolved_prompt=None):
"""
Given the paragraph text and section name, produces parameters that are
used when invoking an LLM via an API.

The specific parameters vary depending on the endpoint being used, which
is determined by the model that was chosen when GPT3CompletionModel was
instantiated.

Args:
paragraph_text: The text of the paragraph to be revised.
section_name: The name of the section the paragraph belongs to.
resolved_prompt: The prompt resolved via ai-revision config files, if available.

Returns:
A dictionary of parameters to be used when invoking an LLM API.
"""
max_tokens = self.get_max_tokens(paragraph_text)
prompt = self.get_prompt(paragraph_text, section_name, resolved_prompt)

params = {
"n": 1,
}

if self.endpoint == "edits":
params.update(
{
"instruction": prompt[0],
"input": prompt[1],
}
)
elif self.endpoint == "chat":
if self.endpoint == "chat":
params.update(
{
"messages": [
Expand All @@ -502,19 +512,23 @@ def get_params(self, paragraph_text, section_name, resolved_prompt=None):

return params

def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolved_prompt=None):
def revise_paragraph(
self, paragraph_text: str, section_name: str = None, resolved_prompt=None
):
"""
It revises a paragraph using GPT-3 completion model.

Arguments:
paragraph_text (str): Paragraph text to revise.
section_name (str): Section name of the paragraph.
throw_error (bool): If True, it throws an error if the API call fails.
If False, it returns the original paragraph text.
section_name (str): Section name of the paragrap
resolved_prompt (str): Prompt resolved via ai-revision config files, if available.

Returns:
Revised paragraph text.
"""

# based on the paragraph text to revise and the section to which it
# belongs, constructs parameters that we'll use to query the LLM's API
params = self.get_params(paragraph_text, section_name, resolved_prompt)

retry_count = 0
Expand All @@ -526,17 +540,39 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv
flush=True,
)

if self.endpoint == "edits":
completions = openai.Edit.create(**params)
elif self.endpoint == "chat":
completions = openai.ChatCompletion.create(**params)
else:
completions = openai.Completion.create(**params)
# map the prompt to langchain's prompt types, based on what
# kind of endpoint we're using
if "messages" in params:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit outside the PR scope but adding as this is a fresh read of the code and I'm less familiar with how params are used. I noticed the docstring doesn't match the method parameters. Consider updating this when there's a chance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we'll do a comprehensive review of the docstrings for the PR that addresses issue #68, but in this PR I've attempted to add some documentation to the GPT3CompletionModel.get_params() method to address this gap.

# map the messages to langchain's message types
# based on the 'role' field
prompt = [
falquaddoomi marked this conversation as resolved.
Show resolved Hide resolved
(
HumanMessage(content=msg["content"])
if msg["role"] == "user"
else SystemMessage(content=msg["content"])
)
for msg in params["messages"]
]
elif "prompt" in params:
prompt = [HumanMessage(content=params["prompt"])]

response = self.client.invoke(
input=prompt,
max_tokens=params.get("max_tokens"),
stop=params.get("stop"),
)

if self.endpoint == "chat":
message = completions.choices[0].message.content.strip()
if isinstance(response, BaseMessage):
message = response.content.strip()
else:
message = completions.choices[0].text.strip()
message = response.strip()

# FIXME: the prior code retrieved the first of the 'choices'
# response from the openai client. now, we only get one
# response from the langchain client, but i should check
# if that's really how langchain works or if there is a way
# to get multiple 'choices' back from the backend.

except Exception as e:
error_message = str(e)
print(f"Error: {error_message}")
Expand Down Expand Up @@ -583,10 +619,10 @@ class DebuggingManuscriptRevisionModel(GPT3CompletionModel):
"""

def __init__(self, *args, **kwargs):
if 'title' not in kwargs or kwargs['title'] is None:
kwargs['title'] = "Debugging Title"
if 'keywords' not in kwargs or kwargs['keywords'] is None:
kwargs['keywords'] = ["debugging", "keywords"]
if "title" not in kwargs or kwargs["title"] is None:
kwargs["title"] = "Debugging Title"
if "keywords" not in kwargs or kwargs["keywords"] is None:
kwargs["keywords"] = ["debugging", "keywords"]

super().__init__(*args, **kwargs)

Expand Down
20 changes: 11 additions & 9 deletions libs/manubot_ai_editor/prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __init__(self, config_dir: str | Path, title: str, keywords: str) -> None:
# specify filename-to-prompt mappings; if both are present, we use
# self.config.files, but warn the user that they should only use one
if (
self.prompts_files is not None and
self.config is not None and
self.config.get('files', {}).get('matchings') is not None
self.prompts_files is not None
and self.config is not None
and self.config.get("files", {}).get("matchings") is not None
):
print(
"WARNING: Both 'ai-revision-config.yaml' and 'ai-revision-prompts.yaml' specify filename-to-prompt mappings. "
Expand Down Expand Up @@ -93,7 +93,7 @@ def _load_custom_prompts(self) -> tuple[dict, dict]:
# same as _load_config, if no config folder was specified, we just
if self.config_dir is None:
return (None, None)

prompt_file_path = os.path.join(self.config_dir, "ai-revision-prompts.yaml")

try:
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_prompt_for_filename(
# ai-revision-prompts.yaml specifies prompts_files, then files.matchings
# takes precedence.
# (the user is notified of this in a validation warning in __init__)

# then, consult ai-revision-config.yaml's 'matchings' collection if a
# match is found, use the prompt ai-revision-prompts.yaml
for entry in get_obj_path(self.config, ("files", "matchings"), missing=[]):
Expand All @@ -169,7 +169,10 @@ def get_prompt_for_filename(
if resolved_prompt is not None:
resolved_prompt = resolved_prompt.strip()

return ( resolved_prompt, m, )
return (
resolved_prompt,
m,
)

# since we haven't found a match yet, consult ai-revision-prompts.yaml's
# 'prompts_files' collection
Expand All @@ -185,11 +188,10 @@ def get_prompt_for_filename(
resolved_default_prompt = None
if use_default and self.prompts is not None:
resolved_default_prompt = self.prompts.get(
get_obj_path(self.config, ("files", "default_prompt")),
None
get_obj_path(self.config, ("files", "default_prompt")), None
)

if resolved_default_prompt is not None:
resolved_default_prompt = resolved_default_prompt.strip()

return (resolved_default_prompt, None)
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setuptools.setup(
name="manubot-ai-editor",
version="0.5.2",
version="0.5.3",
author="Milton Pividori",
author_email="[email protected]",
description="A Manubot plugin to revise a manuscript using GPT-3",
Expand All @@ -25,7 +25,8 @@
],
python_requires=">=3.10",
install_requires=[
"openai==0.28",
"langchain-core~=0.3.6",
"langchain-openai~=0.2.0",
"pyyaml",
],
classifiers=[
Expand Down
20 changes: 9 additions & 11 deletions tests/test_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,7 @@ def test_revise_methods_with_equation_that_was_alrady_revised(
# GPT3CompletionModel(None, None),
],
)
def test_revise_methods_mutator_epistasis_paper(
tmp_path, model, filename
):
def test_revise_methods_mutator_epistasis_paper(tmp_path, model, filename):
"""
This papers has several test cases:
- it ends with multiple blank lines
Expand All @@ -635,7 +633,7 @@ def test_revise_methods_mutator_epistasis_paper(
)

assert (
r"""
r"""
%%% PARAGRAPH START %%%
Briefly, we identified private single-nucleotide mutations in each BXD that were absent from all other BXDs, as well as from the C57BL/6J and DBA/2J parents.
We required each private variant to be meet the following criteria:
Expand All @@ -651,11 +649,11 @@ def test_revise_methods_mutator_epistasis_paper(
* must occur on a parental haplotype that was inherited by at least one other BXD at the same locus; these other BXDs must be homozygous for the reference allele at the variant site
%%% PARAGRAPH END %%%
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)

assert (
r"""
r"""
### Extracting mutation signatures

We used SigProfilerExtractor (v.1.1.21) [@PMID:30371878] to extract mutation signatures from the BXD mutation data.
Expand All @@ -678,11 +676,11 @@ def test_revise_methods_mutator_epistasis_paper(

### Comparing mutation spectra between Mouse Genomes Project strains
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)

assert (
r"""
r"""
%%% PARAGRAPH START %%%
We investigated the region implicated by our aggregate mutation spectrum distance approach on chromosome 6 by subsetting the joint-genotyped BXD VCF file (European Nucleotide Archive accession PRJEB45429 [@url:https://www.ebi.ac.uk/ena/browser/view/PRJEB45429]) using `bcftools` [@PMID:33590861].
We defined the candidate interval surrounding the cosine distance peak on chromosome 6 as the 90% bootstrap confidence interval (extending from approximately 95 Mbp to 114 Mbp).
Expand All @@ -693,7 +691,7 @@ def test_revise_methods_mutator_epistasis_paper(
java -Xmx16g -jar /path/to/snpeff/jarfile GRCm38.75 /path/to/bxd/vcf > /path/to/uncompressed/output/vcf
```
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)


Expand Down
Loading