Skip to content

Commit

Permalink
fix data error with pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
sunhailin committed Mar 4, 2024
1 parent 1e14372 commit fab4a50
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions models/transformers/bert_w2ner.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ func generateInitDistInputs(size int) [][]int32 {
return matrix
}

func generateAllTrueSlice(size int) []bool {
matrix := make([]bool, size)
for i := 0; i < size; i++ {
matrix[i] = true
}

return matrix
}

///////////////////////////////////////// Bert Service Pre-Process Function /////////////////////////////////////////

// getBertInputFeature Get Bert-W2NER Feature (before Make HTTP or GRPC Request).
Expand Down Expand Up @@ -100,9 +109,9 @@ func (w *W2NerModelService) getBertInputFeature(batchInferData [][]string) []*W2

start := 0
for j := 0; j < padTokenLen; j++ {
gridMask2d[j] = make([]bool, padTokenLen)
gridMask2d[j] = generateAllTrueSlice(padTokenLen)

if j+1 >= len(inferTokens) {
if j >= len(inferTokens) {
pieces2word[j] = make([]bool, len(batchInputFeatures[i].TokenIDs))
continue
}
Expand All @@ -127,7 +136,6 @@ func (w *W2NerModelService) getBertInputFeature(batchInferData [][]string) []*W2

// distInputs
for i, inferTokens := range batchInferTokens {
// len(inferTokens)
distInputs := generateInitDistInputs(len(inferTokens))
for j := 0; j < len(inferTokens); j++ {
for k := 0; k < len(inferTokens); k++ {
Expand Down

0 comments on commit fab4a50

Please sign in to comment.