Skip to content

Commit

Permalink
Merge branch 'SciSharp:master' into Development
Browse files Browse the repository at this point in the history
  • Loading branch information
SignalRT authored Oct 24, 2023
2 parents cede6b0 + 5b6408b commit 32a375d
Show file tree
Hide file tree
Showing 11 changed files with 425 additions and 195 deletions.
5 changes: 3 additions & 2 deletions LLama.Unittest/BeamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ public void BasicBeam()
{
const int num_beams = 2;
const int n_predict = 3;
const string prompt = "The cat sat on";

var context = _model.CreateContext(_params);

var result = new StringBuilder();

var initial_tokens = context.Tokenize("The cat sat on");
result.Append(context.DeTokenize(initial_tokens.ToArray()));
var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens, 0);

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
Expand Down
125 changes: 125 additions & 0 deletions LLama.Unittest/TokenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,129 @@ public void TokensNotEndWithNothing()
var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
Assert.False(result);
}

[Fact]
public void TokensEndWith2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor(new[]
{
"a fish",
"the mat",
"this is an improbably long query to be using for this method"
});
var result = processor.Add(decoder.Read());

Assert.True(result);
}

[Fact]
public void TokensEndSubstring2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor(new[] { "at" });
var result = processor.Add(decoder.Read());

Assert.True(result);
}

[Fact]
public void TokensNotEndWith2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor(new[]
{
"a fish",
"The cat sat on the edge of the ma",
"this is an improbably long query to be using for this method"
});
var result = processor.Add(decoder.Read());

Assert.False(result);
}

[Fact]
public void TokensNotEndWithNothing2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor();
var result = processor.Add(decoder.Read());

Assert.False(result);
}

[Fact]
public void RoundTrip()
{
var strings = new[]
{
"Hello world",
"철수",
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
};

var charsArr = new char[1024];

foreach (var input in strings)
{
// Convert into llama tokens
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);

// Convert tokens back into characters
var chars = _model.NativeHandle.TokensToSpan(tokens, charsArr.AsSpan(), Encoding.UTF8);

// llama.cpp adds a space to the start of strings, remove that
var output = new string(chars).TrimStart(' ');

// Check that the input equals the output
Assert.Equal(input, output);
}
}

[Fact]
public void StreamingDecoderRoundTrip()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

var strings = new[]
{
"Hello world",
"철수",
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
};

foreach (var input in strings)
{
decoder.Reset();

// Convert into llama tokens
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);

// Add tokens to decoder
foreach (var token in tokens)
decoder.Add(token);

// llama.cpp adds a space to the start of strings, remove that
var output = decoder.Read().TrimStart(' ');

// Check that the input equals the output
Assert.Equal(input, output);
}
}
}
66 changes: 66 additions & 0 deletions LLama/AntipromptProcessor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using System;
using System.Collections.Generic;

namespace LLama;

internal sealed class AntipromptProcessor
{
private int _longestAntiprompt;
private readonly List<string> _antiprompts = new();

private string? _string;

public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
{
if (antiprompts != null)
SetAntiprompts(antiprompts);
}

/// <summary>
/// Add an antiprompt to the collection
/// </summary>
/// <param name="antiprompt"></param>
public void AddAntiprompt(string antiprompt)
{
_antiprompts.Add(antiprompt);
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}

/// <summary>
/// Overwrite all current antiprompts with a new set
/// </summary>
/// <param name="antiprompts"></param>
public void SetAntiprompts(IEnumerable<string> antiprompts)
{
_antiprompts.Clear();
_antiprompts.AddRange(antiprompts);

_longestAntiprompt = 0;
foreach (var antiprompt in _antiprompts)
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}

/// <summary>
/// Add some text and check if the buffer now ends with any antiprompt
/// </summary>
/// <param name="text"></param>
/// <returns>true if the text buffer ends with any antiprompt</returns>
public bool Add(string text)
{
_string += text;

// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
var maxLength = Math.Max(32, _longestAntiprompt * 4);
var trimLength = Math.Max(16, _longestAntiprompt * 2);
if (_string.Length > maxLength)
_string = _string.Substring(_string.Length - trimLength);

foreach (var antiprompt in _antiprompts)
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
return true;

return false;
}
}
9 changes: 2 additions & 7 deletions LLama/Extensions/IReadOnlyListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ internal static class IReadOnlyListExtensions
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
Expand Down Expand Up @@ -68,13 +69,6 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
}
}

internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{
return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding);
}

/// <summary>
/// Check if the given set of tokens ends with any of the given strings
/// </summary>
Expand All @@ -83,6 +77,7 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{
Expand Down
24 changes: 24 additions & 0 deletions LLama/Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;

namespace LLama.Extensions
{
internal static class ListExtensions
{
#if NETSTANDARD2_0
public static void EnsureCapacity<T>(this List<T> list, int capacity)
{
if (list.Capacity < capacity)
list.Capacity = capacity;
}
#endif

public static void AddSpan<T>(this List<T> list, ReadOnlySpan<T> items)
{
list.EnsureCapacity(list.Count + items.Length);

for (var i = 0; i < items.Length; i++)
list.Add(items[i]);
}
}
}
32 changes: 7 additions & 25 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
public string DeTokenize(IEnumerable<llama_token> tokens)
[Obsolete("Use a `StreamingTokenDecoder` instead")]
public string DeTokenize(IReadOnlyList<llama_token> tokens)
{
var sb = new StringBuilder();
foreach (var token in tokens)
NativeHandle.TokenToString(token, Encoding, sb);
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
// It should be kept around for the entire time you are decoding one stream of tokens.

return sb.ToString();
var decoder = new StreamingTokenDecoder(this);
decoder.AddRange(tokens);
return decoder.ToString();

Check warning on line 113 in LLama/LLamaContext.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference return.
}

/// <summary>
Expand Down Expand Up @@ -418,26 +420,6 @@ public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
}
#endregion

/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public string TokenToString(llama_token token)
{
return NativeHandle.TokenToString(token, Encoding);
}

/// <summary>
/// Append a single token to a string builder
/// </summary>
/// <param name="token">Token to decode</param>
/// <param name="dest">string builder to append the result to</param>
public void TokenToString(llama_token token, StringBuilder dest)
{
NativeHandle.TokenToString(token, Encoding, dest);
}

/// <inheritdoc />
public void Dispose()
{
Expand Down
5 changes: 1 addition & 4 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string text, IInference
await InferInternal(inferenceParams, args);

if (args.ReturnValue)
{
foreach (var id in _embeds)
yield return Context.TokenToString(id);
}
yield return Context.DeTokenize(_embeds);

Check warning on line 297 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (macos-release)

'LLamaContext.DeTokenize(IReadOnlyList<int>)' está obsoleto: 'Use a `StreamingTokenDecoder` instead'

Check warning on line 297 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'LLamaContext.DeTokenize(IReadOnlyList<int>)' is obsolete: 'Use a `StreamingTokenDecoder` instead'

var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
Expand Down
11 changes: 8 additions & 3 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
Context.Dispose();
Context = _weights.CreateContext(Context.Params, _logger);

var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty<string>());

if (inferenceParams != null)
{
if (inferenceParams.TokensKeep > Context.ContextSize)
Expand All @@ -64,7 +67,6 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?

cancellationToken.ThrowIfCancellationRequested();

var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
inferenceParams ??= new InferenceParams();

var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
Expand Down Expand Up @@ -95,13 +97,16 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);

lastTokens.Add(id);
yield return Context.TokenToString(id);

decoder.Add(id);
var decoded = decoder.Read();
yield return decoded;

tokens.Clear();
tokens.Add(id);

// Check if any of the antiprompts have been generated
if (lastTokens.TokensEndsWithAnyString(antiprompts, Context))
if (antiprocessor.Add(decoded))
break;

// when run out of context
Expand Down
Loading

0 comments on commit 32a375d

Please sign in to comment.