diff --git a/models/transformers/bert_w2ner.go b/models/transformers/bert_w2ner.go index c40a37a..6b8deb0 100644 --- a/models/transformers/bert_w2ner.go +++ b/models/transformers/bert_w2ner.go @@ -61,6 +61,15 @@ func generateInitDistInputs(size int) [][]int32 { return matrix } +func generateAllTrueSlice(size int) []bool { + matrix := make([]bool, size) + for i := 0; i < size; i++ { + matrix[i] = true + } + + return matrix +} + ///////////////////////////////////////// Bert Service Pre-Process Function ///////////////////////////////////////// // getBertInputFeature Get Bert-W2NER Feature (before Make HTTP or GRPC Request). @@ -100,9 +109,9 @@ func (w *W2NerModelService) getBertInputFeature(batchInferData [][]string) []*W2 start := 0 for j := 0; j < padTokenLen; j++ { - gridMask2d[j] = make([]bool, padTokenLen) + gridMask2d[j] = generateAllTrueSlice(padTokenLen) - if j+1 >= len(inferTokens) { + if j >= len(inferTokens) { pieces2word[j] = make([]bool, len(batchInputFeatures[i].TokenIDs)) continue } @@ -127,7 +136,6 @@ func (w *W2NerModelService) getBertInputFeature(batchInferData [][]string) []*W2 // distInputs for i, inferTokens := range batchInferTokens { - // len(inferTokens) distInputs := generateInitDistInputs(len(inferTokens)) for j := 0; j < len(inferTokens); j++ { for k := 0; k < len(inferTokens); k++ {