-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(opentrons-ai-server): anthropic integration (#16881)
<!-- Thanks for taking the time to open a Pull Request (PR)! Please make sure you've read the "Opening Pull Requests" section of our Contributing Guide: https://github.com/Opentrons/opentrons/blob/edge/CONTRIBUTING.md#opening-pull-requests GitHub provides robust markdown to format your PR. Links, diagrams, pictures, and videos along with text formatting make it possible to create a rich and informative PR. For more information on GitHub markdown, see: https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax To ensure your code is reviewed quickly and thoroughly, please fill out the sections below to the best of your ability! --> # Overview closes AUTH-1050 This PR introduces Anthropic model, named `claude-3.5-sonnet` in addition to OpenAI. <!-- Describe your PR at a high level. State acceptance criteria and how this PR fits into other work. Link issues, PRs, and other relevant resources. --> ## Test Plan and Hands on Testing Please interact with UI <!-- Describe your testing of the PR. Emphasize testing not reflected in the code. Attach protocols, logs, screenshots and any other assets that support your testing. --> ## Changelog - integrated anthropic model and set it as default <!-- List changes introduced by this PR considering future developers and the end user. Give careful thought and clear documentation to breaking changes. --> ## Review requests **Review scripts:** - `opentrons-ai-server/api/domain/anthropic_predict.py` - `opentrons-ai-server/api/domain/config_anthropic.py` - `opentrons-ai-server/api/handler/fast.py` **UI** - Please interact with UI and create protocols. - not required to check `storage/docs` unless if you have time <!-- - What do you need from reviewers to feel confident this PR is ready to merge? - Ask questions. --> ## Risk assessment Low <!-- - Indicate the level of attention this PR needs. - Provide context to guide reviewers. - Discuss trade-offs, coupling, and side effects. - Look for the possibility, even if you think it's small, that your change may affect some other part of the system. - For instance, changing return tip behavior may also change the behavior of labware calibration. - How do your unit tests and on hands on testing mitigate this PR's risks and the risk of future regressions? - Especially in high risk PRs, explain how you know your testing is enough. -->
- Loading branch information
Showing
16 changed files
with
7,595 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import uuid | ||
from pathlib import Path | ||
from typing import Any, Dict, List | ||
|
||
import requests | ||
import structlog | ||
from anthropic import Anthropic | ||
from anthropic.types import Message, MessageParam | ||
from ddtrace import tracer | ||
|
||
from api.domain.config_anthropic import DOCUMENTS, PROMPT, SYSTEM_PROMPT | ||
from api.settings import Settings | ||
|
||
settings: Settings = Settings() | ||
logger = structlog.stdlib.get_logger(settings.logger_name) | ||
ROOT_PATH: Path = Path(Path(__file__)).parent.parent.parent | ||
|
||
|
||
class AnthropicPredict: | ||
def __init__(self, settings: Settings) -> None: | ||
self.settings: Settings = settings | ||
self.client: Anthropic = Anthropic(api_key=settings.anthropic_api_key.get_secret_value()) | ||
self.model_name: str = settings.anthropic_model_name | ||
self.system_prompt: str = SYSTEM_PROMPT | ||
self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs" | ||
self._messages: List[MessageParam] = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore | ||
], | ||
} | ||
] | ||
self.tools: List[Dict[str, Any]] = [ | ||
{ | ||
"name": "simulate_protocol", | ||
"description": "Simulates the python protocol on user input. Returned value is text indicating if protocol is successful.", | ||
"input_schema": { | ||
"type": "object", | ||
"properties": { | ||
"protocol": {"type": "string", "description": "protocol in python for simulation"}, | ||
}, | ||
"required": ["protocol"], | ||
}, | ||
} | ||
] | ||
|
||
@tracer.wrap() | ||
def get_docs(self) -> str: | ||
""" | ||
Processes documents from a directory and returns their content wrapped in XML tags. | ||
Each document is wrapped in <document> tags with metadata subtags. | ||
Returns: | ||
str: XML-formatted string containing all documents and their metadata | ||
""" | ||
logger.info("Getting docs", extra={"path": str(self.path_docs)}) | ||
xml_output = ["<documents>"] | ||
for file_path in self.path_docs.iterdir(): | ||
try: | ||
content = file_path.read_text(encoding="utf-8") | ||
document_xml = [ | ||
"<document>", | ||
f" <source>{file_path.name}</source>", | ||
" <document_content>", | ||
f" {content}", | ||
" </document_content>", | ||
"</document>", | ||
] | ||
xml_output.extend(document_xml) | ||
|
||
except Exception as e: | ||
logger.error("Error procesing file", extra={"file": file_path.name, "error": str(e)}) | ||
continue | ||
|
||
xml_output.append("</documents>") | ||
return "\n".join(xml_output) | ||
|
||
@tracer.wrap() | ||
def generate_message(self, max_tokens: int = 4096) -> Message: | ||
|
||
response = self.client.messages.create( | ||
model=self.model_name, | ||
system=self.system_prompt, | ||
max_tokens=max_tokens, | ||
messages=self._messages, | ||
tools=self.tools, # type: ignore | ||
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, | ||
) | ||
|
||
logger.info( | ||
"Token usage", | ||
extra={ | ||
"input_tokens": response.usage.input_tokens, | ||
"output_tokens": response.usage.output_tokens, | ||
"cache_read": getattr(response.usage, "cache_read_input_tokens", "---"), | ||
"cache_create": getattr(response.usage, "cache_creation_input_tokens", "---"), | ||
}, | ||
) | ||
return response | ||
|
||
@tracer.wrap() | ||
def predict(self, prompt: str) -> str | None: | ||
try: | ||
self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)}) | ||
response = self.generate_message() | ||
if response.content[-1].type == "tool_use": | ||
tool_use = response.content[-1] | ||
self._messages.append({"role": "assistant", "content": response.content}) | ||
result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore | ||
self._messages.append( | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "tool_result", | ||
"tool_use_id": tool_use.id, | ||
"content": result, | ||
} | ||
], | ||
} | ||
) | ||
follow_up = self.generate_message() | ||
response_text = follow_up.content[0].text # type: ignore | ||
self._messages.append({"role": "assistant", "content": response_text}) | ||
return response_text | ||
|
||
elif response.content[0].type == "text": | ||
response_text = response.content[0].text | ||
self._messages.append({"role": "assistant", "content": response_text}) | ||
return response_text | ||
|
||
logger.error("Unexpected response type") | ||
return None | ||
except IndexError as e: | ||
logger.error("Invalid response format", extra={"error": str(e)}) | ||
return None | ||
except Exception as e: | ||
logger.error("Error in predict method", extra={"error": str(e)}) | ||
return None | ||
|
||
@tracer.wrap() | ||
def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str: | ||
if func_name == "simulate_protocol": | ||
results = self.simulate_protocol(**func_params) | ||
return results | ||
|
||
logger.error("Unknown tool", extra={"tool": func_name}) | ||
raise ValueError(f"Unknown tool: {func_name}") | ||
|
||
@tracer.wrap() | ||
def reset(self) -> None: | ||
self._messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore | ||
], | ||
} | ||
] | ||
|
||
@tracer.wrap() | ||
def simulate_protocol(self, protocol: str) -> str: | ||
url = "https://Opentrons-simulator.hf.space/protocol" | ||
protocol_name = str(uuid.uuid4()) + ".py" | ||
data = {"name": protocol_name, "content": protocol} | ||
hf_token: str = settings.huggingface_api_key.get_secret_value() | ||
headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(hf_token)} | ||
response = requests.post(url, json=data, headers=headers) | ||
|
||
if response.status_code != 200: | ||
logger.error("Simulation request failed", extra={"status": response.status_code, "error": response.text}) | ||
return f"Error: {response.text}" | ||
|
||
response_data = response.json() | ||
if "error_message" in response_data: | ||
logger.error("Simulation error", extra={"error": response_data["error_message"]}) | ||
return str(response_data["error_message"]) | ||
elif "protocol_name" in response_data: | ||
return str(response_data["run_status"]) | ||
else: | ||
logger.error("Unexpected response", extra={"response": response_data}) | ||
return "Unexpected response" | ||
|
||
|
||
def main() -> None: | ||
"""Intended for testing this class locally.""" | ||
import sys | ||
from pathlib import Path | ||
|
||
# # Add project root to Python path | ||
root_dir = Path(__file__).parent.parent.parent | ||
sys.path.insert(0, str(root_dir)) | ||
|
||
from rich import print | ||
from rich.prompt import Prompt | ||
|
||
settings = Settings() | ||
llm = AnthropicPredict(settings) | ||
prompt = Prompt.ask("Type a prompt to send to the Anthropic API:") | ||
completion = llm.predict(prompt) | ||
print(completion) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.