-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #201 from sebastienblanc/ollama-embedding
Add embeddings endpoint support for Ollama
- Loading branch information
Showing
9 changed files
with
241 additions
and
10 deletions.
There are no files selected for viewing
16 changes: 16 additions & 0 deletions
16
...src/main/java/io/quarkiverse/langchain4j/ollama/deployment/EmbeddingModelBuildConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package io.quarkiverse.langchain4j.ollama.deployment; | ||
|
||
import java.util.Optional; | ||
|
||
import io.quarkus.runtime.annotations.ConfigDocDefault; | ||
import io.quarkus.runtime.annotations.ConfigGroup; | ||
|
||
@ConfigGroup | ||
public interface EmbeddingModelBuildConfig { | ||
|
||
/** | ||
* Whether the model should be enabled | ||
*/ | ||
@ConfigDocDefault("true") | ||
Optional<Boolean> enabled(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package io.quarkiverse.langchain4j.ollama; | ||
|
||
public class EmbeddingRequest { | ||
|
||
private final String model; | ||
private final String prompt; | ||
|
||
private EmbeddingRequest(Builder builder) { | ||
model = builder.model; | ||
prompt = builder.prompt; | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public String getModel() { | ||
return model; | ||
} | ||
|
||
public String getPrompt() { | ||
return prompt; | ||
} | ||
|
||
public static final class Builder { | ||
private String model = "llama2"; | ||
private String prompt; | ||
|
||
private Builder() { | ||
} | ||
|
||
public Builder model(String val) { | ||
model = val; | ||
return this; | ||
} | ||
|
||
public Builder prompt(String val) { | ||
prompt = val; | ||
return this; | ||
} | ||
|
||
public EmbeddingRequest build() { | ||
return new EmbeddingRequest(this); | ||
} | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package io.quarkiverse.langchain4j.ollama; | ||
|
||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; | ||
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; | ||
|
||
@JsonDeserialize(builder = EmbeddingResponse.Builder.class) | ||
public class EmbeddingResponse { | ||
|
||
private float[] embedding; | ||
|
||
private EmbeddingResponse(Builder builder) { | ||
embedding = builder.embedding; | ||
} | ||
|
||
public float[] getEmbedding() { | ||
return embedding; | ||
} | ||
|
||
public void setEmbedding(float[] embedding) { | ||
this.embedding = embedding; | ||
} | ||
|
||
@JsonPOJOBuilder(withPrefix = "") | ||
public static final class Builder { | ||
private float[] embedding; | ||
|
||
private Builder() { | ||
} | ||
|
||
public Builder embedding(float[] val) { | ||
embedding = val; | ||
return this; | ||
} | ||
|
||
public EmbeddingResponse build() { | ||
return new EmbeddingResponse(this); | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaEmbeddingModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package io.quarkiverse.langchain4j.ollama; | ||
|
||
import java.time.Duration; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import dev.langchain4j.data.embedding.Embedding; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.embedding.EmbeddingModel; | ||
import dev.langchain4j.model.output.Response; | ||
|
||
public class OllamaEmbeddingModel implements EmbeddingModel { | ||
|
||
private final OllamaClient client; | ||
private final String model; | ||
|
||
private OllamaEmbeddingModel(Builder builder) { | ||
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses); | ||
model = builder.model; | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
@Override | ||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) { | ||
List<Embedding> embeddings = new ArrayList<>(); | ||
|
||
textSegments.forEach(textSegment -> { | ||
EmbeddingRequest request = EmbeddingRequest.builder() | ||
.model(model) | ||
.prompt(textSegment.text()) | ||
.build(); | ||
|
||
EmbeddingResponse response = client.embedding(request); | ||
|
||
embeddings.add(Embedding.from(response.getEmbedding())); | ||
}); | ||
|
||
return Response.from(embeddings); | ||
} | ||
|
||
public static final class Builder { | ||
private String baseUrl = "http://localhost:11434"; | ||
private Duration timeout = Duration.ofSeconds(10); | ||
private String model; | ||
|
||
private boolean logRequests = false; | ||
private boolean logResponses = false; | ||
|
||
private Builder() { | ||
} | ||
|
||
public Builder baseUrl(String val) { | ||
baseUrl = val; | ||
return this; | ||
} | ||
|
||
public Builder timeout(Duration val) { | ||
this.timeout = val; | ||
return this; | ||
} | ||
|
||
public Builder model(String val) { | ||
model = val; | ||
return this; | ||
} | ||
|
||
public Builder logRequests(boolean logRequests) { | ||
this.logRequests = logRequests; | ||
return this; | ||
} | ||
|
||
public Builder logResponses(boolean logResponses) { | ||
this.logResponses = logResponses; | ||
return this; | ||
} | ||
|
||
public OllamaEmbeddingModel build() { | ||
return new OllamaEmbeddingModel(this); | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters