Skip to content

Commit

Permalink
Code cleanup driven by R# suggestions:
Browse files Browse the repository at this point in the history
 - Made `NativeApi` into a `static class` (it's not intended to be instantiated)
 - Moved `LLamaTokenType` enum out into a separate file
 - Made `LLamaSeqId` and `LLamaPos` into `record struct`, convenient to have equality etc
  • Loading branch information
martindevans committed Jan 2, 2024
1 parent a408335 commit f860f88
Show file tree
Hide file tree
Showing 22 changed files with 126 additions and 140 deletions.
2 changes: 1 addition & 1 deletion LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public sealed record MetadataOverride
/// <summary>
/// Get the key being overriden by this override
/// </summary>
public string Key { get; init; }
public string Key { get; }

internal LLamaModelKvOverrideType Type { get; }

Expand Down
9 changes: 2 additions & 7 deletions LLama/Common/ChatHistory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand Down Expand Up @@ -37,6 +36,7 @@ public enum AuthorRole
/// </summary>
public class ChatHistory
{
private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };

/// <summary>
/// Chat message representation
Expand Down Expand Up @@ -96,12 +96,7 @@ public void AddMessage(AuthorRole authorRole, string content)
/// <returns></returns>
public string ToJson()
{
return JsonSerializer.Serialize(
this,
new JsonSerializerOptions()
{
WriteIndented = true
});
return JsonSerializer.Serialize(this, _jsonOptions);
}

/// <summary>
Expand Down
1 change: 0 additions & 1 deletion LLama/Common/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using LLama.Extensions;

namespace LLama.Common
{
Expand Down
2 changes: 2 additions & 0 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ public record InferenceParams
/// number of tokens to keep from initial prompt
/// </summary>
public int TokensKeep { get; set; } = 0;

/// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
/// </summary>
public int MaxTokens { get; set; } = -1;

/// <summary>
/// logit bias for specific tokens
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions LLama/Extensions/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TK

internal static TValue GetValueOrDefaultImpl<TKey, TValue>(IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{
// ReSharper disable once CanSimplifyDictionaryTryGetValueWithGetValueOrDefault (this is a shim for that method!)
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
}
}
Expand Down
8 changes: 7 additions & 1 deletion LLama/Grammars/Grammar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public sealed class Grammar
/// <summary>
/// Index of the initial rule to start from
/// </summary>
public ulong StartRuleIndex { get; set; }
public ulong StartRuleIndex { get; }

/// <summary>
/// The rules which make up this grammar
Expand Down Expand Up @@ -121,6 +121,12 @@ private void PrintRule(StringBuilder output, GrammarRule rule)
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;

case LLamaGrammarElementType.END:
case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
default:
output.Append("] ");
break;
Expand Down
11 changes: 1 addition & 10 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public sealed class LLamaContext
/// <summary>
/// The context params set for this context
/// </summary>
public IContextParams Params { get; set; }
public IContextParams Params { get; }

/// <summary>
/// The native handle, which is used to be passed to the native APIs
Expand All @@ -56,15 +56,6 @@ public sealed class LLamaContext
/// </summary>
public Encoding Encoding { get; }

internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;

_logger = logger;
Encoding = @params.Encoding;
NativeHandle = nativeContext;
}

/// <summary>
/// Create a new LLamaContext for the given LLamaWeights
/// </summary>
Expand Down
18 changes: 8 additions & 10 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ namespace LLama
public sealed class LLamaEmbedder
: IDisposable
{
private readonly LLamaContext _ctx;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
public int EmbeddingSize => Context.EmbeddingSize;

/// <summary>
/// LLama Context
/// </summary>
public LLamaContext Context => this._ctx;
public LLamaContext Context { get; }

/// <summary>
/// Create a new embedder, using the given LLamaWeights
Expand All @@ -33,7 +31,7 @@ public sealed class LLamaEmbedder
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
@params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params, logger);
Context = weights.CreateContext(@params, logger);
}

/// <summary>
Expand Down Expand Up @@ -72,20 +70,20 @@ public float[] GetEmbeddings(string text)
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
{
var embed_inp_array = _ctx.Tokenize(text, addBos);
var embed_inp_array = Context.Tokenize(text, addBos);

// TODO(Rinne): deal with log of prompt

if (embed_inp_array.Length > 0)
_ctx.Eval(embed_inp_array, 0);
Context.Eval(embed_inp_array, 0);

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'LLamaContext.Eval(int[], int)' is obsolete: 'use llama_decode() instead'

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'LLamaContext.Eval(int[], int)' is obsolete: 'use llama_decode() instead'

unsafe
{
var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();

return new Span<float>(embeddings, EmbeddingSize).ToArray();
return embeddings.ToArray();
}
}

Expand All @@ -94,7 +92,7 @@ public float[] GetEmbeddings(string text, bool addBos)
/// </summary>
public void Dispose()
{
_ctx.Dispose();
Context.Dispose();
}

}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public sealed class LLamaWeights
/// </summary>
public IReadOnlyDictionary<string, string> Metadata { get; set; }

internal LLamaWeights(SafeLlamaModelHandle weights)
private LLamaWeights(SafeLlamaModelHandle weights)
{
NativeHandle = weights;
Metadata = weights.ReadMetadata();
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public struct LLamaKvCacheViewCell
/// May be negative if the cell is not populated.
/// </summary>
public LLamaPos pos;
};
}

/// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view)
Expand Down Expand Up @@ -130,7 +130,7 @@ public ref LLamaKvCacheView GetView()
}
}

partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Create an empty KV cache view. (use only for debugging purposes)
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/LLamaPos.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLama.Native;
/// Indicates position in a sequence
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaPos
public record struct LLamaPos
{
/// <summary>
/// The raw value
Expand All @@ -17,7 +17,7 @@ public struct LLamaPos
/// Create a new LLamaPos
/// </summary>
/// <param name="value"></param>
public LLamaPos(int value)
private LLamaPos(int value)
{
Value = value;
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/LLamaSeqId.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLama.Native;
/// ID for a sequence in a batch
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaSeqId
public record struct LLamaSeqId
{
/// <summary>
/// The raw value
Expand All @@ -17,7 +17,7 @@ public struct LLamaSeqId
/// Create a new LLamaSeqId
/// </summary>
/// <param name="value"></param>
public LLamaSeqId(int value)
private LLamaSeqId(int value)
{
Value = value;
}
Expand Down
12 changes: 12 additions & 0 deletions LLama/Native/LLamaTokenType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace LLama.Native;

public enum LLamaTokenType
{
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
LLAMA_TOKEN_TYPE_CONTROL = 3,
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
LLAMA_TOKEN_TYPE_UNUSED = 5,
LLAMA_TOKEN_TYPE_BYTE = 6,
}
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.BeamSearch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace LLama.Native;

public partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Type of pointer to the beam_search_callback function.
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/NativeApi.Grammar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace LLama.Native
{
using llama_token = Int32;

public unsafe partial class NativeApi
public static partial class NativeApi
{
/// <summary>
/// Create a new grammar from the given set of grammar rules
Expand All @@ -15,7 +15,7 @@ public unsafe partial class NativeApi
/// <param name="start_rule_index"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
public static extern unsafe IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);

/// <summary>
/// Free all memory from the given SafeLLamaGrammarHandle
Expand Down
Loading

0 comments on commit f860f88

Please sign in to comment.