Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Cohere notebooks to use API V2 #3473

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions sdk/python/foundation-models/cohere/cohere-aisearch-rag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"outputs": [],
"source": [
"# Set up the embedding model to be used in the vector index\n",
"co_embed = cohere.Client(\n",
"co_embed = cohere.ClientV2(\n",
" base_url=azure_cohere_embed_endpoint, api_key=azure_cohere_embed_key\n",
")"
]
Expand Down Expand Up @@ -212,11 +212,13 @@
"outputs": [],
"source": [
"# embed each of the descriptions\n",
"# you will notice that Cohere has a field called \"input_type\" which can be set to \"search_document\", \"search_query\", \"classification\", or \"clustering\" depedning on the text you are embedding\n",
"for doc in docs_to_index:\n",
" doc[\"descriptionVector\"] = co_embed.embed(\n",
" texts=[doc[\"description\"]], input_type=\"search_document\"\n",
" ).embeddings[0]"
" model=\"azureai\",\n",
" texts=[doc[\"description\"]],\n",
" input_type=\"search_document\", # the type of content being embedded. Can be one of \"search_document, \"search_query\", \"classification\", \"clustering\", or \"image\"\n",
" embedding_types=[\"float\"], # the format of the embeddings. Can be one or more of \"float\", \"int8\", \"uint8\", \"binary\"\n",
" ).embeddings[\"float\"][0]"
]
},
{
Expand Down Expand Up @@ -329,7 +331,10 @@
" list: A list of search results.\n",
" \"\"\"\n",
" query_embedding = co_embed.embed(\n",
" texts=[query], input_type=\"search_query\"\n",
" model=\"azureai\",\n",
" texts=[query],\n",
" input_type=\"search_query\", # the type of content being embedded. Can be one of \"search_document, \"search_query\", \"classification\", \"clustering\", or \"image\"\n",
" embedding_types=[\"float\"], # the format of the embeddings. Can be one or more of \"float\", \"int8\", \"uint8\", \"binary\"\n",
" ).embeddings[0]\n",
"\n",
" # Azure AI search requires a vector query\n",
Expand Down Expand Up @@ -375,7 +380,7 @@
"metadata": {},
"outputs": [],
"source": [
"co_chat = cohere.Client(\n",
"co_chat = cohere.ClientV2(\n",
" base_url=azure_cohere_command_endpoint, api_key=azure_cohere_command_key\n",
")"
]
Expand All @@ -401,14 +406,24 @@
" # select category, description, and hotelName from the search results\n",
" documents = [\n",
" {\n",
" \"category\": result[\"category\"],\n",
" \"description\": result[\"description\"],\n",
" \"hotelName\": result[\"hotelName\"],\n",
" \"id\": f\"{index}\", # we set the id to the document index\n",
" \"data\": {\n",
" \"category\": result[\"category\"],\n",
" \"description\": result[\"description\"],\n",
" \"hotelName\": result[\"hotelName\"],\n",
" }\n",
" }\n",
" for result in search_results\n",
" for index, result in enumerate(search_results)\n",
" ]\n",
"\n",
" response = co_chat.chat(message=question, documents=documents)\n",
" response = co_chat.chat(\n",
" model=\"azureai\",\n",
" messages=[{\n",
" \"role\": \"user\",\n",
" \"content\": question\n",
" }],\n",
" documents=documents,\n",
" )\n",
"\n",
" return response"
]
Expand All @@ -428,7 +443,7 @@
"metadata": {},
"source": [
"## Clean the results\n",
"We can also pull the citations and text response from the response"
"We can also pull the citations and text answer from the response"
]
},
{
Expand All @@ -438,18 +453,21 @@
"outputs": [],
"source": [
"def pretty_text(text, citations):\n",
" # Sort citations by start position to avoid issues when altering text indices\n",
" sorted_citations = sorted(citations, key=lambda x: x.start, reverse=True)\n",
"\n",
" # Process each citation in reverse order to prevent index shifting\n",
" for citation in sorted_citations:\n",
" doc_ids_str = \", \".join(citation.document_ids)\n",
" citation_text = text[citation.start : citation.end]\n",
" text_with_citations = \"\"\n",
" text_start_index = 0\n",
" for citation in citations:\n",
" doc_ids_str = \", \".join([source.id for source in citation.sources])\n",
" citated_text = text[citation.start : citation.end]\n",
" # Bold the citation text and add document ids as superscript\n",
" new_text = f\"**{citation_text}**^({doc_ids_str})\"\n",
" text = text[: citation.start] + new_text + text[citation.end :]\n",
" cited_text_with_ids = f\"**{citated_text}**^({doc_ids_str})\"\n",
" text_with_citations = text[text_start_index : citation.start] + cited_text_with_ids\n",
" text_start_index = citation.end\n",
"\n",
" text_with_citations += text[text_start_index:]\n",
"\n",
" return text"
" return text_with_citations"
]
},
{
Expand All @@ -458,7 +476,7 @@
"metadata": {},
"outputs": [],
"source": [
"pretty_text_output = pretty_text(res.text, res.citations)"
"pretty_text_output = pretty_text(res.message.content[0].text, res.message.citations)"
]
},
{
Expand Down
12 changes: 8 additions & 4 deletions sdk/python/foundation-models/cohere/cohere-cmdR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
"outputs": [],
"source": [
"co = cohere.Client(\n",
"co = cohere.ClientV2(\n",
" base_url=\"https://<endpoint>.<region>.inference.ai.azure.com/v1\", api_key=\"<key>\"\n",
")"
]
Expand All @@ -99,8 +99,12 @@
},
"outputs": [],
"source": [
"chat_response = co.chat(\n",
" message=\"Who is the most renowned French painter? Provide a short answer.\"\n",
"res = co.chat(\n",
" model=\"azureai\"\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing comma here.

" messages=[{\n",
" \"role\": \"user\",\n",
" \"content\": \"Who is the most renowned French painter? Provide a short answer.\"\n",
" }]\n",
")"
]
},
Expand All @@ -119,7 +123,7 @@
},
"outputs": [],
"source": [
"print(chat_response.text)"
"print(res.message.content[0].text)"
]
},
{
Expand Down
13 changes: 10 additions & 3 deletions sdk/python/foundation-models/cohere/cohere-embed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
"outputs": [],
"source": [
"co = cohere.Client(\n",
"co = cohere.ClientV2(\n",
" base_url=\"https://<endpoint>.<region>.inference.ai.azure.com/v1\", api_key=\"<key>\"\n",
")"
]
Expand All @@ -100,8 +100,10 @@
"outputs": [],
"source": [
"response = co.embed(\n",
" model=\"azureai\",\n",
" texts=[\"Who is the most renowned French painter? Provide a short answer.\"],\n",
" input_type=\"query\",\n",
" input_type=\"search_query\", # the type of content being embedded. Can be one of \"search_document, \"search_query\", \"classification\", \"clustering\", or \"image\"\n",
" embedding_types=[\"float\"], # the format of the embeddings. Can be one or more of \"float\", \"int8\", \"uint8\", \"binary\"\n",
")"
]
},
Expand Down Expand Up @@ -131,7 +133,12 @@
" # Convert the base64 bytes to a string\n",
" base64_string = base64_encoded_data.decode(\"utf-8\")\n",
"\n",
"co.embed(images=[base64_string], input_type=\"image\")"
"co.embed(\n",
" model=\"azureai\",\n",
" images=[base64_string],\n",
" input_type=\"image\", # the type of content being embedded. Can be one of \"search_document, \"search_query\", \"classification\", \"clustering\", or \"image\"\n",
" embedding_types=[\"float\"], # the format of the embeddings. Can be one or more of \"float\", \"int8\", \"uint8\", \"binary\"\n",
")"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
"outputs": [],
"source": [
"co = cohere.Client(\n",
"co = cohere.ClientV2(\n",
" base_url=\"https://<endpoint>.<region>.inference.ai.azure.com/v1\", api_key=\"<key>\"\n",
")"
]
Expand All @@ -99,6 +99,8 @@
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"documents = [\n",
" {\n",
" \"Title\": \"Incorrect Password\",\n",
Expand Down Expand Up @@ -139,9 +141,9 @@
"]\n",
"\n",
"response = co.rerank(\n",
" documents=documents,\n",
" model=\"azureai\"\n",
" documents=[yaml.dump(doc, sort_keys=False) for doc in documents],\n",
" query=\"What emails have been about returning items?\",\n",
" rank_fields=[\"Title\", \"Content\"],\n",
" top_n=5,\n",
")"
]
Expand Down
29 changes: 5 additions & 24 deletions sdk/python/foundation-models/cohere/rerank-webrequests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@
"* The payload or data, which is your prompt detail and model hyper parameters."
]
},
{
"source": [
"!curl --request POST \\\n",
" --url https://your-endpoint.inference.ai/azure/com/v1/rerank \\\n",
" --header 'Authorization: Bearer your-auth-key' \\\n",
" --header 'Cohere-Version: 2022-12-06' \\\n",
" --header 'Content-Type: application/json' \\\n",
" --data '{\"query\": \"What is the capital of the United States?\", \"model\":\"rerank-english-v3.0\",\"return_documents\": true, \"documents\" : [\"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.\",\"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.\",\"Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.\",\"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.\",\"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.\"],\"top_n\": 3}'"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -39,40 +29,31 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'{\"id\":\"3750ff1f-a1d8-4824-89e4-3cd7b1eb6447\",\"results\":[{\"document\":{\"text\":\"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.\"},\"index\":3,\"relevance_score\":0.9990564},{\"document\":{\"text\":\"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.\"},\"index\":4,\"relevance_score\":0.7516481},{\"document\":{\"text\":\"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.\"},\"index\":1,\"relevance_score\":0.08882029}],\"meta\":{\"api_version\":{\"version\":\"unspecified\",\"is_deprecated\":true},\"warnings\":[\"Please set an API version, for more information please refer to https://docs.cohere.com/versioning-reference\",\"Version is deprecated, for more information please refer to https://docs.cohere.com/versioning-reference\"],\"billed_units\":{\"search_units\":1}}}'\n"
]
}
],
"outputs": [],
"source": [
"import urllib.request\n",
"import json\n",
"\n",
"# Configure payload data sending to API endpoint\n",
"data = {\n",
" \"model\": \"azureai\",\n",
" \"query\": \"What is the capital of the United States?\",\n",
" \"model\": \"rerank-english-v3.0\",\n",
" \"return_documents\": True,\n",
" \"documents\": [\n",
" \"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.\",\n",
" \"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.\",\n",
" \"Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.\",\n",
" \"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.\",\n",
" \"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.\",\n",
" \"Capital punishment has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.\",\n",
" ],\n",
" \"top_n\": 3,\n",
"}\n",
"\n",
"body = str.encode(json.dumps(data))\n",
"\n",
"# Replace the url with your API endpoint\n",
"url = \"https://your-endpoint.inference.ai/azure/com/v1/rerank\"\n",
"url = \"https://your-endpoint.inference.ai/azure/com/v2/rerank\"\n",
"\n",
"# Replace this with the key for the endpoint\n",
"api_key = \"bearer <your-api-key>\"\n",
Expand Down