Skip to content

Commit

Permalink
- Fixed LLamaEmbedder example
Browse files Browse the repository at this point in the history
 - Using `llama_set_embeddings` to toggle on embedding mode, so it no longer needs to be specified in the params
  • Loading branch information
martindevans committed Aug 27, 2024
1 parent 4f3f0ed commit a3028de
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
4 changes: 2 additions & 2 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public class ExampleRunner
{ "Executor: Stateless mode chat", StatelessModeExecute.Run },
{ "Save and Load: chat session", SaveAndLoadSession.Run },
{ "Save and Load: state of model and executor", LoadAndSaveState.Run },
{ "LLama Model: Get embeddings", () => Task.Run(GetEmbeddings.Run) },
{ "LLama Model: Quantize", () => Task.Run(QuantizeModel.Run) },
{ "LLama Model: Get embeddings", GetEmbeddings.Run },
{ "LLama Model: Quantize", QuantizeModel.Run },
{ "Grammar: Constrain response to json format", GrammarJsonResponse.Run },
{ "Kernel Memory: Document Q&A", KernelMemory.Run },
{ "Kernel Memory: Save and Load", KernelMemorySaveAndLoad.Run },
Expand Down
25 changes: 18 additions & 7 deletions LLama.Examples/Examples/GetEmbeddings.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Examples.Examples
{
public class GetEmbeddings
{
public static void Run()
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

Console.ForegroundColor = ConsoleColor.DarkGray;
var @params = new ModelParams(modelPath) { Embeddings = true };
var @params = new ModelParams(modelPath)
{
// Embedding models can return one embedding per token, or all of them can be combined ("pooled") into
// one single embedding. Setting PoolingType to "Mean" will combine all of the embeddings using mean average.
PoolingType = LLamaPoolingType.Mean,
};
using var weights = LLamaWeights.LoadFromFile(@params);
var embedder = new LLamaEmbedder(weights, @params);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine(
"""
This example displays embeddings from a text prompt.
Embeddings are numerical codes that represent information like words, images, or concepts.
These codes capture important relationships between those objects,
Embeddings are vectors that represent information like words, images, or concepts.
These vector capture important relationships between those objects,
like how similar words are in meaning or how close images are visually.
This allows machine learning models to efficiently understand and process complex data.
Embeddings of a text in LLM is sometimes useful, for example, to train other MLP models.
"""); // NOTE: this description was AI generated
""");

while (true)
{
Expand All @@ -32,8 +38,13 @@ This allows machine learning models to efficiently understand and process comple
var text = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;

float[] embeddings = embedder.GetEmbeddings(text).Result;
Console.WriteLine($"Embeddings contain {embeddings.Length:N0} floating point values:");
// Get embeddings for the text
var embeddings = await embedder.GetEmbeddings(text);

// This should have returned one single embedding vector, because PoolingType was set to Mean above.
var embedding = embeddings.Single();

Console.WriteLine($"Embeddings contain {embedding.Length:N0} floating point values:");
Console.ForegroundColor = ConsoleColor.DarkGray;
Console.WriteLine(string.Join(", ", embeddings.Take(20)) + ", ...");
Console.WriteLine();
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/QuantizeModel.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
namespace LLama.Examples.Examples
namespace LLama.Examples.Examples
{
public class QuantizeModel
{
public static void Run()
public static async Task Run()
{
string inputPath = UserSettings.GetModelPath();

Expand Down
2 changes: 0 additions & 2 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ private async Task CompareEmbeddings(string modelPath)
{
ContextSize = 8,
Threads = 4,
Embeddings = true,
GpuLayerCount = Constants.CIGpuLayerCount,
PoolingType = LLamaPoolingType.Mean,
};
Expand Down Expand Up @@ -74,7 +73,6 @@ private async Task NonPooledEmbeddings(string modelPath)
{
ContextSize = 8,
Threads = 4,
Embeddings = true,
GpuLayerCount = Constants.CIGpuLayerCount,
PoolingType = LLamaPoolingType.None,
};
Expand Down
3 changes: 1 addition & 2 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ public sealed class LLamaEmbedder
/// <param name="logger"></param>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
if (!@params.Embeddings)
throw new ArgumentException("Embeddings must be true", nameof(@params));
if (@params.UBatchSize != @params.BatchSize)
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params));
if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true })
throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported");

Context = weights.CreateContext(@params, logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
}

/// <inheritdoc />
Expand Down

0 comments on commit a3028de

Please sign in to comment.