diff --git a/Seq2SeqSharp/Applications/Decoder.cs b/Seq2SeqSharp/Applications/Decoder.cs index cb9abba..99f0f5b 100644 --- a/Seq2SeqSharp/Applications/Decoder.cs +++ b/Seq2SeqSharp/Applications/Decoder.cs @@ -38,8 +38,8 @@ public static MultiProcessorNetworkWrapper CreateDecoders(IModel model decoder = new MultiProcessorNetworkWrapper( new GPTDecoder("GPTDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(), isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, - expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType, attentionType: options.AttentionType, multiHeadAttentionType: options.MultiHeadAttentionType, - KVGroupNum: options.KVGroupNum), raDeviceIds.ToArray()); + expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType, attentionType: options.AttentionType, multiHeadAttentionType: model.MultiHeadAttentionType, + KVGroupNum: model.KVGroupNum), raDeviceIds.ToArray()); } else { diff --git a/Seq2SeqSharp/Models/IModel.cs b/Seq2SeqSharp/Models/IModel.cs index d41e883..f4d2255 100644 --- a/Seq2SeqSharp/Models/IModel.cs +++ b/Seq2SeqSharp/Models/IModel.cs @@ -45,6 +45,9 @@ public interface IModel public int MaxSegmentNum { get; set; } + public MultiHeadAttentionTypeEnums MultiHeadAttentionType { get; set; } + public int KVGroupNum { get; set; } + public void AddWeights(string name, float[] weights); public float[] GetWeights(string name);