Skip to content

Commit

Permalink
Improve mono data processing performance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Dec 14, 2023
1 parent b3185fc commit 90c3563
Showing 1 changed file with 102 additions and 64 deletions.
166 changes: 102 additions & 64 deletions Seq2SeqSharp/Corpus/MonoCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@
using System.IO.MemoryMappedFiles;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace Seq2SeqSharp.Tools
{

public class IndexData
{
public Dictionary<long, LinkedList<long>> len2offsets = new Dictionary<long, LinkedList<long>>();
public Dictionary<long, long> len2lengths = new Dictionary<long, long>();
public string filePath;
}

public class MonoCorpus<T> : ICorpus<T> where T : ISntPairBatch, new()
{
internal int m_maxTgtTokenSize = 32;
Expand Down Expand Up @@ -102,24 +111,31 @@ public List<Dictionary<string, long>> CountTokenFreqs()
}


private (Dictionary<long, LinkedList<long>>, Dictionary<long, long>, string) BuildIndex()
private IndexData[] BuildIndex()
{
Logger.WriteLine(Logger.Level.debug, $"Start to build index for data set.");

SortedDictionary<int, int> dictTgtLenDist = new SortedDictionary<int, int>();
int corpusSize = 0;
int tooLongTgtSntCnt = 0;
string randomFileName = Path.GetRandomFileName();
Logger.WriteLine($"Loading and shuffling corpus from '{m_tgtFileList.Count}' files.");

string binaryDataSetFilePath = randomFileName + ".tmp";
BinaryWriter bw = new BinaryWriter(new FileStream(binaryDataSetFilePath, FileMode.Create));
Logger.WriteLine($"Loading and shuffling corpus from '{m_tgtFileList.Count}' files.");
object locker = new object();

Dictionary<long, LinkedList<long>> len2offsets = new Dictionary<long, LinkedList<long>>();
Dictionary<long, long> len2lengths = new Dictionary<long, long>();
IndexData[] indexDatas = new IndexData[m_tgtFileList.Count];

for (int i = 0; i < m_tgtFileList.Count; i++)
Parallel.For(0, m_tgtFileList.Count, i =>
{
string randomFileName = m_tgtFileList[i];
string binaryDataSetFilePath = randomFileName + ".tmp";
BinaryWriter bw = new BinaryWriter(new FileStream(binaryDataSetFilePath, FileMode.Create));
Dictionary<long, LinkedList<long>> len2offsets = new Dictionary<long, LinkedList<long>>();
Dictionary<long, long> len2lengths = new Dictionary<long, long>();
StreamReader srTgt = new StreamReader(m_tgtFileList[i]);
while (true)
Expand All @@ -137,22 +153,19 @@ public List<Dictionary<string, long>> CountTokenFreqs()
if (m_showTokenDist)
{
if (dictTgtLenDist.ContainsKey(rawSntPair.TgtTokenSize / 100) == false)
lock (locker)
{
dictTgtLenDist.Add(rawSntPair.TgtTokenSize / 100, 0);
if (dictTgtLenDist.ContainsKey(rawSntPair.TgtTokenSize / 100) == false)
{
dictTgtLenDist.Add(rawSntPair.TgtTokenSize / 100, 0);
}
dictTgtLenDist[rawSntPair.TgtTokenSize / 100]++;
}
dictTgtLenDist[rawSntPair.TgtTokenSize / 100]++;
}
bool hasTooLongSent = false;
if (rawSntPair.TgtTokenSize > m_maxTgtTokenSize)
{
Interlocked.Increment(ref tooLongTgtSntCnt);
hasTooLongSent = true;
}

if (hasTooLongSent)
{
continue;
}
Expand Down Expand Up @@ -181,9 +194,21 @@ public List<Dictionary<string, long>> CountTokenFreqs()
}
srTgt.Close();
}
bw.Close();
indexDatas[i] = new IndexData();
indexDatas[i].len2lengths= len2lengths;
indexDatas[i].len2offsets= len2offsets;
indexDatas[i].filePath = binaryDataSetFilePath;
});







bw.Close();


Logger.WriteLine(Logger.Level.debug, $"Shuffled '{corpusSize}' sentence pairs.");

Expand Down Expand Up @@ -220,7 +245,7 @@ public List<Dictionary<string, long>> CountTokenFreqs()

Logger.WriteLine(Logger.Level.debug, $"Finished to build index for data set.");

return (len2offsets, len2lengths, binaryDataSetFilePath);
return indexDatas;
}


Expand All @@ -246,69 +271,82 @@ public void PrepareDataSet()
try
{
m_batchNumInTotal = 0;
(var length2offsets, var length2counts, string tmpDataSetFilePath) = BuildIndex();
long totalRecordsNum = 0;
foreach (var pair in length2offsets)
{
totalRecordsNum += length2counts[pair.Key];
}
var indexDatas = BuildIndex();
object locker = new object();

Logger.WriteLine(Logger.Level.debug, $"Start to sort and shuffle data set by length.");

m_indexedDataSetFilePath = tmpDataSetFilePath + ".sorted";
m_indexedDataSetFilePath = Path.GetRandomFileName() + ".sorted";
using (BinaryWriter bw = new BinaryWriter(new FileStream(m_indexedDataSetFilePath, FileMode.Create, FileAccess.Write, FileShare.None, 40960000)))
using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(tmpDataSetFilePath))
using (MemoryMappedViewStream mms = mmf.CreateViewStream())
{
using (BinaryReader br = new BinaryReader(mms))

Parallel.ForEach(indexDatas, (indexData) =>
{
long totalRecordsNum = 0;
var length2offsets = indexData.len2offsets;
var length2counts = indexData.len2lengths;
foreach (var pair in length2offsets)
{
while (length2offsets.Count > 0)
{
long length = GetNextLength(length2counts, totalRecordsNum);
LinkedList<long> offsets = length2offsets[length];
totalRecordsNum += length2counts[pair.Key];
}
int totalTgtTokenSize = 0;
int sentSize = 0;
List<string> tgtLines = new List<string>();
while (totalTgtTokenSize < m_maxTokenSizePerBatch && offsets.Any())
var tmpDataSetFilePath = indexData.filePath;
using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(tmpDataSetFilePath))
using (MemoryMappedViewStream mms = mmf.CreateViewStream())
{
using (BinaryReader br = new BinaryReader(mms))
{
while (length2offsets.Count > 0)
{
long offset = offsets.First.Value;
offsets.RemoveFirst();
length2counts[length]--;
totalRecordsNum--;
long length = GetNextLength(length2counts, totalRecordsNum);
LinkedList<long> offsets = length2offsets[length];
br.BaseStream.Seek(offset, SeekOrigin.Begin);
string tgtLine = br.ReadString();
totalTgtTokenSize += tgtLine.Split(' ').Length;
tgtLines.Add(tgtLine);
int totalTgtTokenSize = 0;
int sentSize = 0;
List<string> tgtLines = new List<string>();
while (totalTgtTokenSize < m_maxTokenSizePerBatch && offsets.Any())
{
long offset = offsets.First.Value;
offsets.RemoveFirst();
length2counts[length]--;
totalRecordsNum--;
br.BaseStream.Seek(offset, SeekOrigin.Begin);
string tgtLine = br.ReadString();
totalTgtTokenSize += tgtLine.Split(' ').Length;
tgtLines.Add(tgtLine);
sentSize++;
}
bw.Write(sentSize);
bw.Write(String.Join("\n", tgtLines));
sentSize++;
}
m_batchNumInTotal++;
if (m_batchNumInTotal % 10000 == 0)
{
Logger.WriteLine(Logger.Level.debug, $"Batch '{m_batchNumInTotal}' has been processed.");
}
lock (locker)
{
bw.Write(sentSize);
bw.Write(String.Join("\n", tgtLines));
}
Interlocked.Increment(ref m_batchNumInTotal);
if (m_batchNumInTotal % 10000 == 0)
{
Logger.WriteLine(Logger.Level.debug, $"Batch '{m_batchNumInTotal}' has been processed.");
}
if (offsets.Any() == false)
{
length2offsets.Remove(length);
length2counts.Remove(length);
if (offsets.Any() == false)
{
length2offsets.Remove(length);
length2counts.Remove(length);
}
}
}

bw.Write(-1);
}
}

File.Delete(tmpDataSetFilePath);
File.Delete(tmpDataSetFilePath);
});

bw.Write(-1);
}

Logger.WriteLine($"Finished to sort and shuffle data set by length. Total batch size = '{m_batchNumInTotal}'");
}
catch (Exception err)
Expand Down

0 comments on commit 90c3563

Please sign in to comment.