diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelConnection.java new file mode 100644 index 00000000..2b41d375 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelConnection.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to mark a method that provides an embedding model connection resource descriptor. + * + *

Methods annotated with this annotation should return a {@link + * org.apache.flink.agents.api.resource.ResourceDescriptor} that describes how to configure and + * create an embedding model connection. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface EmbeddingModelConnection {} diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelSetup.java new file mode 100644 index 00000000..b5a5b9a6 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/EmbeddingModelSetup.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to mark a method that provides an embedding model setup resource descriptor. + * + *

Methods annotated with this annotation should return a {@link + * org.apache.flink.agents.api.resource.ResourceDescriptor} that describes how to configure and + * create an embedding model setup. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface EmbeddingModelSetup {} diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java new file mode 100644 index 00000000..35dd1394 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.embedding.model; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Abstraction of embedding model connection. + * + *

Responsible for managing embedding model service connection configurations, such as Service + * address, API key, Connection timeout, Model name, Authentication information, etc. + * + *

This class follows the parameter pattern where additional configuration options can be passed + * through a Map<String, Object> parameters argument. Common parameters include: + * + *

+ */ +public abstract class BaseEmbeddingModelConnection extends Resource { + + public BaseEmbeddingModelConnection( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + } + + @Override + public ResourceType getResourceType() { + return ResourceType.EMBEDDING_MODEL_CONNECTION; + } + + /** + * Generate embeddings for a single text input. + * + * @param text The input text to generate embeddings for + * @param parameters Additional parameters to configure the embedding request + * @return An array of floating-point values representing the text embeddings + */ + public abstract float[] embed(String text, Map parameters); + + /** + * Generate embeddings for multiple text inputs. + * + * @param texts The list of input texts to generate embeddings for + * @param parameters Additional parameters to configure the embedding request + * @return A list of arrays, each containing floating-point values representing the text + * embeddings + */ + public abstract List embed(List texts, Map parameters); + + /** + * Get the dimension of the embeddings produced by this model. + * + * @return The embedding dimension + */ + public abstract int getEmbeddingDimension(); +} diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java new file mode 100644 index 00000000..c367fd03 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.embedding.model; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Base class for embedding model setup configurations. + * + *

This class provides common setup functionality for embedding models, including connection + * management and model configuration. + */ +public abstract class BaseEmbeddingModelSetup extends Resource { + protected final String connection; + protected String model; + + public BaseEmbeddingModelSetup( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + this.connection = descriptor.getArgument("connection"); + this.model = descriptor.getArgument("model"); + } + + public abstract Map getParameters(); + + @Override + public ResourceType getResourceType() { + return ResourceType.EMBEDDING_MODEL; + } + + /** + * Get the embedding model connection. + * + * @return The embedding model connection instance + */ + public BaseEmbeddingModelConnection getConnection() { + return (BaseEmbeddingModelConnection) + getResource.apply(connection, ResourceType.EMBEDDING_MODEL_CONNECTION); + } + + /** + * Get the model name. + * + * @return The model name + */ + public String getModel() { + return model; + } + + /** + * Generate embeddings for the given text. + * + * @param text The input text to generate embeddings for + * @return An array of floating-point values representing the text embeddings + */ + public float[] embed(String text) { + return this.embed(text, Collections.emptyMap()); + } + + public float[] embed(String text, Map parameters) { + BaseEmbeddingModelConnection connection = getConnection(); + + Map params = this.getParameters(); + params.putAll(parameters); + + return connection.embed(text, params); + } + + /** + * Generate embeddings for multiple texts. + * + * @param texts The list of input texts to generate embeddings for + * @return A list of arrays, each containing floating-point values representing the text + * embeddings + */ + public List embed(List texts) { + return this.embed(texts, Collections.emptyMap()); + } + + public List embed(List texts, Map parameters) { + BaseEmbeddingModelConnection connection = getConnection(); + + Map params = this.getParameters(); + params.putAll(parameters); + + return connection.embed(texts, params); + } + + /** + * Get the dimension of the embeddings produced by this model. + * + * @return The embedding dimension + */ + public int getEmbeddingDimension() { + return getConnection().getEmbeddingDimension(); + } +} diff --git a/examples/pom.xml b/examples/pom.xml index 1b8d543e..02407efc 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -69,7 +69,11 @@ under the License. flink-agents-integrations-chat-models-ollama ${project.version} - + + org.apache.flink + flink-agents-integrations-embedding-models-ollama + ${project.version} + \ No newline at end of file diff --git a/examples/src/main/java/org/apache/flink/agents/examples/WorkflowEmbeddingsAgentExampleJob.java b/examples/src/main/java/org/apache/flink/agents/examples/WorkflowEmbeddingsAgentExampleJob.java new file mode 100644 index 00000000..a94df5ff --- /dev/null +++ b/examples/src/main/java/org/apache/flink/agents/examples/WorkflowEmbeddingsAgentExampleJob.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.examples; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.examples.agents.EmbeddingsAgent; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +/** + * Example workflow demonstrating how to use the WorkflowEmbeddingsAgentExample to generate + * embeddings for streaming text data. + * + *

This example shows: 1. Setting up a Flink streaming job with the embedding agent 2. Processing + * text documents to generate vector embeddings 3. Handling both structured and unstructured text + * input 4. Monitoring embedding generation results + * + *

Prerequisites: - Ollama server running on localhost:11434 - nomic-embed-text model available: + * `ollama pull nomic-embed-text` + */ +public class WorkflowEmbeddingsAgentExampleJob { + + /** Sample text documents for embedding generation. */ + private static final String[] SAMPLE_DOCUMENTS = { + "Apache Flink is a framework and distributed processing engine for stateful computations over unbounded and bounded data streams.", + "Machine learning algorithms can learn patterns from data and make predictions on new, unseen data.", + "Vector embeddings capture semantic meaning of text in high-dimensional numerical representations.", + "Natural language processing enables computers to understand, interpret, and generate human language.", + "Deep learning uses neural networks with multiple layers to model and understand complex patterns.", + "Retrieval-Augmented Generation combines information retrieval with text generation for better AI responses.", + "Semantic search uses vector similarity to find relevant documents based on meaning rather than keywords.", + "Large language models are trained on vast amounts of text data to understand and generate human-like text.", + "Data streaming allows real-time processing of continuous data flows in distributed systems.", + "Artificial intelligence systems can process and analyze large volumes of data to extract insights." + }; + + /** Custom source function that generates sample text documents. */ + public static class SampleTextSource implements SourceFunction { + private volatile boolean running = true; + private int documentIndex = 0; + + @Override + public void run(SourceContext ctx) throws Exception { + while (running) { + // Send structured JSON documents + if (documentIndex < SAMPLE_DOCUMENTS.length) { + String document = SAMPLE_DOCUMENTS[documentIndex]; + String jsonDoc = + String.format( + "{\"id\": \"doc_%d\", \"text\": \"%s\", \"category\": \"tech\", \"timestamp\": %d}", + documentIndex + 1, + document.replace("\"", "\\\""), + System.currentTimeMillis()); + ctx.collect(jsonDoc); + documentIndex++; + } else { + // Send some plain text documents + String plainText = + "This is a plain text document number " + + (documentIndex - SAMPLE_DOCUMENTS.length + 1) + + " for embedding generation testing."; + ctx.collect(plainText); + documentIndex++; + + if (documentIndex > SAMPLE_DOCUMENTS.length + 5) { + running = false; // Stop after processing all documents + } + } + + // Wait 2 seconds between documents + Thread.sleep(2000); + } + } + + @Override + public void cancel() { + running = false; + } + } + + public static void main(String[] args) throws Exception { + // Set up Flink execution environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); // Use single parallelism for deterministic processing + + // Set up Agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Create data stream of text documents + DataStream textStream = + env.addSource(new SampleTextSource()).name("Sample Text Source"); + + // Process with embedding agent using the correct pattern + DataStream embeddingResults = + agentsEnv.fromDataStream(textStream).apply(new EmbeddingsAgent()).toDataStream(); + + // Print results with detailed information + embeddingResults + .map( + event -> { + if (event instanceof org.apache.flink.agents.api.OutputEvent) { + org.apache.flink.agents.api.OutputEvent outputEvent = + (org.apache.flink.agents.api.OutputEvent) event; + Object payload = outputEvent.getOutput(); + + if (payload instanceof java.util.Map) { + @SuppressWarnings("unchecked") + java.util.Map result = + (java.util.Map) payload; + + if (result.containsKey("error")) { + return String.format("ERROR: %s", result.get("error")); + } else { + return String.format( + "EMBEDDING GENERATED - ID: %s, Dimension: %s, Norm: %.4f, Text: '%.100s...'", + result.get("id"), + result.get("dimension"), + result.get("norm"), + result.get("text")); + } + } + } + return "Processed: " + event.toString(); + }) + .print() + .name("Print Results"); + + // Execute the Flink job + env.execute("Workflow Embeddings Agent Example"); + } +} diff --git a/examples/src/main/java/org/apache/flink/agents/examples/agents/EmbeddingsAgent.java b/examples/src/main/java/org/apache/flink/agents/examples/agents/EmbeddingsAgent.java new file mode 100644 index 00000000..201b42ca --- /dev/null +++ b/examples/src/main/java/org/apache/flink/agents/examples/agents/EmbeddingsAgent.java @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.examples.agents; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.Agent; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.annotation.EmbeddingModelSetup; +import org.apache.flink.agents.api.annotation.Tool; +import org.apache.flink.agents.api.annotation.ToolParam; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * An agent that generates embeddings for each row of data using Ollama embedding models. + * + *

This agent receives text data, processes it to generate high-dimensional vector embeddings, + * and outputs the results with metadata. It demonstrates how to integrate embedding models into + * Flink Agents workflows for vector-based processing and similarity search applications. + * + *

The agent supports various embedding models available in Ollama such as: - nomic-embed-text + * (768 dimensions) - mxbai-embed-large (1024 dimensions) - all-minilm (384 dimensions) + */ +public class EmbeddingsAgent extends Agent { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @EmbeddingModelConnection + public static ResourceDescriptor ollamaEmbeddingConnection() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 60) + .addInitialArgument("model", "nomic-embed-text") + .build(); + } + + @EmbeddingModelSetup + public static ResourceDescriptor embeddingModel() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "ollamaEmbeddingConnection") + .addInitialArgument("model", "nomic-embed-text") + .build(); + } + + /** + * Tool for storing embeddings in a vector database. + * + * @param id The unique identifier for the text data + * @param text The original text content + * @param embedding The generated embedding vector + * @param dimension The dimension of the embedding vector + */ + @Tool( + description = + "Store embeddings in a vector database for similarity search and retrieval.") + public static void storeEmbedding( + @ToolParam(name = "id") String id, + @ToolParam(name = "text") String text, + @ToolParam(name = "embedding") float[] embedding, + @ToolParam(name = "dimension") int dimension) { + + // In a real implementation, this would store in a vector database like Pinecone, Weaviate, + // etc. + System.out.printf( + "Storing embedding: ID=%s, Text='%s...', Dimension=%d%n", + id, text.substring(0, Math.min(50, text.length())), dimension); + System.out.printf( + "Embedding preview: [%.4f, %.4f, %.4f, ...]%n", + embedding[0], embedding[1], embedding[2]); + } + + /** + * Tool for calculating similarity between embeddings. + * + * @param embedding1 First embedding vector + * @param embedding2 Second embedding vector + * @return Cosine similarity score between -1 and 1 + */ + @Tool(description = "Calculate cosine similarity between two embedding vectors.") + public static float calculateSimilarity( + @ToolParam(name = "embedding1") float[] embedding1, + @ToolParam(name = "embedding2") float[] embedding2) { + + if (embedding1.length != embedding2.length) { + throw new IllegalArgumentException("Embedding dimensions must match"); + } + + float dotProduct = 0.0f; + float normA = 0.0f; + float normB = 0.0f; + + for (int i = 0; i < embedding1.length; i++) { + dotProduct += embedding1[i] * embedding2[i]; + normA += embedding1[i] * embedding1[i]; + normB += embedding2[i] * embedding2[i]; + } + + if (normA == 0.0f || normB == 0.0f) { + return 0.0f; + } + + float similarity = (float) (dotProduct / (Math.sqrt(normA) * Math.sqrt(normB))); + System.out.printf("Calculated similarity: %.4f%n", similarity); + return similarity; + } + + /** Process input event and generate embeddings for the text data. */ + @Action(listenEvents = {InputEvent.class}) + public static void processInput(InputEvent event, RunnerContext ctx) throws Exception { + String input = (String) event.getInput(); + MAPPER.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + + // Parse input as a text document with optional metadata + Map inputData; + try { + inputData = MAPPER.readValue(input, Map.class); + } catch (Exception e) { + // If not JSON, treat as plain text + inputData = new HashMap<>(); + inputData.put("text", input); + inputData.put("id", "doc_" + System.currentTimeMillis()); + } + + String text = (String) inputData.get("text"); + String id = (String) inputData.getOrDefault("id", "doc_" + System.currentTimeMillis()); + + if (text == null || text.trim().isEmpty()) { + ctx.sendEvent( + new OutputEvent( + Map.of("error", "No text content found in input", "input", input))); + return; + } + + // Store data in short-term memory for later use + ctx.getShortTermMemory().set("id", id); + ctx.getShortTermMemory().set("text", text); + ctx.getShortTermMemory().set("originalInput", inputData); + + try { + // Use the actual Ollama embedding model directly + float[] embedding = generateRealEmbedding(text, ctx); + int dimension = embedding.length; + + // Store the embedding using the tool + storeEmbedding(id, text, embedding, dimension); + + // Create output with embedding results + Map result = new HashMap<>(); + result.put("id", id); + result.put("text", text); + result.put("embedding", embedding); + result.put("dimension", dimension); + result.put("embeddingPreview", Arrays.copyOf(embedding, Math.min(5, embedding.length))); + result.put("metadata", inputData); + result.put("timestamp", System.currentTimeMillis()); + + // Calculate some statistics + float norm = 0.0f; + for (float value : embedding) { + norm += value * value; + } + result.put("norm", Math.sqrt(norm)); + + ctx.sendEvent(new OutputEvent(result)); + + System.out.printf( + "Generated embedding for text: '%s' (ID: %s, Dimension: %d)%n", + text.substring(0, Math.min(100, text.length())), id, dimension); + + } catch (Exception e) { + System.err.printf( + "Error generating embedding for text '%s': %s%n", text, e.getMessage()); + ctx.sendEvent( + new OutputEvent( + Map.of( + "error", "Failed to generate embedding: " + e.getMessage(), + "id", id, + "text", text))); + } + } + + /** + * Generate real embeddings using Ollama embedding model via the framework resource system. This + * uses the context to retrieve the managed embedding model setup resource. + */ + private static float[] generateRealEmbedding(String text, RunnerContext ctx) { + try { + System.out.println("Attempting to retrieve embeddingModel resource..."); + + // Use the embedding model setup resource as intended by the framework + OllamaEmbeddingModelSetup embeddingModel = + (OllamaEmbeddingModelSetup) + ctx.getResource( + "embeddingModel", + org.apache.flink.agents.api.resource.ResourceType + .EMBEDDING_MODEL); + + System.out.println("Successfully retrieved embeddingModel resource"); + + // Generate the embedding using the managed Ollama model setup + float[] embedding = embeddingModel.embed(text); + System.out.printf("Generated Ollama embedding with dimension: %d%n", embedding.length); + return embedding; + + } catch (Exception e) { + System.err.printf( + "FAILED to generate real embedding for text '%s': %s%n", + text.substring(0, Math.min(50, text.length())), e.getMessage()); + e.printStackTrace(); + // Re-throw the exception instead of falling back to mock + throw new RuntimeException("Ollama embedding generation failed", e); + } + } + + /** Data class for structured text input with metadata. */ + public static class TextDocument { + private String id; + private String text; + private String category; + private Map metadata; + + // Constructors + public TextDocument() {} + + public TextDocument(String id, String text, String category, Map metadata) { + this.id = id; + this.text = text; + this.category = category; + this.metadata = metadata; + } + + // Getters and setters + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public String getCategory() { + return category; + } + + public void setCategory(String category) { + this.category = category; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + } + + /** Data class for embedding results. */ + public static class EmbeddingResult { + private String id; + private String text; + private float[] embedding; + private int dimension; + private double norm; + private long timestamp; + private Map metadata; + + // Constructors + public EmbeddingResult() {} + + public EmbeddingResult( + String id, + String text, + float[] embedding, + int dimension, + double norm, + long timestamp, + Map metadata) { + this.id = id; + this.text = text; + this.embedding = embedding; + this.dimension = dimension; + this.norm = norm; + this.timestamp = timestamp; + this.metadata = metadata; + } + + // Getters and setters + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public float[] getEmbedding() { + return embedding; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + + public int getDimension() { + return dimension; + } + + public void setDimension(int dimension) { + this.dimension = dimension; + } + + public double getNorm() { + return norm; + } + + public void setNorm(double norm) { + this.norm = norm; + } + + public long getTimestamp() { + return timestamp; + } + + public void setTimestamp(long timestamp) { + this.timestamp = timestamp; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + } +} diff --git a/integrations/chat-models/ollama/pom.xml b/integrations/chat-models/ollama/pom.xml index 96850c44..942a42d6 100644 --- a/integrations/chat-models/ollama/pom.xml +++ b/integrations/chat-models/ollama/pom.xml @@ -46,7 +46,7 @@ under the License. io.github.ollama4j ollama4j - 1.1.0 + ${ollama4j.version} diff --git a/integrations/embedding-models/ollama/pom.xml b/integrations/embedding-models/ollama/pom.xml new file mode 100644 index 00000000..e9f6e72e --- /dev/null +++ b/integrations/embedding-models/ollama/pom.xml @@ -0,0 +1,53 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-embedding-models + 0.2-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-embedding-models-ollama + Flink Agents : Integrations: Embedding Models: Ollama + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + org.apache.flink + flink-agents-plan + ${project.version} + + + + io.github.ollama4j + ollama4j + ${ollama4j.version} + + + + diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java new file mode 100644 index 00000000..dcb73bfe --- /dev/null +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.ollama; + +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.exceptions.OllamaBaseException; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * An embedding model integration for Ollama powered by the ollama4j client. + * + *

This implementation adapts the generic Flink Agents embedding model interface to Ollama's + * embedding API. It supports various embedding models available in Ollama such as: + * + *

    + *
  • nomic-embed-text + *
  • mxbai-embed-large + *
  • all-minilm + *
  • And other embedding models supported by Ollama + *
+ * + *

See also {@link BaseEmbeddingModelConnection} for the common resource abstractions and + * lifecycle. + * + *

Example usage: + * + *

{@code
+ * public class MyAgent extends Agent {
+ *   // Register the embedding model connection via @EmbeddingModelConnection metadata.
+ *   @EmbeddingModelConnection
+ *   public static ResourceDescriptor ollama() {
+ *     return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName())
+ *                 .addInitialArgument("host", "http://localhost:11434") // the Ollama server URL
+ *                 .addInitialArgument("model", "nomic-embed-text") // the embedding model name
+ *                 .build();
+ *   }
+ * }
+ * }
+ */ +public class OllamaEmbeddingModelConnection extends BaseEmbeddingModelConnection { + + private final OllamaAPI ollamaAPI; + private final String host; + private final String defaultModel; + private Integer cachedDimension; + + public OllamaEmbeddingModelConnection( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + this.host = + descriptor.getArgument("host") != null + ? descriptor.getArgument("host") + : "http://localhost:11434"; + this.defaultModel = + descriptor.getArgument("model") != null + ? descriptor.getArgument("model") + : "nomic-embed-text"; + + this.ollamaAPI = new OllamaAPI(host); + } + + @Override + public float[] embed(String text, Map parameters) { + String model = (String) parameters.getOrDefault("model", defaultModel); + return embedSingle(text, model); + } + + @Override + public List embed(List texts, Map parameters) { + String model = (String) parameters.getOrDefault("model", defaultModel); + return embedBatch(texts, model); + } + + private float[] embedSingle(String text, String model) { + if (text == null || text.trim().isEmpty()) { + throw new IllegalArgumentException("Text cannot be null or empty"); + } + + try { + List embeddings = ollamaAPI.generateEmbeddings(model, text); + + if (embeddings == null || embeddings.isEmpty()) { + throw new RuntimeException( + "Received empty embeddings from Ollama for model: " + model); + } + + float[] result = new float[embeddings.size()]; + for (int i = 0; i < embeddings.size(); i++) { + result[i] = embeddings.get(i).floatValue(); + } + + if (cachedDimension == null) { + cachedDimension = result.length; + } + + return result; + } catch (OllamaBaseException e) { + throw new RuntimeException( + "Ollama API error while generating embeddings for text with model '" + + model + + "': " + + e.getMessage(), + e); + } catch (IOException | InterruptedException e) { + throw new RuntimeException( + "Communication error with Ollama server at " + + host + + " while generating embeddings: " + + e.getMessage(), + e); + } catch (Exception e) { + throw new RuntimeException( + "Unexpected error while generating embeddings for text with model '" + + model + + "': " + + e.getMessage(), + e); + } + } + + private List embedBatch(List texts, String model) { + if (texts == null || texts.isEmpty()) { + throw new IllegalArgumentException("Texts list cannot be null or empty"); + } + + List results = new ArrayList<>(); + for (String text : texts) { + if (text != null && !text.trim().isEmpty()) { + results.add(embedSingle(text, model)); + } else { + throw new IllegalArgumentException("Text in list cannot be null or empty"); + } + } + return results; + } + + @Override + public int getEmbeddingDimension() { + if (cachedDimension != null) { + return cachedDimension; + } + + try { + Map testParams = new HashMap<>(); + testParams.put("model", defaultModel); + float[] testEmbedding = embed("test", testParams); + cachedDimension = testEmbedding.length; + return cachedDimension; + } catch (Exception e) { + switch (defaultModel.toLowerCase()) { + case "nomic-embed-text": + return 768; + case "mxbai-embed-large": + return 1024; + case "all-minilm": + return 384; + default: + throw new RuntimeException( + "Could not determine embedding dimension for model: " + + defaultModel + + ". Cause: " + + e.getMessage(), + e); + } + } + } + + /** Check if the specified model is available on the Ollama server. */ + public boolean isModelAvailable(String model) { + try { + return ollamaAPI.listModels().stream() + .anyMatch(modelInfo -> modelInfo.getName().equals(model)); + } catch (Exception e) { + try { + Map testParams = new HashMap<>(); + testParams.put("model", model); + embed("test", testParams); + return true; + } catch (Exception testException) { + return false; + } + } + } + + /** Get the default embedding model name. */ + public String getDefaultModel() { + return defaultModel; + } +} diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java new file mode 100644 index 00000000..fc7228c0 --- /dev/null +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.ollama; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * An embedding model setup for Ollama powered by the ollama4j client. + * + *

This implementation adapts the generic Flink Agents embedding model interface to Ollama's + * embedding API. It supports various embedding models available in Olloma such as: - + * nomic-embed-text (768 dimensions) - mxbai-embed-large (1024 dimensions) - all-minilm (384 + * dimensions) - And other embedding models supported by Olloma + * + *

This class implements the parameter pattern where configuration options are passed through a + * Map<String, Object> parameters argument. The getParameters() method provides default + * parameters that can be overridden when calling embed methods. + * + *

See also {@link BaseEmbeddingModelSetup} for the common resource abstractions and lifecycle. + * + *

Example usage: + * + *

{@code
+ * public class MyAgent extends Agent {
+ *   // Register the embedding model setup via @EmbeddingModelSetup metadata.
+ *   @EmbeddingModelSetup
+ *   public static ResourceDesc ollama() {
+ *     return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName())
+ *                 .addInitialArgument("connection", "myConnection") // the name of OllamaEmbeddingModelConnection
+ *                 .addInitialArgument("model", "nomic-embed-text") // the model name
+ *                 .build();
+ *   }
+ * }
+ * }
+ */ +public class OllamaEmbeddingModelSetup extends BaseEmbeddingModelSetup { + + public OllamaEmbeddingModelSetup( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + } + + @Override + public Map getParameters() { + Map parameters = new HashMap<>(); + + // Add the model name if specified + if (model != null) { + parameters.put("model", model); + } + + return parameters; + } + + @Override + public OllamaEmbeddingModelConnection getConnection() { + return (OllamaEmbeddingModelConnection) super.getConnection(); + } + + /** + * Get the dimension of the embeddings produced by the configured Ollama model. + * + * @return The embedding dimension + */ + @Override + public int getEmbeddingDimension() { + return getConnection().getEmbeddingDimension(); + } + + /** + * Check if the specified model is available on the Ollama server. + * + * @param model The model name to check + * @return true if the model is available, false otherwise + */ + public boolean isModelAvailable(String model) { + return getConnection().isModelAvailable(model); + } + + /** + * Get the default embedding model name configured for this setup. + * + * @return The default model name + */ + public String getDefaultModel() { + return getConnection().getDefaultModel(); + } +} diff --git a/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java new file mode 100644 index 00000000..47c9dc6c --- /dev/null +++ b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.ollama; + +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.junit.jupiter.api.Assertions.*; + +class OllamaEmbeddingModelConnectionTest { + + private static ResourceDescriptor buildDescriptor() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("model", "nomic-embed-text") + .build(); + } + + private static BiFunction dummyResource = (a, b) -> null; + + @Test + @DisplayName("Create OllamaEmbeddingModelConnection and check embed method") + void testCreateAndEmbed() { + OllamaEmbeddingModelConnection conn = + new OllamaEmbeddingModelConnection(buildDescriptor(), dummyResource); + assertNotNull(conn); + // No llamamos a embed porque requiere un servidor Ollama real + } + + @Test + @DisplayName("Test EmbeddingModelConnection annotation presence") + void testAnnotationPresence() { + assertNull( + OllamaEmbeddingModelConnection.class.getAnnotation(EmbeddingModelConnection.class)); + } + + @Test + @DisplayName("Test EmbeddingModelSetup annotation presence on setup class") + void testSetupAnnotationPresence() { + class DummySetup extends BaseEmbeddingModelSetup { + public DummySetup( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + } + + public Map getParameters() { + Map parameters = new HashMap<>(); + if (model != null) { + parameters.put("model", model); + } + return parameters; + } + + @Override + public BaseEmbeddingModelConnection getConnection() { + return new OllamaEmbeddingModelConnection(buildDescriptor(), dummyResource); + } + } + DummySetup setup = new DummySetup(buildDescriptor(), dummyResource); + assertNotNull(setup.getConnection()); + } +} diff --git a/integrations/embedding-models/pom.xml b/integrations/embedding-models/pom.xml new file mode 100644 index 00000000..473929b9 --- /dev/null +++ b/integrations/embedding-models/pom.xml @@ -0,0 +1,37 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations + 0.2-SNAPSHOT + + + flink-agents-integrations-embedding-models + Flink Agents : Integrations: Embedding Models + pom + + + ollama + + + diff --git a/integrations/pom.xml b/integrations/pom.xml index 657cb1f9..3509b561 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -30,8 +30,13 @@ under the License. Flink Agents : Integrations: pom + + 1.1.0 + + chat-models + embedding-models \ No newline at end of file diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index c7d388e7..7cb81386 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -22,6 +22,8 @@ import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.annotation.ChatModelConnection; import org.apache.flink.agents.api.annotation.ChatModelSetup; +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.annotation.EmbeddingModelSetup; import org.apache.flink.agents.api.annotation.Prompt; import org.apache.flink.agents.api.annotation.Tool; import org.apache.flink.agents.api.resource.Resource; @@ -372,6 +374,10 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { extractResource(ResourceType.CHAT_MODEL, method); } else if (method.isAnnotationPresent(ChatModelConnection.class)) { extractResource(ResourceType.CHAT_MODEL_CONNECTION, method); + } else if (method.isAnnotationPresent(EmbeddingModelSetup.class)) { + extractResource(ResourceType.EMBEDDING_MODEL, method); + } else if (method.isAnnotationPresent(EmbeddingModelConnection.class)) { + extractResource(ResourceType.EMBEDDING_MODEL_CONNECTION, method); } }