-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LangChain Integration #60
base: main
Are you sure you want to change the base?
Changes from 3 commits
de34033
ae19171
33d0adc
8e75813
e4da6f0
cdca3b8
9a2d11b
3d52297
bfacb16
3bafcf6
a98f69a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,8 @@ | |
import time | ||
import json | ||
|
||
import openai | ||
from langchain_openai import OpenAI, ChatOpenAI | ||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | ||
|
||
from manubot_ai_editor import env_vars | ||
|
||
|
@@ -141,12 +142,13 @@ def __init__( | |
super().__init__() | ||
|
||
# make sure the OpenAI API key is set | ||
openai.api_key = openai_api_key | ||
if openai_api_key is None: | ||
# attempt to get the OpenAI API key from the environment, since one | ||
# wasn't specified as an argument | ||
openai_api_key = os.environ.get(env_vars.OPENAI_API_KEY, None) | ||
|
||
if openai.api_key is None: | ||
openai.api_key = os.environ.get(env_vars.OPENAI_API_KEY, None) | ||
|
||
if openai.api_key is None or openai.api_key.strip() == "": | ||
# if it's *still* not set, bail | ||
if openai_api_key is None or openai_api_key.strip() == "": | ||
raise ValueError( | ||
f"OpenAI API key not found. Please provide it as parameter " | ||
f"or set it as an the environment variable " | ||
|
@@ -253,6 +255,22 @@ def __init__( | |
|
||
self.several_spaces_pattern = re.compile(r"\s+") | ||
|
||
if self.endpoint == "edits": | ||
# FIXME: what's the "edits" equivalent in langchain? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider moving this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, and fair that adding "FIXME"s at all runs the risk of them being introduced into merged code. My intent here was to get this FIXME figured out within the scope of this PR, which is why I didn't create an issue for it, but I'll think more on not adding FIXMEs and instead communicating questions some other way (review comments, perhaps?) |
||
client_cls = OpenAI | ||
elif self.endpoint == "chat": | ||
client_cls = ChatOpenAI | ||
else: | ||
client_cls = OpenAI | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to take care of this anymore. Before, there were a "completion" and "edits" endpoints, but now we only have a "chat" endpoint I believe. Let's research a little bit, but I think we only need the |
||
|
||
# construct the OpenAI client after all the rest of | ||
# the settings above have been processed | ||
self.client = client_cls( | ||
api_key=openai_api_key, | ||
**self.model_parameters, | ||
) | ||
|
||
|
||
def get_prompt( | ||
self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None | ||
) -> str | tuple[str, str]: | ||
|
@@ -526,17 +544,48 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv | |
flush=True, | ||
) | ||
|
||
if self.endpoint == "edits": | ||
completions = openai.Edit.create(**params) | ||
elif self.endpoint == "chat": | ||
completions = openai.ChatCompletion.create(**params) | ||
# FIXME: 'params' contains a lot of fields that we're not | ||
# currently passing to the langchain client. i need to figure | ||
# out where they're supposed to be given, e.g. in the client | ||
# init or with each request. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are those fields in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at it again, "a lot" is an overstatement, sorry. On top of the
Correct me if I'm wrong, but since I'll go ahead and make the other changes, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I didn't forget what the code does, the only field that should go in each request/invoke (instead of using them to initialize the client) is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, after I made the comment above I discovered that |
||
|
||
# map the prompt to langchain's prompt types, based on what | ||
# kind of endpoint we're using | ||
if "messages" in params: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit outside the PR scope but adding as this is a fresh read of the code and I'm less familiar with how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking we'll do a comprehensive review of the docstrings for the PR that addresses issue #68, but in this PR I've attempted to add some documentation to the |
||
# map the messages to langchain's message types | ||
# based on the 'role' field | ||
prompts = [ | ||
HumanMessage(content=msg["content"]) | ||
if msg["role"] == "user" else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might need formatting corrections applied via Black (I tested using the existing |
||
SystemMessage(content=msg["content"]) | ||
for msg in params["messages"] | ||
] | ||
elif "instruction" in params: | ||
# since we don't know how to use the edits endpoint, we'll just | ||
# concatenate the instruction and input and use the regular | ||
# completion endpoint | ||
# FIXME: there's probably a langchain equivalent for | ||
# "edits", so we should change this to use that | ||
prompts = [ | ||
HumanMessage(content=params["instruction"]), | ||
HumanMessage(content=params["input"]), | ||
] | ||
elif "prompt" in params: | ||
prompts = [HumanMessage(content=params["prompt"])] | ||
|
||
response = self.client.invoke(prompts) | ||
|
||
if isinstance(response, BaseMessage): | ||
message = response.content.strip() | ||
else: | ||
completions = openai.Completion.create(**params) | ||
message = response.strip() | ||
|
||
# FIXME: the prior code retrieved the first of the 'choices' | ||
# response from the openai client. now, we only get one | ||
# response from the langchain client, but i should check | ||
# if that's really how langchain works or if there is a way | ||
# to get multiple 'choices' back from the backend. | ||
|
||
if self.endpoint == "chat": | ||
message = completions.choices[0].message.content.strip() | ||
else: | ||
message = completions.choices[0].text.strip() | ||
except Exception as e: | ||
error_message = str(e) | ||
print(f"Error: {error_message}") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
setuptools.setup( | ||
name="manubot-ai-editor", | ||
version="0.5.2", | ||
version="0.5.3", | ||
author="Milton Pividori", | ||
author_email="[email protected]", | ||
description="A Manubot plugin to revise a manuscript using GPT-3", | ||
|
@@ -25,7 +25,7 @@ | |
], | ||
python_requires=">=3.10", | ||
install_requires=[ | ||
"openai==0.28", | ||
"langchain-openai==0.2.0", | ||
falquaddoomi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"pyyaml", | ||
], | ||
classifiers=[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider documenting class attributes in the docstring for the class to help identify what functionality they're associated with. As I read through this I wondered "what does
self.endpoint
do; how might it matter later?" and couldn't find much human-readable form on this topic. It could be that I'm missing fundamental common knowledge about how this works - if so, please don't hesitate to link to the appropriate location.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for pointing this out; I've created an issue to address filling these gaps in the documentation, #68.