Skip to content

Commit

Permalink
[NOID] Fixes #3634: Updated ML procs for Azure OpenAI services (#3850)
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Nov 27, 2024
1 parent e571e6f commit 89e6571
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 1 deletion.
64 changes: 63 additions & 1 deletion docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,60 @@

NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incur costs on your OpenAI account. You can set the api key globally by defining the `apoc.openai.key` configuration in `apoc.conf`



All the following procedures can have the following APOC config, i.e. in `apoc.conf` or via docker env variable
.Apoc configuration
|===
|key | description | default
| apoc.ml.openai.type | "AZURE" or "OPENAI", indicates whether the API is Azure or not | "OPENAI"
| apoc.ml.openai.url | the OpenAI endpoint base url | https://api.openai.com/v1
(or empty string if `apoc.ml.openai.type=AZURE`)
| apoc.ml.azure.api.version | in case of `apoc.ml.openai.type=AZURE`, indicates the `api-version` to be passed after the `?api-version=` url | ""
|===


Moreover, they can have the following configuration keys, as the last parameter.
If present, they take precedence over the analogous APOC configs.

.Common configuration parameter

|===
| key | description
| apiType | analogous to `apoc.ml.openai.type` APOC config
| endpoint | analogous to `apoc.ml.openai.url` APOC config
| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config
|===


Therefore, we can use the following procedures with the Open AI Services provided by Azure,
pointing to the correct endpoints https://learn.microsoft.com/it-it/azure/ai-services/openai/reference[as explained in the documentation].

That is, if we want to call an endpoint like https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/embeddings?api-version=my-api-version` for example,
by passing as a configuration parameter:
```
{endpoint: "https://my-resource.openai.azure.com/openai/deployments/my-deployment-id",
apiVersion: my-api-version,
apiType: 'AZURE'
}
```

The `/embeddings` portion will be added under-the-hood.
Similarly, if we use the `apoc.ml.openai.completion`, if we want to call an endpoint like `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/completions?api-version=my-api-version` for example,
we can write the same configuration parameter as above,
where the `/completions` portion will be added.

While using the `apoc.ml.openai.chat`, with the same configuration, the url portion `/chat/completions` will be added

Or else, we can write this `apoc.conf`:
```
apoc.ml.openai.url=https://my-resource.openai.azure.com/openai/deployments/my-deployment-id
apoc.ml.azure.api.version=my-api-version
apoc.ml.openai.type=AZURE
```



== Generate Embeddings API

This procedure `apoc.ml.openai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector.
Expand All @@ -30,7 +84,15 @@ CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, {}) yield index, text, emb
|name | description
| texts | List of text strings
| apiKey | OpenAI API key
| configuration | optional map for entries like model and other request parameters
| configuration | optional map for entries like model and other request parameters.

We can also pass a custom `endpoint: <MyAndPointKey>` entry (it takes precedence over the `apoc.ml.openai.url` config).
The `<MyAndPointKey>` can be the complete andpoint (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/chat/completions?api-version=my-api-version`),
or with a `%s` (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/%s?api-version=my-api-version`) which will eventually be replaced with `embeddings`, `chat/completion` and `completion`
by using respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.chat` and `apoc.ml.openai.completion`.

Or an `authType: `AUTH_TYPE`, which can be `authType: "BEARER"` (default config.), to pass the apiKey via the header as an `Authorization: Bearer $apiKey`,
or `authType: "API_KEY"` to pass the apiKey as an `api-key: $apiKey` header entry.
|===


Expand Down
92 changes: 92 additions & 0 deletions full/src/main/java/apoc/ml/OpenAIRequestHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package apoc.ml;

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;

import apoc.ApocConfig;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;

abstract class OpenAIRequestHandler {

private final String defaultUrl;

public OpenAIRequestHandler(String defaultUrl) {
this.defaultUrl = defaultUrl;
}

public String getDefaultUrl() {
return defaultUrl;
}

public abstract String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig);

public abstract void addApiKey(Map<String, Object> headers, String apiKey);

public String getEndpoint(Map<String, Object> procConfig, ApocConfig apocConfig) {
return (String) procConfig.getOrDefault(
ENDPOINT_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl())));
}

public String getFullUrl(String method, Map<String, Object> procConfig, ApocConfig apocConfig) {
return Stream.of(getEndpoint(procConfig, apocConfig), method, getApiVersion(procConfig, apocConfig))
.filter(StringUtils::isNotBlank)
.collect(Collectors.joining("/"));
}

enum Type {
AZURE(new Azure(null)),
OPENAI(new OpenAi("https://api.openai.com/v1"));

private final OpenAIRequestHandler handler;

Type(OpenAIRequestHandler handler) {
this.handler = handler;
}

public OpenAIRequestHandler get() {
return handler;
}
}

static class Azure extends OpenAIRequestHandler {

public Azure(String defaultUrl) {
super(defaultUrl);
}

@Override
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) {
return "?api-version="
+ configuration.getOrDefault(
API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION));
}

@Override
public void addApiKey(Map<String, Object> headers, String apiKey) {
headers.put("api-key", apiKey);
}
}

static class OpenAi extends OpenAIRequestHandler {

public OpenAi(String defaultUrl) {
super(defaultUrl);
}

@Override
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) {
return "";
}

@Override
public void addApiKey(Map<String, Object> headers, String apiKey) {
headers.put("Authorization", "Bearer " + apiKey);
}
}
}
100 changes: 100 additions & 0 deletions full/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package apoc.ml;

import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
import static apoc.ml.OpenAITestResultUtils.assertCompletion;
import static apoc.util.TestUtil.testCall;
import static org.junit.Assume.assumeNotNull;

import apoc.util.TestUtil;
import java.util.Map;
import java.util.stream.Stream;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Ignore;
import org.junit.Test;
import org.neo4j.test.rule.DbmsRule;
import org.neo4j.test.rule.ImpermanentDbmsRule;

public class OpenAIAzureIT {
// In Azure, the endpoints can be different
private static String OPENAI_EMBEDDING_URL;
private static String OPENAI_CHAT_URL;
private static String OPENAI_COMPLETION_URL;

private static String OPENAI_AZURE_API_VERSION;

private static String OPENAI_KEY;

@ClassRule
public static DbmsRule db = new ImpermanentDbmsRule();

@BeforeClass
public static void setUp() throws Exception {
OPENAI_KEY = System.getenv("OPENAI_KEY");
// Azure OpenAI base URLs
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL");
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL");
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL");

// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>`)
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION");

Stream.of(OPENAI_EMBEDDING_URL, OPENAI_CHAT_URL, OPENAI_COMPLETION_URL, OPENAI_AZURE_API_VERSION, OPENAI_KEY)
.forEach(key -> assumeNotNull("No " + key + " environment configured", key));

TestUtil.registerProcedure(db, OpenAI.class);
}

@Test
public void embedding() {
testCall(
db,
"CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)",
getParams(OPENAI_EMBEDDING_URL),
OpenAITestResultUtils::assertEmbeddings);
}

@Test
@Ignore("It returns wrong answers sometimes")
public void completion() {
testCall(
db,
"CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)",
getParams(OPENAI_CHAT_URL),
(row) -> assertCompletion(row, "gpt-35-turbo"));
}

@Test
public void chatCompletion() {
testCall(
db,
"""
CALL apoc.ml.openai.chat([
{role:"system", content:"Only answer with a single word"},
{role:"user", content:"What planet do humans live on?"}
], $apiKey, $conf)
""",
getParams(OPENAI_COMPLETION_URL),
(row) -> assertChatCompletion(row, "gpt-35-turbo"));
}

private static Map<String, Object> getParams(String url) {
return Map.of(
"apiKey",
OPENAI_KEY,
"conf",
Map.of(
ENDPOINT_CONF_KEY,
url,
API_TYPE_CONF_KEY,
OpenAIRequestHandler.Type.AZURE.name(),
API_VERSION_CONF_KEY,
OPENAI_AZURE_API_VERSION,
// on Azure is available only "gpt-35-turbo"
"model",
"gpt-35-turbo"));
}
}
48 changes: 48 additions & 0 deletions full/src/test/java/apoc/ml/OpenAITestResultUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package apoc.ml;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.List;
import java.util.Map;

public class OpenAITestResultUtils {
public static void assertEmbeddings(Map<String, Object> row) {
assertEquals(0L, row.get("index"));
assertEquals("Some Text", row.get("text"));
var embedding = (List<Double>) row.get("embedding");
assertEquals(1536, embedding.size());
}

public static void assertCompletion(Map<String, Object> row, String expectedModel) {
var result = (Map<String, Object>) row.get("value");
assertTrue(result.get("created") instanceof Number);
assertTrue(result.containsKey("choices"));
var finishReason = (String) ((List<Map>) result.get("choices")).get(0).get("finish_reason");
assertTrue(finishReason.matches("stop|length"));
String text = (String) ((List<Map>) result.get("choices")).get(0).get("text");
System.out.println("OpenAI text response for assertCompletion = " + text);
assertTrue(text != null && !text.isBlank());
assertTrue(text.toLowerCase().contains("blue"));
assertTrue(result.containsKey("usage"));
assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number);
assertEquals(expectedModel, result.get("model"));
assertEquals("text_completion", result.get("object"));
}

public static void assertChatCompletion(Map<String, Object> row, String modelId) {
var result = (Map<String, Object>) row.get("value");
assertTrue(result.get("created") instanceof Number);
assertTrue(result.containsKey("choices"));

Map message = ((List<Map<String, Map>>) result.get("choices")).get(0).get("message");
assertEquals("assistant", message.get("role"));
String text = (String) message.get("content");
assertTrue(text != null && !text.isBlank());

assertTrue(result.containsKey("usage"));
assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number);
assertEquals("chat.completion", result.get("object"));
assertTrue(result.get("model").toString().startsWith(modelId));
}
}

0 comments on commit 89e6571

Please sign in to comment.