From a3028def9002de8bef500f085feaf0db85422a9e Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 15 Aug 2024 12:20:17 +0100 Subject: [PATCH] - Fixed LLamaEmbedder example - Using `llama_set_embeddings` to toggle on embedding mode, so it no longer needs to be specified in the params --- LLama.Examples/ExampleRunner.cs | 4 ++-- LLama.Examples/Examples/GetEmbeddings.cs | 25 +++++++++++++++++------- LLama.Examples/Examples/QuantizeModel.cs | 4 ++-- LLama.Unittest/LLamaEmbedderTests.cs | 2 -- LLama/LLamaEmbedder.cs | 3 +-- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index e0feae696..019172fd5 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -19,8 +19,8 @@ public class ExampleRunner { "Executor: Stateless mode chat", StatelessModeExecute.Run }, { "Save and Load: chat session", SaveAndLoadSession.Run }, { "Save and Load: state of model and executor", LoadAndSaveState.Run }, - { "LLama Model: Get embeddings", () => Task.Run(GetEmbeddings.Run) }, - { "LLama Model: Quantize", () => Task.Run(QuantizeModel.Run) }, + { "LLama Model: Get embeddings", GetEmbeddings.Run }, + { "LLama Model: Quantize", QuantizeModel.Run }, { "Grammar: Constrain response to json format", GrammarJsonResponse.Run }, { "Kernel Memory: Document Q&A", KernelMemory.Run }, { "Kernel Memory: Save and Load", KernelMemorySaveAndLoad.Run }, diff --git a/LLama.Examples/Examples/GetEmbeddings.cs b/LLama.Examples/Examples/GetEmbeddings.cs index ad844004e..a249a5bc4 100644 --- a/LLama.Examples/Examples/GetEmbeddings.cs +++ b/LLama.Examples/Examples/GetEmbeddings.cs @@ -1,15 +1,21 @@ using LLama.Common; +using LLama.Native; namespace LLama.Examples.Examples { public class GetEmbeddings { - public static void Run() + public static async Task Run() { string modelPath = UserSettings.GetModelPath(); Console.ForegroundColor = ConsoleColor.DarkGray; - var @params = new ModelParams(modelPath) { Embeddings = true }; + var @params = new ModelParams(modelPath) + { + // Embedding models can return one embedding per token, or all of them can be combined ("pooled") into + // one single embedding. Setting PoolingType to "Mean" will combine all of the embeddings using mean average. + PoolingType = LLamaPoolingType.Mean, + }; using var weights = LLamaWeights.LoadFromFile(@params); var embedder = new LLamaEmbedder(weights, @params); @@ -17,12 +23,12 @@ public static void Run() Console.WriteLine( """ This example displays embeddings from a text prompt. - Embeddings are numerical codes that represent information like words, images, or concepts. - These codes capture important relationships between those objects, + Embeddings are vectors that represent information like words, images, or concepts. + These vector capture important relationships between those objects, like how similar words are in meaning or how close images are visually. This allows machine learning models to efficiently understand and process complex data. Embeddings of a text in LLM is sometimes useful, for example, to train other MLP models. - """); // NOTE: this description was AI generated + """); while (true) { @@ -32,8 +38,13 @@ This allows machine learning models to efficiently understand and process comple var text = Console.ReadLine(); Console.ForegroundColor = ConsoleColor.White; - float[] embeddings = embedder.GetEmbeddings(text).Result; - Console.WriteLine($"Embeddings contain {embeddings.Length:N0} floating point values:"); + // Get embeddings for the text + var embeddings = await embedder.GetEmbeddings(text); + + // This should have returned one single embedding vector, because PoolingType was set to Mean above. + var embedding = embeddings.Single(); + + Console.WriteLine($"Embeddings contain {embedding.Length:N0} floating point values:"); Console.ForegroundColor = ConsoleColor.DarkGray; Console.WriteLine(string.Join(", ", embeddings.Take(20)) + ", ..."); Console.WriteLine(); diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs index 233b59678..a1f7ca1bd 100644 --- a/LLama.Examples/Examples/QuantizeModel.cs +++ b/LLama.Examples/Examples/QuantizeModel.cs @@ -1,8 +1,8 @@ -namespace LLama.Examples.Examples +namespace LLama.Examples.Examples { public class QuantizeModel { - public static void Run() + public static async Task Run() { string inputPath = UserSettings.GetModelPath(); diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 267a24206..f48d1ef45 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -26,7 +26,6 @@ private async Task CompareEmbeddings(string modelPath) { ContextSize = 8, Threads = 4, - Embeddings = true, GpuLayerCount = Constants.CIGpuLayerCount, PoolingType = LLamaPoolingType.Mean, }; @@ -74,7 +73,6 @@ private async Task NonPooledEmbeddings(string modelPath) { ContextSize = 8, Threads = 4, - Embeddings = true, GpuLayerCount = Constants.CIGpuLayerCount, PoolingType = LLamaPoolingType.None, }; diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index f48fcf693..ed6240359 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -33,14 +33,13 @@ public sealed class LLamaEmbedder /// public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { - if (!@params.Embeddings) - throw new ArgumentException("Embeddings must be true", nameof(@params)); if (@params.UBatchSize != @params.BatchSize) throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params)); if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true }) throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); Context = weights.CreateContext(@params, logger); + NativeApi.llama_set_embeddings(Context.NativeHandle, true); } ///