Skip to content

Commit

Permalink
feat(opentrons-ai-server): anthropic integration (#16881)
Browse files Browse the repository at this point in the history
<!--
Thanks for taking the time to open a Pull Request (PR)! Please make sure
you've read the "Opening Pull Requests" section of our Contributing
Guide:


https://github.com/Opentrons/opentrons/blob/edge/CONTRIBUTING.md#opening-pull-requests

GitHub provides robust markdown to format your PR. Links, diagrams,
pictures, and videos along with text formatting make it possible to
create a rich and informative PR. For more information on GitHub
markdown, see:


https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax

To ensure your code is reviewed quickly and thoroughly, please fill out
the sections below to the best of your ability!
-->

# Overview
closes AUTH-1050 
This PR introduces Anthropic model, named `claude-3.5-sonnet` in
addition to OpenAI.
<!--
Describe your PR at a high level. State acceptance criteria and how this
PR fits into other work. Link issues, PRs, and other relevant resources.
-->

## Test Plan and Hands on Testing
Please interact with UI
<!--
Describe your testing of the PR. Emphasize testing not reflected in the
code. Attach protocols, logs, screenshots and any other assets that
support your testing.
-->

## Changelog
- integrated anthropic model and set it as default
<!--
List changes introduced by this PR considering future developers and the
end user. Give careful thought and clear documentation to breaking
changes.
-->

## Review requests
**Review scripts:** 
- `opentrons-ai-server/api/domain/anthropic_predict.py`
- `opentrons-ai-server/api/domain/config_anthropic.py`
- `opentrons-ai-server/api/handler/fast.py`

**UI**
- Please interact with UI and create protocols.

- not required to check `storage/docs` unless if you have time

<!--
- What do you need from reviewers to feel confident this PR is ready to
merge?
- Ask questions.
-->

## Risk assessment
Low
<!--
- Indicate the level of attention this PR needs.
- Provide context to guide reviewers.
- Discuss trade-offs, coupling, and side effects.
- Look for the possibility, even if you think it's small, that your
change may affect some other part of the system.
- For instance, changing return tip behavior may also change the
behavior of labware calibration.
- How do your unit tests and on hands on testing mitigate this PR's
risks and the risk of future regressions?
- Especially in high risk PRs, explain how you know your testing is
enough.
-->
  • Loading branch information
Elyorcv authored Nov 19, 2024
1 parent 1c4385c commit b7a5540
Show file tree
Hide file tree
Showing 16 changed files with 7,595 additions and 92 deletions.
1 change: 1 addition & 0 deletions opentrons-ai-server/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ asgi-correlation-id = "==4.3.3"
gspread = "==6.1.4"
google-auth = "==2.36.0"
google-auth-oauthlib = "==1.2.1"
anthropic = "*"

[dev-packages]
docker = "==7.1.0"
Expand Down
256 changes: 172 additions & 84 deletions opentrons-ai-server/Pipfile.lock

Large diffs are not rendered by default.

206 changes: 206 additions & 0 deletions opentrons-ai-server/api/domain/anthropic_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import uuid
from pathlib import Path
from typing import Any, Dict, List

import requests
import structlog
from anthropic import Anthropic
from anthropic.types import Message, MessageParam
from ddtrace import tracer

from api.domain.config_anthropic import DOCUMENTS, PROMPT, SYSTEM_PROMPT
from api.settings import Settings

settings: Settings = Settings()
logger = structlog.stdlib.get_logger(settings.logger_name)
ROOT_PATH: Path = Path(Path(__file__)).parent.parent.parent


class AnthropicPredict:
def __init__(self, settings: Settings) -> None:
self.settings: Settings = settings
self.client: Anthropic = Anthropic(api_key=settings.anthropic_api_key.get_secret_value())
self.model_name: str = settings.anthropic_model_name
self.system_prompt: str = SYSTEM_PROMPT
self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs"
self._messages: List[MessageParam] = [
{
"role": "user",
"content": [
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore
],
}
]
self.tools: List[Dict[str, Any]] = [
{
"name": "simulate_protocol",
"description": "Simulates the python protocol on user input. Returned value is text indicating if protocol is successful.",
"input_schema": {
"type": "object",
"properties": {
"protocol": {"type": "string", "description": "protocol in python for simulation"},
},
"required": ["protocol"],
},
}
]

@tracer.wrap()
def get_docs(self) -> str:
"""
Processes documents from a directory and returns their content wrapped in XML tags.
Each document is wrapped in <document> tags with metadata subtags.
Returns:
str: XML-formatted string containing all documents and their metadata
"""
logger.info("Getting docs", extra={"path": str(self.path_docs)})
xml_output = ["<documents>"]
for file_path in self.path_docs.iterdir():
try:
content = file_path.read_text(encoding="utf-8")
document_xml = [
"<document>",
f" <source>{file_path.name}</source>",
" <document_content>",
f" {content}",
" </document_content>",
"</document>",
]
xml_output.extend(document_xml)

except Exception as e:
logger.error("Error procesing file", extra={"file": file_path.name, "error": str(e)})
continue

xml_output.append("</documents>")
return "\n".join(xml_output)

@tracer.wrap()
def generate_message(self, max_tokens: int = 4096) -> Message:

response = self.client.messages.create(
model=self.model_name,
system=self.system_prompt,
max_tokens=max_tokens,
messages=self._messages,
tools=self.tools, # type: ignore
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
)

logger.info(
"Token usage",
extra={
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
"cache_read": getattr(response.usage, "cache_read_input_tokens", "---"),
"cache_create": getattr(response.usage, "cache_creation_input_tokens", "---"),
},
)
return response

@tracer.wrap()
def predict(self, prompt: str) -> str | None:
try:
self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
response = self.generate_message()
if response.content[-1].type == "tool_use":
tool_use = response.content[-1]
self._messages.append({"role": "assistant", "content": response.content})
result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore
self._messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": result,
}
],
}
)
follow_up = self.generate_message()
response_text = follow_up.content[0].text # type: ignore
self._messages.append({"role": "assistant", "content": response_text})
return response_text

elif response.content[0].type == "text":
response_text = response.content[0].text
self._messages.append({"role": "assistant", "content": response_text})
return response_text

logger.error("Unexpected response type")
return None
except IndexError as e:
logger.error("Invalid response format", extra={"error": str(e)})
return None
except Exception as e:
logger.error("Error in predict method", extra={"error": str(e)})
return None

@tracer.wrap()
def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
if func_name == "simulate_protocol":
results = self.simulate_protocol(**func_params)
return results

logger.error("Unknown tool", extra={"tool": func_name})
raise ValueError(f"Unknown tool: {func_name}")

@tracer.wrap()
def reset(self) -> None:
self._messages = [
{
"role": "user",
"content": [
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore
],
}
]

@tracer.wrap()
def simulate_protocol(self, protocol: str) -> str:
url = "https://Opentrons-simulator.hf.space/protocol"
protocol_name = str(uuid.uuid4()) + ".py"
data = {"name": protocol_name, "content": protocol}
hf_token: str = settings.huggingface_api_key.get_secret_value()
headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(hf_token)}
response = requests.post(url, json=data, headers=headers)

if response.status_code != 200:
logger.error("Simulation request failed", extra={"status": response.status_code, "error": response.text})
return f"Error: {response.text}"

response_data = response.json()
if "error_message" in response_data:
logger.error("Simulation error", extra={"error": response_data["error_message"]})
return str(response_data["error_message"])
elif "protocol_name" in response_data:
return str(response_data["run_status"])
else:
logger.error("Unexpected response", extra={"response": response_data})
return "Unexpected response"


def main() -> None:
"""Intended for testing this class locally."""
import sys
from pathlib import Path

# # Add project root to Python path
root_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(root_dir))

from rich import print
from rich.prompt import Prompt

settings = Settings()
llm = AnthropicPredict(settings)
prompt = Prompt.ask("Type a prompt to send to the Anthropic API:")
completion = llm.predict(prompt)
print(completion)


if __name__ == "__main__":
main()
Loading

0 comments on commit b7a5540

Please sign in to comment.