-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1150 from andreadimaio/main
Add support for structured output in Ollama
- Loading branch information
Showing
8 changed files
with
382 additions
and
25 deletions.
There are no files selected for viewing
93 changes: 93 additions & 0 deletions
93
...ment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package io.quarkiverse.langchain4j.ollama.deployment; | ||
|
||
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.post; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import jakarta.inject.Inject; | ||
import jakarta.inject.Singleton; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.model.output.structured.Description; | ||
import dev.langchain4j.service.UserName; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkiverse.langchain4j.testing.internal.WiremockAware; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
public class OllamaJsonOutputTest extends WiremockAware { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) | ||
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) | ||
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false") | ||
.overrideConfigKey("quarkus.langchain4j.ollama.chat-model.format", "json"); | ||
|
||
@Description("A person") | ||
public record Person( | ||
@Description("The firstname") String firstname, | ||
@Description("The lastname") String lastname) { | ||
} | ||
|
||
@Singleton | ||
@RegisterAiService | ||
interface AiService { | ||
Person extractPerson(@UserName String text); | ||
} | ||
|
||
@Inject | ||
AiService aiService; | ||
|
||
@Test | ||
void extract() { | ||
wiremock().register( | ||
post(urlEqualTo("/api/chat")) | ||
.withRequestBody(equalToJson( | ||
""" | ||
{ | ||
"model": "llama3.2", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "Tell me something about Alan Wake\\nYou must answer strictly in the following JSON format: {\\n\\\"firstname\\\": (The firstname; type: string),\\n\\\"lastname\\\": (The lastname; type: string)\\n}" | ||
} | ||
], | ||
"stream": false, | ||
"options": { | ||
"temperature": 0.8, | ||
"top_k": 40, | ||
"top_p": 0.9 | ||
}, | ||
"tools": [], | ||
"format": "json" | ||
}""")) | ||
.willReturn(aResponse() | ||
.withHeader("Content-Type", "application/json") | ||
.withBody(""" | ||
{ | ||
"model": "llama3.2", | ||
"created_at": "2024-12-11T15:21:23.422542932Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}" | ||
}, | ||
"done_reason": "stop", | ||
"done": true, | ||
"total_duration": 8125806496, | ||
"load_duration": 4223887064, | ||
"prompt_eval_count": 31, | ||
"prompt_eval_duration": 1331000000, | ||
"eval_count": 18, | ||
"eval_duration": 2569000000 | ||
}"""))); | ||
|
||
var result = aiService.extractPerson("Tell me something about Alan Wake"); | ||
assertEquals(new Person("Alan", "Wake"), result); | ||
} | ||
} |
103 changes: 103 additions & 0 deletions
103
...rc/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStructuredOutputTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
package io.quarkiverse.langchain4j.ollama.deployment; | ||
|
||
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.post; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import jakarta.inject.Inject; | ||
import jakarta.inject.Singleton; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.model.output.structured.Description; | ||
import dev.langchain4j.service.UserName; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkiverse.langchain4j.testing.internal.WiremockAware; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
public class OllamaStructuredOutputTest extends WiremockAware { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) | ||
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) | ||
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false"); | ||
|
||
@Description("A person") | ||
public record Person( | ||
@Description("The firstname") String firstname, | ||
@Description("The lastname") String lastname) { | ||
} | ||
|
||
@Singleton | ||
@RegisterAiService | ||
interface AiService { | ||
Person extractPerson(@UserName String text); | ||
} | ||
|
||
@Inject | ||
AiService aiService; | ||
|
||
@Test | ||
void extract() { | ||
wiremock().register( | ||
post(urlEqualTo("/api/chat")) | ||
.withRequestBody(equalToJson(""" | ||
{ | ||
"model": "llama3.2", | ||
"messages": [{"role": "user", "content": "Tell me something about Alan Wake"}], | ||
"stream": false, | ||
"options" : { | ||
"temperature" : 0.8, | ||
"top_k" : 40, | ||
"top_p" : 0.9 | ||
}, | ||
"format": { | ||
"type": "object", | ||
"description": "A person", | ||
"properties": { | ||
"firstname": { | ||
"description": "The firstname", | ||
"type": "string" | ||
}, | ||
"lastname": { | ||
"description": "The lastname", | ||
"type": "string" | ||
} | ||
}, | ||
"required": [ | ||
"firstname", | ||
"lastname" | ||
] | ||
} | ||
} | ||
""")) | ||
.willReturn(aResponse() | ||
.withHeader("Content-Type", "application/json") | ||
.withBody(""" | ||
{ | ||
"model": "llama3.2", | ||
"created_at": "2024-12-11T15:21:23.422542932Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}" | ||
}, | ||
"done_reason": "stop", | ||
"done": true, | ||
"total_duration": 8125806496, | ||
"load_duration": 4223887064, | ||
"prompt_eval_count": 31, | ||
"prompt_eval_duration": 1331000000, | ||
"eval_count": 18, | ||
"eval_duration": 2569000000 | ||
}"""))); | ||
|
||
var result = aiService.extractPerson("Tell me something about Alan Wake"); | ||
assertEquals(new Person("Alan", "Wake"), result); | ||
} | ||
} |
84 changes: 84 additions & 0 deletions
84
...ment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package io.quarkiverse.langchain4j.ollama.deployment; | ||
|
||
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.post; | ||
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import jakarta.inject.Inject; | ||
import jakarta.inject.Singleton; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.service.UserName; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkiverse.langchain4j.testing.internal.WiremockAware; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
public class OllamaTextOutputTest extends WiremockAware { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) | ||
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) | ||
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false"); | ||
|
||
@Singleton | ||
@RegisterAiService | ||
interface AiService { | ||
String question(@UserName String text); | ||
} | ||
|
||
@Inject | ||
AiService aiService; | ||
|
||
@Test | ||
void extract() { | ||
wiremock().register( | ||
post(urlEqualTo("/api/chat")) | ||
.withRequestBody(equalToJson( | ||
""" | ||
{ | ||
"model": "llama3.2", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "Tell me something about Alan Wake" | ||
} | ||
], | ||
"stream": false, | ||
"options": { | ||
"temperature": 0.8, | ||
"top_k": 40, | ||
"top_p": 0.9 | ||
}, | ||
"tools": [] | ||
}""")) | ||
.willReturn(aResponse() | ||
.withHeader("Content-Type", "application/json") | ||
.withBody(""" | ||
{ | ||
"model": "llama3.2", | ||
"created_at": "2024-12-11T15:21:23.422542932Z", | ||
"message": { | ||
"role": "assistant", | ||
"content": "He is a writer!" | ||
}, | ||
"done_reason": "stop", | ||
"done": true, | ||
"total_duration": 8125806496, | ||
"load_duration": 4223887064, | ||
"prompt_eval_count": 31, | ||
"prompt_eval_duration": 1331000000, | ||
"eval_count": 18, | ||
"eval_duration": 2569000000 | ||
}"""))); | ||
|
||
var result = aiService.question("Tell me something about Alan Wake"); | ||
assertEquals("He is a writer!", result); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
.../ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/FormatJsonSerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package io.quarkiverse.langchain4j.ollama; | ||
|
||
import java.io.IOException; | ||
|
||
import com.fasterxml.jackson.core.JsonGenerator; | ||
import com.fasterxml.jackson.databind.JsonSerializer; | ||
import com.fasterxml.jackson.databind.SerializerProvider; | ||
|
||
public class FormatJsonSerializer extends JsonSerializer<String> { | ||
|
||
@Override | ||
public void serialize(String value, JsonGenerator gen, SerializerProvider serializers) throws IOException { | ||
if (value == null) | ||
return; | ||
else if (value.startsWith("{") && value.endsWith("}")) | ||
gen.writeRawValue(value); | ||
else | ||
gen.writeString(value); | ||
} | ||
} |
Oops, something went wrong.