From 7e8fee2e92cc0c0bfc80b5a9d875100b317f38c9 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Mon, 7 Oct 2024 07:24:47 -0700 Subject: [PATCH] Fix model saving bug for GQA --- Seq2SeqSharp/Layers/GroupQueryAttention.cs | 4 ++-- Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs | 2 ++ Seq2SeqSharp/Seq2SeqSharp.csproj | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Seq2SeqSharp/Layers/GroupQueryAttention.cs b/Seq2SeqSharp/Layers/GroupQueryAttention.cs index 8fc4848..01717f4 100644 --- a/Seq2SeqSharp/Layers/GroupQueryAttention.cs +++ b/Seq2SeqSharp/Layers/GroupQueryAttention.cs @@ -108,7 +108,7 @@ public GroupQueryAttention(string name, int num_heads, int num_kv_groups, int d_ /// Transformered output tensor public (IWeightTensor, IWeightTensor) Perform(IWeightTensor inputQ, IWeightTensor inputK, IWeightTensor inputV, IWeightTensor keyMask, int batchSize, IComputeGraph graph, bool outputAttenWeights = false, Dictionary cachedTensors = null) { - string keyName = $"{m_name}_MultiHeadAttention_3"; + string keyName = $"{m_name}_GroupQueryAttention_3"; using IComputeGraph g = graph.CreateSubGraph(keyName); int seqLenQ = inputQ.Rows / batchSize; @@ -340,7 +340,7 @@ public void Load(IModel stream) public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int batchSize, IComputeGraph graph, Dictionary cachedTensors = null) { - string keyName = $"{m_name}_MultiHeadAttention_3"; + string keyName = $"{m_name}_GroupQueryAttention_1"; using IComputeGraph g = graph.CreateSubGraph(keyName); int seqLenQ = inputQ.Rows / batchSize; diff --git a/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs b/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs index 715950f..1bc5af0 100644 --- a/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs +++ b/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs @@ -284,6 +284,8 @@ public Model_4_ProtoBufSerializer(Model m) ExpertsPerTokenFactor = m.ExpertsPerTokenFactor; PEType= m.PEType; NormType = m.NormType; + MultiHeadAttentionType = m.MultiHeadAttentionType; + KVGroupNum = m.KVGroupNum; } public static Model_4_ProtoBufSerializer Create(Model m) => new Model_4_ProtoBufSerializer(m); diff --git a/Seq2SeqSharp/Seq2SeqSharp.csproj b/Seq2SeqSharp/Seq2SeqSharp.csproj index 4d8d4ff..d3dd409 100644 --- a/Seq2SeqSharp/Seq2SeqSharp.csproj +++ b/Seq2SeqSharp/Seq2SeqSharp.csproj @@ -15,7 +15,7 @@ AnyCPU false bin\ - 2.8.13 + 2.8.14 Seq2SeqSharp is a tensor based fast & flexible encoder-decoder deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task and sequence-classification task and other NLP tasks. Seq2SeqSharp supports both CPUs (x86, x64 and ARM64) and GPUs. It's powered by .NET core, so Seq2SeqSharp can run on both Windows and Linux without any modification and recompilation. README.md Seq2SeqSharp