Skip to content

Commit f72c9d2

Browse files
authored
Reduce Tiktoken Creation Memory Allocation (#7202)
1 parent 34eb579 commit f72c9d2

File tree

3 files changed

+60
-12
lines changed

3 files changed

+60
-12
lines changed

src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,37 @@
3232
<Files ParameterType="Microsoft.Build.Framework.ITaskItem[]" Required="true" />
3333
</ParameterGroup>
3434
<Task>
35+
<Using Namespace="System.Globalization" />
3536
<Using Namespace="System.IO" />
3637
<Using Namespace="System.IO.Compression" />
3738
<Code Type="Fragment" Language="cs">
3839
<![CDATA[
39-
foreach(var file in Files)
40+
foreach (var file in Files)
4041
{
41-
using var sourceStream = File.OpenRead(file.GetMetadata("FullPath"));
42+
string fileName = file.GetMetadata("FullPath");
43+
string fileContent = File.ReadAllText(fileName);
44+
int capacity = 1;
45+
int eolIndex = 0;
46+
do
47+
{
48+
if ((eolIndex = fileContent.IndexOf('\n', eolIndex)) >= 0)
49+
{
50+
eolIndex++;
51+
capacity++;
52+
}
53+
else
54+
{
55+
break;
56+
}
57+
} while (eolIndex < fileContent.Length);
58+
59+
using var sourceStream = File.OpenRead(fileName);
4260
using var reader = new StreamReader(sourceStream);
4361
using var destStream = new DeflateStream(File.Create(file.GetMetadata("Destination")), CompressionLevel.Optimal);
4462
using var streamWriter = new StreamWriter(destStream);
4563
64+
streamWriter.WriteLine($"Capacity: {capacity.ToString(CultureInfo.InvariantCulture)}");
65+
4666
string line;
4767
int destLineNumber = 0;
4868

src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,37 @@ private void CacheSpecialTokensEncoding(IReadOnlyDictionary<string, int>? specia
156156
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<StringSpanOrdinalKey, (int Id, string Token)>, Dictionary<int, ReadOnlyMemory<byte>>)> LoadTiktokenBpeAsync(
157157
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
158158
{
159-
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
160-
var vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>();
161-
var decoder = new Dictionary<int, ReadOnlyMemory<byte>>();
159+
Dictionary<ReadOnlyMemory<byte>, int> encoder;
160+
Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab;
161+
Dictionary<int, ReadOnlyMemory<byte>> decoder;
162162

163163
try
164164
{
165165
// Don't dispose the reader as it will dispose the underlying stream vocabStream. The caller is responsible for disposing the stream.
166166
StreamReader reader = new StreamReader(vocabStream);
167-
string? line;
168-
do
167+
string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
168+
169+
const string capacity = "Capacity: ";
170+
int suggestedCapacity = 0; // default capacity
171+
if (line is not null && line.StartsWith(capacity, StringComparison.Ordinal))
169172
{
170-
line = useAsync ?
171-
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
172-
reader.ReadLine();
173-
} while (line is not null && line.Length == 0);
173+
if (!Helpers.TryParseInt32(line, capacity.Length, out suggestedCapacity))
174+
{
175+
throw new FormatException($"Invalid format in the BPE vocab file stream");
176+
}
177+
178+
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
179+
}
180+
181+
encoder = new Dictionary<ReadOnlyMemory<byte>, int>(suggestedCapacity, ReadOnlyMemoryByteComparer.Instance);
182+
vocab = new Dictionary<StringSpanOrdinalKey, (int Id, string Token)>(suggestedCapacity);
183+
decoder = new Dictionary<int, ReadOnlyMemory<byte>>(suggestedCapacity);
184+
185+
// skip empty lines
186+
while (line is not null && line.Length == 0)
187+
{
188+
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
189+
}
174190

175191
if (line is not null && line.IndexOf(' ') < 0)
176192
{

test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
using System.Reflection;
1313
using System.Text;
1414
using System.Text.Json;
15-
using System.Text.Json.Serialization;
1615
using System.Threading.Tasks;
1716
using Xunit;
1817

@@ -501,9 +500,22 @@ public void TestCreationUsingModel(string modelName)
501500
{
502501
RemoteExecutor.Invoke(static (name) =>
503502
{
503+
#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
504+
long allocation = GC.GetAllocatedBytesForCurrentThread();
505+
#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
506+
504507
Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(name);
505508
Assert.True(tokenizer is TiktokenTokenizer);
506509
Assert.NotNull(tokenizer.PreTokenizer);
510+
511+
#if NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
512+
int entriesCount = GetEncoder((tokenizer as TiktokenTokenizer)!)!.Count;
513+
allocation = GC.GetAllocatedBytesForCurrentThread() - allocation;
514+
515+
// entriesCount * 260 is average memory allocation during the initialization for the the models we carry data files for.
516+
// this allocation is not the size of the cache but it include all temporary allocations during the initialization.
517+
Assert.True((entriesCount * 260) > allocation, $"Memory allocation of {entriesCount} entries for {name}: {allocation} bytes");
518+
#endif // NET8_0_OR_GREATER || NETFRAMEWORK_4_8_OR_GREATER
507519
}, modelName).Dispose();
508520
}
509521

0 commit comments

Comments
 (0)