Skip to content

Commit

Permalink
Merge pull request #964 from stephentoub/meai
Browse files Browse the repository at this point in the history
Add Microsoft.Extensions.AI support for IChatClient / IEmbeddingGenerator
  • Loading branch information
martindevans authored Nov 1, 2024
2 parents b2c5e3f + 0d7875f commit 3f176be
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 3 deletions.
22 changes: 22 additions & 0 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.AI;
using Xunit.Abstractions;

namespace LLama.Unittest;
Expand Down Expand Up @@ -41,6 +42,27 @@ private async Task CompareEmbeddings(string modelPath)
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, spoon);

var generator = (IEmbeddingGenerator<string, Embedding<float>>)embedder;
Assert.NotNull(generator.Metadata);
Assert.Equal(nameof(LLamaEmbedder), generator.Metadata.ProviderName);
Assert.NotNull(generator.Metadata.ModelId);
Assert.NotEmpty(generator.Metadata.ModelId);
Assert.Same(embedder, generator.GetService<LLamaEmbedder>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Null(generator.GetService<string>());

var embeddings = await generator.GenerateAsync(
[
"The cat is cute",
"The kitten is cute",
"The spoon is not real"
]);
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.True(embeddings.Usage?.InputTokenCount is 19 or 20);
Assert.True(embeddings.Usage?.TotalTokenCount is 19 or 20);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
Expand Down
169 changes: 169 additions & 0 deletions LLama/Extensions/LLamaExecutorExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Common;
using LLama.Sampling;
using Microsoft.Extensions.AI;

namespace LLama.Abstractions;

/// <summary>
/// Extension methods to the <see cref="LLamaExecutorExtensions" /> interface.
/// </summary>
public static class LLamaExecutorExtensions
{
/// <summary>Gets an <see cref="IChatClient"/> instance for the specified <see cref="ILLamaExecutor"/>.</summary>
/// <param name="executor">The executor.</param>
/// <param name="historyTransform">The <see cref="IHistoryTransform"/> to use to transform an input list messages into a prompt.</param>
/// <param name="outputTransform">The <see cref="ITextStreamTransform"/> to use to transform the output into text.</param>
/// <returns>An <see cref="IChatClient"/> instance for the provided <see cref="ILLamaExecutor" />.</returns>
/// <exception cref="ArgumentNullException"><paramref name="executor"/> is null.</exception>
public static IChatClient AsChatClient(
this ILLamaExecutor executor,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null) =>
new LLamaExecutorChatClient(
executor ?? throw new ArgumentNullException(nameof(executor)),
historyTransform,
outputTransform);

private sealed class LLamaExecutorChatClient(
ILLamaExecutor executor,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null) : IChatClient
{
private static readonly InferenceParams s_defaultParams = new();
private static readonly DefaultSamplingPipeline s_defaultPipeline = new();
private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"];
[ThreadStatic]
private static Random? t_random;

private readonly ILLamaExecutor _executor = executor;
private readonly IHistoryTransform _historyTransform = historyTransform ?? new AppendAssistantHistoryTransform();
private readonly ITextStreamTransform _outputTransform = outputTransform ??
new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts);

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient));

/// <inheritdoc/>
public void Dispose() { }

/// <inheritdoc/>
public TService? GetService<TService>(object? key = null) where TService : class =>
typeof(TService) == typeof(ILLamaExecutor) ? (TService)_executor :
this as TService;

/// <inheritdoc/>
public async Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);

StringBuilder text = new();
await foreach (var token in _outputTransform.TransformAsync(result))
{
text.Append(token);
}

return new(new ChatMessage(ChatRole.Assistant, text.ToString()))
{
CreatedAt = DateTime.UtcNow,
};
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);

await foreach (var token in _outputTransform.TransformAsync(result))
{
yield return new()
{
CreatedAt = DateTime.UtcNow,
Role = ChatRole.Assistant,
Text = token,
};
}
}

/// <summary>Format the chat messages into a string prompt.</summary>
private string CreatePrompt(IList<ChatMessage> messages)
{
if (messages is null)
{
throw new ArgumentNullException(nameof(messages));
}

ChatHistory history = new();

if (_executor is not StatefulExecutorBase seb ||
seb.GetStateData() is InteractiveExecutor.InteractiveExecutorState { IsPromptRun: true })
{
foreach (var message in messages)
{
history.AddMessage(
message.Role == ChatRole.System ? AuthorRole.System :
message.Role == ChatRole.Assistant ? AuthorRole.Assistant :
AuthorRole.User,
string.Concat(message.Contents.OfType<TextContent>()));
}
}
else
{
// Stateless executor with IsPromptRun = false: use only the last message.
history.AddMessage(AuthorRole.User, string.Concat(messages.LastOrDefault()?.Contents.OfType<TextContent>() ?? []));
}

return _historyTransform.HistoryToText(history);
}

/// <summary>Convert the chat options to inference parameters.</summary>
private static InferenceParams? CreateInferenceParams(ChatOptions? options)
{
List<string> antiPrompts = new(s_antiPrompts);
if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList<string>? anti) is true)
{
antiPrompts.AddRange(anti);
}

return new()
{
AntiPrompts = antiPrompts,
TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep,
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
SamplingPipeline = new DefaultSamplingPipeline()
{
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,
Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar,
MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep,
MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP,
Seed = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Seed), out uint seed) is true ? seed : (uint)(t_random ??= new()).Next(),
TailFreeZ = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TailFreeZ), out float tfz) is true ? tfz : s_defaultPipeline.TailFreeZ,
Temperature = options?.Temperature ?? 0,
TopP = options?.TopP ?? 0,
TopK = options?.TopK ?? s_defaultPipeline.TopK,
TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP,
},
};
}

/// <summary>A default transform that appends "Assistant: " to the end.</summary>
private sealed class AppendAssistantHistoryTransform : LLamaTransforms.DefaultHistoryTransform
{
public override string HistoryToText(ChatHistory history) =>
$"{base.HistoryToText(history)}{AuthorRole.Assistant}: ";
}
}
}
12 changes: 12 additions & 0 deletions LLama/Extensions/SpanNormalizationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ public static Span<float> EuclideanNormalization(this Span<float> vector)
return vector;
}

/// <summary>
/// Creates a new array containing an L2 normalization of the input vector.
/// </summary>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static float[] EuclideanNormalization(this ReadOnlySpan<float> vector)
{
var result = new float[vector.Length];
TensorPrimitives.Divide(vector, TensorPrimitives.Norm(vector), result);
return result;
}

/// <summary>
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
/// <list type="bullet">
Expand Down
54 changes: 54 additions & 0 deletions LLama/LLamaEmbedder.EmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using Microsoft.Extensions.AI;

namespace LLama;

public partial class LLamaEmbedder
: IEmbeddingGenerator<string, Embedding<float>>
{
private EmbeddingGeneratorMetadata? _metadata;

/// <inheritdoc />
EmbeddingGeneratorMetadata IEmbeddingGenerator<string, Embedding<float>>.Metadata =>
_metadata ??= new(
nameof(LLamaEmbedder),
modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
dimensions: EmbeddingSize);

/// <inheritdoc />
TService? IEmbeddingGenerator<string, Embedding<float>>.GetService<TService>(object? key) where TService : class =>
typeof(TService) == typeof(LLamaContext) ? (TService)(object)Context :
this as TService;

/// <inheritdoc />
async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
{
if (Context.NativeHandle.PoolingType == LLamaPoolingType.None)
{
throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}.");
}

GeneratedEmbeddings<Embedding<float>> results = new()
{
Usage = new() { InputTokenCount = 0 },
};

foreach (var value in values)
{
var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false);
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled.");

results.Usage.InputTokenCount += tokenCount;
results.Add(new Embedding<float>(embeddings[0]) { CreatedAt = DateTime.UtcNow });
}

results.Usage.TotalTokenCount = results.Usage.InputTokenCount;

return results;
}
}
9 changes: 6 additions & 3 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace LLama;
/// <summary>
/// Generate high dimensional embedding vectors from text
/// </summary>
public sealed class LLamaEmbedder
public sealed partial class LLamaEmbedder
: IDisposable
{
/// <summary>
Expand Down Expand Up @@ -58,7 +58,10 @@ public void Dispose()
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
/// <exception cref="NotSupportedException"></exception>
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default)
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default) =>
(await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings;

private async Task<(IReadOnlyList<float[]> Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default)
{
// Add all of the tokens to the batch
var tokens = Context.Tokenize(input);
Expand Down Expand Up @@ -113,6 +116,6 @@ public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, Cancellati

Context.NativeHandle.KvCacheClear();

return results;
return (results, tokens.Length);
}
}
2 changes: 2 additions & 0 deletions LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.1" />
<PackageReference Include="System.Numerics.Tensors" Version="8.0.0" />
</ItemGroup>
Expand Down

0 comments on commit 3f176be

Please sign in to comment.