-
Notifications
You must be signed in to change notification settings - Fork 413
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Resolve conflicts Co-authored-by: Himani Chauhan <[email protected]>
- Loading branch information
Showing
3 changed files
with
98 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 59 additions & 52 deletions
111
oak-upgrade/src/main/java/org/apache/jackrabbit/oak/upgrade/ImageEmbeddingComparison.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,84 +1,91 @@ | ||
package org.apache.jackrabbit.oak.upgrade; | ||
|
||
import ai.djl.Application; | ||
import ai.djl.Model; | ||
import ai.djl.ModelException; | ||
import ai.djl.inference.Predictor; | ||
import ai.djl.modality.Classifications; | ||
import ai.djl.modality.cv.Image; | ||
import ai.djl.modality.cv.ImageFactory; | ||
import ai.djl.modality.cv.transform.Normalize; | ||
import ai.djl.modality.cv.transform.Resize; | ||
import ai.djl.modality.cv.transform.ToTensor; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
import ai.djl.ndarray.types.DataType; | ||
import ai.djl.ndarray.types.Shape; | ||
import ai.djl.repository.zoo.Criteria; | ||
import ai.djl.repository.zoo.ModelZoo; | ||
import ai.djl.translate.Batchifier; | ||
import ai.djl.translate.TranslateException; | ||
import ai.djl.translate.Translator; | ||
import ai.djl.translate.TranslatorContext; | ||
import ai.djl.translate.TranslatorFactory; | ||
|
||
import java.io.IOException; | ||
import java.io.FileInputStream; | ||
import java.io.InputStream; | ||
|
||
public class ImageEmbeddingComparison { | ||
|
||
public static void compareImages(InputStream sourceStream, InputStream targetStream) throws IOException, ModelException, TranslateException { | ||
// Load the source and target streams (replace with actual streams) | ||
// InputStream sourceStream = getSourceStream(); // Replace with your actual stream | ||
// InputStream targetStream = getTargetStream(); // Replace with your actual stream | ||
public static void compareImages(InputStream inputStream1, InputStream inputStream2) throws Exception { | ||
// Paths to your images | ||
// InputStream inputStream1 = new FileInputStream("path_to_image1.jpg"); | ||
// InputStream inputStream2 = new FileInputStream("path_to_image2.jpg"); | ||
|
||
// Pre-process the images from input streams | ||
Image sourceImage = ImageFactory.getInstance().fromInputStream(sourceStream); | ||
Image targetImage = ImageFactory.getInstance().fromInputStream(targetStream); | ||
// Load images using DJL ImageFactory | ||
Image img1 = ImageFactory.getInstance().fromInputStream(inputStream1); | ||
Image img2 = ImageFactory.getInstance().fromInputStream(inputStream2); | ||
|
||
// Load the pre-trained CLIP model | ||
String modelPath = "huggingface/clip-vit-base-patch16"; // Replace with actual Hugging Face model path | ||
|
||
// Load the model using PyTorch engine | ||
Model model = Model.newInstance(modelPath); // Set application to IMAGE_CLASSIFICATION | ||
// Extract embeddings from both images using CLIP | ||
float[] sourceEmbedding = extractImageEmbedding(model, sourceImage); | ||
float[] targetEmbedding = extractImageEmbedding(model, targetImage); | ||
// Load the pre-trained ResNet model for feature extraction (ResNet50 or ResNet18, etc.) | ||
Criteria<Image, NDArray> criteria = Criteria.builder() | ||
.setTypes(Image.class, NDArray.class) | ||
.optTranslator(new FeatureExtractionTranslator()) | ||
.build(); | ||
|
||
// Calculate cosine similarity between the image embeddings | ||
double similarity = cosineSimilarity(sourceEmbedding, targetEmbedding); | ||
System.out.println("Cosine Similarity: " + similarity); | ||
} | ||
try (Model model = ModelZoo.loadModel(criteria)) { | ||
// Create a predictor for feature extraction | ||
try (Predictor<Image, NDArray> predictor = model.newPredictor(new FeatureExtractionTranslator())) { | ||
// Extract features for both images | ||
NDArray feature1 = predictor.predict(img1); | ||
NDArray feature2 = predictor.predict(img2); | ||
|
||
private static float[] extractImageEmbedding(Model model, Image image) throws ModelException, TranslateException { | ||
Translator<Image, float[]> translator = new Translator<Image, float[]>() { | ||
@Override | ||
public NDList processInput(TranslatorContext ctx, Image input) { | ||
NDArray array = input.toNDArray(ctx.getNDManager()); | ||
return new NDList(array); | ||
// Compute cosine similarity | ||
float similarity = computeCosineSimilarity(feature1, feature2); | ||
System.out.println("Cosine Similarity: " + similarity); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public float[] processOutput(TranslatorContext ctx, NDList list) { | ||
return list.singletonOrThrow().toFloatArray(); | ||
} | ||
// Helper function to compute cosine similarity | ||
private static float computeCosineSimilarity(NDArray a, NDArray b) { | ||
float dotProduct = a.dot(b).getFloat(); | ||
float normA = a.norm().getFloat(); | ||
float normB = b.norm().getFloat(); | ||
return dotProduct / (normA * normB); | ||
} | ||
|
||
@Override | ||
public Batchifier getBatchifier() { | ||
return null; // No batching needed | ||
} | ||
}; | ||
// Translator for feature extraction using ResNet | ||
private static class FeatureExtractionTranslator implements Translator<Image, NDArray> { | ||
|
||
// Use the predictor to extract image embedding | ||
try (Predictor<Image, float[]> predictor = model.newPredictor(translator)) { | ||
return predictor.predict(image); | ||
@Override | ||
public NDList processInput(TranslatorContext ctx, Image input) { | ||
// Resize, convert to tensor, and normalize the image | ||
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); | ||
array = array.toType(DataType.FLOAT32, false); | ||
array = array.div(255f); // Normalize pixel values to [0, 1] | ||
// array = Resize.resize(array, new Shape(224, 224, 3)); | ||
//array = Normalize.normalize(array, new float[]{0.485f, 0.456f, 0.406f}, | ||
// new float[]{0.229f, 0.224f, 0.225f}); | ||
array = array.transpose(2, 0, 1); // Convert to CHW format for PyTorch | ||
return new NDList(array.expandDims(0)); // Add batch dimension | ||
} | ||
} | ||
|
||
// Method to calculate cosine similarity | ||
private static double cosineSimilarity(float[] vectorA, float[] vectorB) { | ||
double dotProduct = 0.0; | ||
double normA = 0.0; | ||
double normB = 0.0; | ||
|
||
for (int i = 0; i < vectorA.length; i++) { | ||
dotProduct += vectorA[i] * vectorB[i]; | ||
normA += Math.pow(vectorA[i], 2); | ||
normB += Math.pow(vectorB[i], 2); | ||
@Override | ||
public NDArray processOutput(TranslatorContext ctx, NDList list) { | ||
return list.singletonOrThrow(); // Extract the model output | ||
} | ||
|
||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); | ||
@Override | ||
public Batchifier getBatchifier() { | ||
return null; // No batching needed | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters