Skip to content

Commit

Permalink
Merge pull request #1150 from andreadimaio/main
Browse files Browse the repository at this point in the history
Add support for structured output in Ollama
  • Loading branch information
geoand authored Dec 12, 2024
2 parents e7b1bf4 + cead717 commit d2a0f7e
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.UserName;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.test.QuarkusUnitTest;

public class OllamaJsonOutputTest extends WiremockAware {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false")
.overrideConfigKey("quarkus.langchain4j.ollama.chat-model.format", "json");

@Description("A person")
public record Person(
@Description("The firstname") String firstname,
@Description("The lastname") String lastname) {
}

@Singleton
@RegisterAiService
interface AiService {
Person extractPerson(@UserName String text);
}

@Inject
AiService aiService;

@Test
void extract() {
wiremock().register(
post(urlEqualTo("/api/chat"))
.withRequestBody(equalToJson(
"""
{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "Tell me something about Alan Wake\\nYou must answer strictly in the following JSON format: {\\n\\\"firstname\\\": (The firstname; type: string),\\n\\\"lastname\\\": (The lastname; type: string)\\n}"
}
],
"stream": false,
"options": {
"temperature": 0.8,
"top_k": 40,
"top_p": 0.9
},
"tools": [],
"format": "json"
}"""))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"model": "llama3.2",
"created_at": "2024-12-11T15:21:23.422542932Z",
"message": {
"role": "assistant",
"content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}"
},
"done_reason": "stop",
"done": true,
"total_duration": 8125806496,
"load_duration": 4223887064,
"prompt_eval_count": 31,
"prompt_eval_duration": 1331000000,
"eval_count": 18,
"eval_duration": 2569000000
}""")));

var result = aiService.extractPerson("Tell me something about Alan Wake");
assertEquals(new Person("Alan", "Wake"), result);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.UserName;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.test.QuarkusUnitTest;

public class OllamaStructuredOutputTest extends WiremockAware {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false");

@Description("A person")
public record Person(
@Description("The firstname") String firstname,
@Description("The lastname") String lastname) {
}

@Singleton
@RegisterAiService
interface AiService {
Person extractPerson(@UserName String text);
}

@Inject
AiService aiService;

@Test
void extract() {
wiremock().register(
post(urlEqualTo("/api/chat"))
.withRequestBody(equalToJson("""
{
"model": "llama3.2",
"messages": [{"role": "user", "content": "Tell me something about Alan Wake"}],
"stream": false,
"options" : {
"temperature" : 0.8,
"top_k" : 40,
"top_p" : 0.9
},
"format": {
"type": "object",
"description": "A person",
"properties": {
"firstname": {
"description": "The firstname",
"type": "string"
},
"lastname": {
"description": "The lastname",
"type": "string"
}
},
"required": [
"firstname",
"lastname"
]
}
}
"""))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"model": "llama3.2",
"created_at": "2024-12-11T15:21:23.422542932Z",
"message": {
"role": "assistant",
"content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}"
},
"done_reason": "stop",
"done": true,
"total_duration": 8125806496,
"load_duration": 4223887064,
"prompt_eval_count": 31,
"prompt_eval_duration": 1331000000,
"eval_count": 18,
"eval_duration": 2569000000
}""")));

var result = aiService.extractPerson("Tell me something about Alan Wake");
assertEquals(new Person("Alan", "Wake"), result);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.service.UserName;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.test.QuarkusUnitTest;

public class OllamaTextOutputTest extends WiremockAware {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false");

@Singleton
@RegisterAiService
interface AiService {
String question(@UserName String text);
}

@Inject
AiService aiService;

@Test
void extract() {
wiremock().register(
post(urlEqualTo("/api/chat"))
.withRequestBody(equalToJson(
"""
{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "Tell me something about Alan Wake"
}
],
"stream": false,
"options": {
"temperature": 0.8,
"top_k": 40,
"top_p": 0.9
},
"tools": []
}"""))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"model": "llama3.2",
"created_at": "2024-12-11T15:21:23.422542932Z",
"message": {
"role": "assistant",
"content": "He is a writer!"
},
"done_reason": "stop",
"done": true,
"total_duration": 8125806496,
"load_duration": 4223887064,
"prompt_eval_count": 31,
"prompt_eval_duration": 1331000000,
"eval_count": 18,
"eval_duration": 2569000000
}""")));

var result = aiService.question("Tell me something about Alan Wake");
assertEquals("He is a writer!", result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import java.util.List;

public record ChatRequest(String model, List<Message> messages, List<Tool> tools, Options options, String format,
import com.fasterxml.jackson.databind.annotation.JsonSerialize;

public record ChatRequest(
String model,
List<Message> messages,
List<Tool> tools,
Options options,
@JsonSerialize(using = FormatJsonSerializer.class) String format,
Boolean stream) {

public static Builder builder() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.quarkiverse.langchain4j.ollama;

import java.io.IOException;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;

public class FormatJsonSerializer extends JsonSerializer<String> {

@Override
public void serialize(String value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
if (value == null)
return;
else if (value.startsWith("{") && value.endsWith("}"))
gen.writeRawValue(value);
else
gen.writeString(value);
}
}
Loading

0 comments on commit d2a0f7e

Please sign in to comment.