-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
4 changed files
with
303 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package apoc.ml; | ||
|
||
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION; | ||
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL; | ||
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; | ||
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; | ||
|
||
import apoc.ApocConfig; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.Stream; | ||
import org.apache.commons.lang3.StringUtils; | ||
|
||
abstract class OpenAIRequestHandler { | ||
|
||
private final String defaultUrl; | ||
|
||
public OpenAIRequestHandler(String defaultUrl) { | ||
this.defaultUrl = defaultUrl; | ||
} | ||
|
||
public String getDefaultUrl() { | ||
return defaultUrl; | ||
} | ||
|
||
public abstract String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig); | ||
|
||
public abstract void addApiKey(Map<String, Object> headers, String apiKey); | ||
|
||
public String getEndpoint(Map<String, Object> procConfig, ApocConfig apocConfig) { | ||
return (String) procConfig.getOrDefault( | ||
ENDPOINT_CONF_KEY, | ||
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl()))); | ||
} | ||
|
||
public String getFullUrl(String method, Map<String, Object> procConfig, ApocConfig apocConfig) { | ||
return Stream.of(getEndpoint(procConfig, apocConfig), method, getApiVersion(procConfig, apocConfig)) | ||
.filter(StringUtils::isNotBlank) | ||
.collect(Collectors.joining("/")); | ||
} | ||
|
||
enum Type { | ||
AZURE(new Azure(null)), | ||
OPENAI(new OpenAi("https://api.openai.com/v1")); | ||
|
||
private final OpenAIRequestHandler handler; | ||
|
||
Type(OpenAIRequestHandler handler) { | ||
this.handler = handler; | ||
} | ||
|
||
public OpenAIRequestHandler get() { | ||
return handler; | ||
} | ||
} | ||
|
||
static class Azure extends OpenAIRequestHandler { | ||
|
||
public Azure(String defaultUrl) { | ||
super(defaultUrl); | ||
} | ||
|
||
@Override | ||
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) { | ||
return "?api-version=" | ||
+ configuration.getOrDefault( | ||
API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION)); | ||
} | ||
|
||
@Override | ||
public void addApiKey(Map<String, Object> headers, String apiKey) { | ||
headers.put("api-key", apiKey); | ||
} | ||
} | ||
|
||
static class OpenAi extends OpenAIRequestHandler { | ||
|
||
public OpenAi(String defaultUrl) { | ||
super(defaultUrl); | ||
} | ||
|
||
@Override | ||
public String getApiVersion(Map<String, Object> configuration, ApocConfig apocConfig) { | ||
return ""; | ||
} | ||
|
||
@Override | ||
public void addApiKey(Map<String, Object> headers, String apiKey) { | ||
headers.put("Authorization", "Bearer " + apiKey); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package apoc.ml; | ||
|
||
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; | ||
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; | ||
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; | ||
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; | ||
import static apoc.ml.OpenAITestResultUtils.assertCompletion; | ||
import static apoc.util.TestUtil.testCall; | ||
import static org.junit.Assume.assumeNotNull; | ||
|
||
import apoc.util.TestUtil; | ||
import java.util.Map; | ||
import java.util.stream.Stream; | ||
import org.junit.BeforeClass; | ||
import org.junit.ClassRule; | ||
import org.junit.Ignore; | ||
import org.junit.Test; | ||
import org.neo4j.test.rule.DbmsRule; | ||
import org.neo4j.test.rule.ImpermanentDbmsRule; | ||
|
||
public class OpenAIAzureIT { | ||
// In Azure, the endpoints can be different | ||
private static String OPENAI_EMBEDDING_URL; | ||
private static String OPENAI_CHAT_URL; | ||
private static String OPENAI_COMPLETION_URL; | ||
|
||
private static String OPENAI_AZURE_API_VERSION; | ||
|
||
private static String OPENAI_KEY; | ||
|
||
@ClassRule | ||
public static DbmsRule db = new ImpermanentDbmsRule(); | ||
|
||
@BeforeClass | ||
public static void setUp() throws Exception { | ||
OPENAI_KEY = System.getenv("OPENAI_KEY"); | ||
// Azure OpenAI base URLs | ||
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL"); | ||
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL"); | ||
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL"); | ||
|
||
// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>`) | ||
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION"); | ||
|
||
Stream.of(OPENAI_EMBEDDING_URL, OPENAI_CHAT_URL, OPENAI_COMPLETION_URL, OPENAI_AZURE_API_VERSION, OPENAI_KEY) | ||
.forEach(key -> assumeNotNull("No " + key + " environment configured", key)); | ||
|
||
TestUtil.registerProcedure(db, OpenAI.class); | ||
} | ||
|
||
@Test | ||
public void embedding() { | ||
testCall( | ||
db, | ||
"CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)", | ||
getParams(OPENAI_EMBEDDING_URL), | ||
OpenAITestResultUtils::assertEmbeddings); | ||
} | ||
|
||
@Test | ||
@Ignore("It returns wrong answers sometimes") | ||
public void completion() { | ||
testCall( | ||
db, | ||
"CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)", | ||
getParams(OPENAI_CHAT_URL), | ||
(row) -> assertCompletion(row, "gpt-35-turbo")); | ||
} | ||
|
||
@Test | ||
public void chatCompletion() { | ||
testCall( | ||
db, | ||
""" | ||
CALL apoc.ml.openai.chat([ | ||
{role:"system", content:"Only answer with a single word"}, | ||
{role:"user", content:"What planet do humans live on?"} | ||
], $apiKey, $conf) | ||
""", | ||
getParams(OPENAI_COMPLETION_URL), | ||
(row) -> assertChatCompletion(row, "gpt-35-turbo")); | ||
} | ||
|
||
private static Map<String, Object> getParams(String url) { | ||
return Map.of( | ||
"apiKey", | ||
OPENAI_KEY, | ||
"conf", | ||
Map.of( | ||
ENDPOINT_CONF_KEY, | ||
url, | ||
API_TYPE_CONF_KEY, | ||
OpenAIRequestHandler.Type.AZURE.name(), | ||
API_VERSION_CONF_KEY, | ||
OPENAI_AZURE_API_VERSION, | ||
// on Azure is available only "gpt-35-turbo" | ||
"model", | ||
"gpt-35-turbo")); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package apoc.ml; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
import static org.junit.jupiter.api.Assertions.assertTrue; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
public class OpenAITestResultUtils { | ||
public static void assertEmbeddings(Map<String, Object> row) { | ||
assertEquals(0L, row.get("index")); | ||
assertEquals("Some Text", row.get("text")); | ||
var embedding = (List<Double>) row.get("embedding"); | ||
assertEquals(1536, embedding.size()); | ||
} | ||
|
||
public static void assertCompletion(Map<String, Object> row, String expectedModel) { | ||
var result = (Map<String, Object>) row.get("value"); | ||
assertTrue(result.get("created") instanceof Number); | ||
assertTrue(result.containsKey("choices")); | ||
var finishReason = (String) ((List<Map>) result.get("choices")).get(0).get("finish_reason"); | ||
assertTrue(finishReason.matches("stop|length")); | ||
String text = (String) ((List<Map>) result.get("choices")).get(0).get("text"); | ||
System.out.println("OpenAI text response for assertCompletion = " + text); | ||
assertTrue(text != null && !text.isBlank()); | ||
assertTrue(text.toLowerCase().contains("blue")); | ||
assertTrue(result.containsKey("usage")); | ||
assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number); | ||
assertEquals(expectedModel, result.get("model")); | ||
assertEquals("text_completion", result.get("object")); | ||
} | ||
|
||
public static void assertChatCompletion(Map<String, Object> row, String modelId) { | ||
var result = (Map<String, Object>) row.get("value"); | ||
assertTrue(result.get("created") instanceof Number); | ||
assertTrue(result.containsKey("choices")); | ||
|
||
Map message = ((List<Map<String, Map>>) result.get("choices")).get(0).get("message"); | ||
assertEquals("assistant", message.get("role")); | ||
String text = (String) message.get("content"); | ||
assertTrue(text != null && !text.isBlank()); | ||
|
||
assertTrue(result.containsKey("usage")); | ||
assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number); | ||
assertEquals("chat.completion", result.get("object")); | ||
assertTrue(result.get("model").toString().startsWith(modelId)); | ||
} | ||
} |