From 4f3f0edcef6557ff2424f7960ebde6364885cc18 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 15 Aug 2024 11:47:04 +0100 Subject: [PATCH] Totally rewritten the LLamaEmbedder based on https://github.com/ggerganov/llama.cpp/tree/master/examples/embedding. New embedder properly handles pooling, either returning one embedding for the whole sequence or one per token. - Added `Encode` methods to `LLamaContext` - Moved some native methods from `NativeApi` to `SafeLLamaContextHandle` and wrapped them properly - Added `HasDecoder` property to `SafeLlamaModelHandle`. This function doesn't exist in the current version of llama.cpp, will need to be hooked up in the next binary update - Added some normalization methods as extensions on span/array. This required adding a dependency on `System.Numerics.Tensors` --- .../LLamaSharpTextEmbeddingGenerator.cs | 9 +- .../LLamaSharpEmbeddingGeneration.cs | 5 +- LLama.Unittest/LLamaEmbedderTests.cs | 39 +++- .../Extensions/SpanNormalizationExtensions.cs | 126 +++++++++++ LLama/LLamaContext.cs | 22 ++ LLama/LLamaEmbedder.cs | 205 ++++++++---------- LLama/LLamaSharp.csproj | 6 +- LLama/Native/EncodeResult.cs | 17 ++ LLama/Native/LLamaBatch.cs | 7 +- LLama/Native/LLamaPoolingType.cs | 17 ++ LLama/Native/NativeApi.cs | 24 -- LLama/Native/SafeLLamaContextHandle.cs | 86 +++++++- 12 files changed, 408 insertions(+), 155 deletions(-) create mode 100644 LLama/Extensions/SpanNormalizationExtensions.cs create mode 100644 LLama/Native/EncodeResult.cs diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index 2f6e332d5..79eaa1514 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -1,5 +1,6 @@ using LLama; using LLama.Common; +using LLama.Native; using Microsoft.KernelMemory; using Microsoft.KernelMemory.AI; @@ -35,7 +36,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config) GpuLayerCount = config.GpuLayerCount ?? 20, Embeddings = true, MainGpu = config.MainGpu, - SplitMode = config.SplitMode + SplitMode = config.SplitMode, + PoolingType = LLamaPoolingType.Mean, }; _weights = LLamaWeights.LoadFromFile(@params); _embedder = new LLamaEmbedder(_weights, @params); @@ -59,7 +61,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we GpuLayerCount = config.GpuLayerCount ?? 20, Embeddings = true, MainGpu = config.MainGpu, - SplitMode = config.SplitMode + SplitMode = config.SplitMode, + PoolingType = LLamaPoolingType.Mean, }; _weights = weights; _embedder = new LLamaEmbedder(_weights, @params); @@ -92,7 +95,7 @@ public void Dispose() public async Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { var embeddings = await _embedder.GetEmbeddings(text, cancellationToken); - return new Embedding(embeddings); + return new Embedding(embeddings.First()); } /// diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs index 9514e1711..d50945117 100644 --- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs +++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs @@ -4,7 +4,8 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding; -public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService +public sealed class LLamaSharpEmbeddingGeneration + : ITextEmbeddingGenerationService { private readonly LLamaEmbedder _embedder; @@ -23,7 +24,7 @@ public async Task>> GenerateEmbeddingsAsync(IList>(); foreach (var item in data) - result.Add(await _embedder.GetEmbeddings(item, cancellationToken)); + result.Add((await _embedder.GetEmbeddings(item, cancellationToken)).First()); return result; } diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index e9d9359f2..267a24206 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -1,4 +1,6 @@ using LLama.Common; +using LLama.Extensions; +using LLama.Native; using Xunit.Abstractions; namespace LLama.Unittest; @@ -26,17 +28,18 @@ private async Task CompareEmbeddings(string modelPath) Threads = 4, Embeddings = true, GpuLayerCount = Constants.CIGpuLayerCount, + PoolingType = LLamaPoolingType.Mean, }; using var weights = LLamaWeights.LoadFromFile(@params); using var embedder = new LLamaEmbedder(weights, @params); - var cat = await embedder.GetEmbeddings("The cat is cute"); + var cat = (await embedder.GetEmbeddings("The cat is cute")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, cat); - var kitten = await embedder.GetEmbeddings("The kitten is kawaii"); + var kitten = (await embedder.GetEmbeddings("The kitten is cute")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, kitten); - var spoon = await embedder.GetEmbeddings("The spoon is not real"); + var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, spoon); _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); @@ -64,4 +67,34 @@ public async Task EmbedCompareGenerateModel() { await CompareEmbeddings(Constants.GenerativeModelPath); } + + private async Task NonPooledEmbeddings(string modelPath) + { + var @params = new ModelParams(modelPath) + { + ContextSize = 8, + Threads = 4, + Embeddings = true, + GpuLayerCount = Constants.CIGpuLayerCount, + PoolingType = LLamaPoolingType.None, + }; + using var weights = LLamaWeights.LoadFromFile(@params); + using var embedder = new LLamaEmbedder(weights, @params); + + var kitten = await embedder.GetEmbeddings("the kitten is kawaii"); + foreach (var embd in kitten) + Assert.DoesNotContain(float.NaN, embd); + } + + [Fact] + public async Task EmbeddingModelNonPooledEmbeddings() + { + await NonPooledEmbeddings(Constants.EmbeddingModelPath); + } + + [Fact] + public async Task GenerativeModelNonPooledEmbeddings() + { + await NonPooledEmbeddings(Constants.GenerativeModelPath); + } } \ No newline at end of file diff --git a/LLama/Extensions/SpanNormalizationExtensions.cs b/LLama/Extensions/SpanNormalizationExtensions.cs new file mode 100644 index 000000000..8ed827b64 --- /dev/null +++ b/LLama/Extensions/SpanNormalizationExtensions.cs @@ -0,0 +1,126 @@ +using System; +using System.Numerics.Tensors; + +namespace LLama.Extensions; + +/// +/// Extensions to span which apply in-place normalization +/// +public static class SpanNormalizationExtensions +{ + /// + /// In-place multiple every element by 32760 and divide every element in the span by the max absolute value in the span + /// + /// + /// The same array + public static float[] MaxAbsoluteNormalization(this float[] vector) + { + vector.AsSpan().MaxAbsoluteNormalization(); + return vector; + } + + /// + /// In-place multiple every element by 32760 and divide every element in the span by the max absolute value in the span + /// + /// + /// The same span + public static Span MaxAbsoluteNormalization(this Span vector) + { + var factor = 32760 / TensorPrimitives.MaxMagnitude(vector); + TensorPrimitives.Multiply(vector, factor, vector); + return vector; + } + + /// + /// In-place divide every element in the array by the sum of absolute values in the array + /// + /// Also known as "Manhattan normalization". + /// + /// The same array + public static float[] TaxicabNormalization(this float[] vector) + { + vector.AsSpan().TaxicabNormalization(); + return vector; + } + + /// + /// In-place divide every element in the span by the sum of absolute values in the span + /// + /// Also known as "Manhattan normalization". + /// + /// The same span + public static Span TaxicabNormalization(this Span vector) + { + var sumAbs = TensorPrimitives.SumOfMagnitudes(vector); + TensorPrimitives.Divide(vector, sumAbs, vector); + return vector; + } + + /// + /// In-place divide every element by the euclidean length of the vector + /// + /// Also known as "L2 normalization". + /// + /// The same array + public static float[] EuclideanNormalization(this float[] vector) + { + vector.AsSpan().EuclideanNormalization(); + return vector; + } + + /// + /// In-place divide every element by the euclidean length of the vector + /// + /// Also known as "L2 normalization". + /// + /// The same span + public static Span EuclideanNormalization(this Span vector) + { + var norm = TensorPrimitives.Norm(vector); + TensorPrimitives.Divide(vector, norm, vector); + return vector; + } + + /// + /// In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm + /// + /// For p = 1, this is taxicab normalization + /// For p = 2, this is euclidean normalization + /// As p => infinity, this approaches infinity norm or maximum norm + /// + /// + /// + /// + /// The same array + public static float[] PNormalization(this float[] vector, int p) + { + vector.AsSpan().PNormalization(p); + return vector; + } + + /// + /// In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm + /// + /// For p = 1, this is taxicab normalization + /// For p = 2, this is euclidean normalization + /// As p => infinity, this approaches infinity norm or maximum norm + /// + /// + /// + /// + /// The same span + public static Span PNormalization(this Span vector, int p) + { + if (p == 2) + return vector.EuclideanNormalization(); + + var sum = 0.0; + for (var i = 0; i < vector.Length; i++) + sum += MathF.Pow(vector[i], p); + var divisor = (float)Math.Pow(sum, 1.0 / p); + + TensorPrimitives.Divide(vector, divisor, vector); + + return vector; + } +} \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index ca38d49e4..9dd2a6394 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -379,6 +379,28 @@ public bool ShouldAddBosToken() } #region eval overloads + /// + /// + /// + public EncodeResult Encode(LLamaBatch batch) + { + if (batch.TokenCount == 0) + return 0; + if (batch.TokenCount > BatchSize) + throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); + + return (EncodeResult)NativeHandle.Encode(batch); + } + + /// + /// + /// + /// + public Task EncodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) + { + return Task.Run(() => Encode(batch), cancellationToken); + } + /// /// /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index d050707e8..f48fcf693 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,138 +1,119 @@ -using LLama.Native; using System; -using LLama.Exceptions; -using LLama.Abstractions; -using Microsoft.Extensions.Logging; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Exceptions; +using LLama.Native; +using Microsoft.Extensions.Logging; -namespace LLama +namespace LLama; + +/// +/// Generate high dimensional embedding vectors from text +/// +public sealed class LLamaEmbedder + : IDisposable { /// - /// The embedder for LLama, which supports getting embeddings from text. + /// Dimension of embedding vectors /// - public sealed class LLamaEmbedder - : IDisposable - { - /// - /// Dimension of embedding vectors - /// - public int EmbeddingSize => Context.EmbeddingSize; - - /// - /// LLama Context - /// - public LLamaContext Context { get; } - - /// - /// Create a new embedder, using the given LLamaWeights - /// - /// - /// - /// - public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) - { - if (!@params.Embeddings) - throw new ArgumentException("Embeddings must be true", nameof(@params)); - - Context = weights.CreateContext(@params, logger); - } - - /// - /// Get the embeddings of the text. - /// - /// - /// - /// - /// - public Task GetEmbeddings(string text, CancellationToken cancellationToken = default) - { - return GetEmbeddings(text, true, cancellationToken); - } + public int EmbeddingSize => Context.EmbeddingSize; - /// - /// Get the embeddings of the text. - /// - /// - /// Add bos to the text. - /// - /// - /// - public async Task GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default) - { - var tokens = Context.Tokenize(text, addBos); - if (tokens.Length > Context.ContextSize) - throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text)); - - // Evaluate prompt in batch-size chunks - var n_past = 0; - var batch = new LLamaBatch(); - var batchSize = (int)Context.BatchSize; - for (var i = 0; i < tokens.Length; i += batchSize) - { - var n_eval = tokens.Length - i; - if (n_eval > batchSize) - n_eval = batchSize; - - batch.Clear(); - batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true); - n_past += n_eval; + /// + /// LLama Context + /// + public LLamaContext Context { get; } - var returnCode = await Context.DecodeAsync(batch, cancellationToken); - if (returnCode != 0) - throw new LLamaDecodeError(returnCode); - } + /// + /// Create a new embedder, using the given LLamaWeights + /// + /// + /// + /// + public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) + { + if (!@params.Embeddings) + throw new ArgumentException("Embeddings must be true", nameof(@params)); + if (@params.UBatchSize != @params.BatchSize) + throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params)); + if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true }) + throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); + + Context = weights.CreateContext(@params, logger); + } - var embeddings = GetEmbeddingsArray(); + /// + public void Dispose() + { + Context.Dispose(); + } - // Remove everything we just evaluated from the context cache - Context.NativeHandle.KvCacheClear(); + /// + /// Get high dimensional embedding vectors for the given text. Depending on the pooling type used when constructing + /// this this may return an embedding vector per token, or one single embedding vector for the entire string. + /// + /// Embedding vectors are not normalized, consider using one of the extensions in . + /// + /// + /// + /// + /// + public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) + { + // Add all of the tokens to the batch + var tokens = Context.Tokenize(input); + var batch = new LLamaBatch(); + for (var i = 0; i < tokens.Length; i++) + batch.Add(tokens[i], i, LLamaSeqId.Zero, true); - // Normalize the embeddings vector - // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 - Normalize(embeddings); + // clear previous kv_cache values + Context.NativeHandle.KvCacheClear(); - return embeddings; - } + // Check if we should cancel the work, just before doing anything expensive (encode/decode) + cancellationToken.ThrowIfCancellationRequested(); - private float[] GetEmbeddingsArray() + // Run model + switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) { - unsafe + case (true, false): { - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - - if (embeddings == null) - embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); - - if (embeddings == null) - return [ ]; + var result = await Context.EncodeAsync(batch, cancellationToken); + if (result != EncodeResult.Ok) + throw new RuntimeError($"Failed to encode: {result}"); + break; + } - return new Span(embeddings, Context.EmbeddingSize).ToArray(); + case (false, true): + { + var result = await Context.DecodeAsync(batch, cancellationToken); + if (result != DecodeResult.Ok) + throw new RuntimeError($"Failed to decode: {result}"); + break; } + + default: + throw new NotSupportedException("Unsupported model type"); } - private static void Normalize(Span embeddings) + // Extract results + var poolingType = Context.NativeHandle.PoolingType; + var resultsCount = poolingType == LLamaPoolingType.None ? tokens.Length : 1; + var results = new List(resultsCount); + + if (poolingType == LLamaPoolingType.None) { - // Calculate length - var lengthSqr = 0.0; - foreach (var value in embeddings) - lengthSqr += value * value; - var length = (float)Math.Sqrt(lengthSqr); - - // Do not divide by length if it is zero - if (length <= float.Epsilon) - return; - - // Normalize - for (var i = 0; i < embeddings.Length; i++) - embeddings[i] /= length; + var positions = batch.GetLogitPositions(); + foreach (var (_, pos) in positions) + results.Add(Context.NativeHandle.GetEmbeddingsIth(pos).ToArray()); } - - /// - public void Dispose() + else { - Context.Dispose(); + results.Add(Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero).ToArray()); } + Context.NativeHandle.KvCacheClear(); + + return results; } -} +} \ No newline at end of file diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index addda27f2..7b6e11c50 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -50,6 +50,7 @@ + @@ -83,11 +84,6 @@ OverwriteReadOnlyFiles="true" Include="*.dll;*.so;*.dylib;*.metal;" /> - diff --git a/LLama/Native/EncodeResult.cs b/LLama/Native/EncodeResult.cs new file mode 100644 index 000000000..31bafc098 --- /dev/null +++ b/LLama/Native/EncodeResult.cs @@ -0,0 +1,17 @@ +namespace LLama.Native; + +/// +/// Return codes from llama_encode +/// +public enum EncodeResult +{ + /// + /// An unspecified error + /// + Error = -1, + + /// + /// Ok. + /// + Ok = 0 +} \ No newline at end of file diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index 91a2fafbc..c66bd9277 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -281,11 +281,8 @@ public void Clear() /// Get the positions where logits can be sampled from /// /// - internal Span<(LLamaSeqId, int)> GetLogitPositions(Span<(LLamaSeqId, int)> dest) + internal IReadOnlyList<(LLamaSeqId, int)> GetLogitPositions() { - for (var i = 0; i < _logitPositions.Count; i++) - dest[i] = _logitPositions[i]; - - return dest.Slice(0, _logitPositions.Count); + return _logitPositions; } } \ No newline at end of file diff --git a/LLama/Native/LLamaPoolingType.cs b/LLama/Native/LLamaPoolingType.cs index 31c615d7e..ab0b75457 100644 --- a/LLama/Native/LLamaPoolingType.cs +++ b/LLama/Native/LLamaPoolingType.cs @@ -1,3 +1,5 @@ +using LLama.Abstractions; + namespace LLama.Native; /// @@ -6,9 +8,24 @@ namespace LLama.Native; /// llama_pooling_type public enum LLamaPoolingType { + /// + /// No specific pooling type. Use the model default if this is specific in + /// Unspecified = -1, + + /// + /// Do not pool embeddings (per-token embeddings) + /// None = 0, + + /// + /// Take the mean of every token embedding + /// Mean = 1, + + /// + /// Return the embedding for the special "CLS" token + /// CLS = 2, Last = 3, diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 8d967e670..46ec79813 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -131,30 +131,6 @@ public static void llama_empty_call() [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx); - /// - /// Get the pooling type for this context - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx); - - /// - /// Get the embeddings for the a specific sequence. - /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - /// - /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id); - - /// - /// Get the embeddings for the ith sequence. - /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - /// - /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); - /// /// Get all output token embeddings. /// When pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, the embeddings for which diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index b5932aa04..dee74f590 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; using LLama.Exceptions; @@ -62,6 +63,11 @@ public uint BatchThreads set => llama_set_n_threads(this, GenerationThreads, value); } + /// + /// Get the pooling type for this context + /// + public LLamaPoolingType PoolingType => llama_pooling_type(this); + /// /// Get the model which this context is using /// @@ -169,7 +175,7 @@ static SafeLLamaContextHandle() private static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); /// - /// Processes a batch of tokens with the ecoder part of the encoder-decoder model. Stores the encoder output + /// Processes a batch of tokens with the encoder part of the encoder-decoder model. Stores the encoder output /// internally for later use by the decoder cross-attention layers. /// /// @@ -365,6 +371,30 @@ static SafeLLamaContextHandle() [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_lora_adapter_clear(SafeLLamaContextHandle context); + + /// + /// Get the pooling type for this context + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx); + + /// + /// Get the embeddings for the a specific sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id); + + /// + /// Get the embeddings for the ith sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); #endregion #region LoRA @@ -410,6 +440,7 @@ public void ClearLoraAdapters() } #endregion + #region GetLogits /// /// Token logits obtained from the last call to llama_decode /// The logits for the last token are stored in the last row @@ -444,6 +475,43 @@ public Span GetLogitsIth(int i) return new Span(logits, model.VocabCount); } } + #endregion + + #region GetEmbeddings() + /// + /// Get the embeddings for the ith sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + public Span GetEmbeddingsIth(LLamaPos pos) + { + var model = ThrowIfDisposed(); + + unsafe + { + var embd = llama_get_embeddings_ith(this, pos.Value); + Debug.Assert(embd != null); + return new Span(embd, model.EmbeddingSize); + } + } + + /// + /// Get the embeddings for the a specific sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + public Span GetEmbeddingsSeq(LLamaSeqId seq) + { + var model = ThrowIfDisposed(); + + unsafe + { + var embd = llama_get_embeddings_seq(this, seq); + Debug.Assert(embd != null); + return new Span(embd, model.EmbeddingSize); + } + } + #endregion #region tokens /// @@ -495,6 +563,22 @@ public void Synchronize() llama_synchronize(this); } + /// + /// Processes a batch of tokens with the encoder part of the encoder-decoder model. Stores the encoder output + /// internally for later use by the decoder cross-attention layers. + /// + /// + /// 0 = success
< 0 = error
+ public DecodeResult Encode(LLamaBatch batch) + { + if (batch.TokenCount == 0) + return DecodeResult.Ok; + + lock (GlobalInferenceLock) + using (batch.ToNativeBatch(out var nb)) + return (DecodeResult)llama_encode(this, nb); + } + /// /// ///