Skip to content

Commit

Permalink
Merge pull request #55 from sunhailin-Leo/dev
Browse files Browse the repository at this point in the history
Version 2.0.0
  • Loading branch information
sunhailin-Leo authored Mar 6, 2024
2 parents 12b356f + a8862a5 commit 12acc38
Show file tree
Hide file tree
Showing 25 changed files with 1,675 additions and 590 deletions.
1 change: 1 addition & 0 deletions .gitattritubes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* text=crlf
10 changes: 8 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ on:
branches:
- main
pull_request:

permissions:
contents: read
pull-requests: read
checks: write

jobs:
lint:
runs-on: ubuntu-latest
Expand All @@ -17,5 +23,5 @@ jobs:
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v4
with:
version: v1.54.2
args: --enable=nolintlint,gochecknoinits,bodyclose,gofumpt --verbose
version: v1.56.2
args: --enable=nolintlint,gochecknoinits,bodyclose,gocritic --verbose
2 changes: 1 addition & 1 deletion .github/workflows/sercurity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- name: Run Gosec Security Scanner
uses: securego/[email protected]
with:
args: '-exclude=G104,G304,G402 ./...'
args: '-exclude=G103,G104,G304,G402 ./...'
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
Build:
strategy:
matrix:
go-version: [1.19.x, 1.20.x, 1.21.x, 1.22.x]
go-version: [1.21.x, 1.22.x]
platform: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.idea
*.iws
*.iml
*.ipr
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Triton Inference Server - Golang API

### Feature

* Based On `Golang 1.18`
* Based On `Golang 1.21`
* Support HTTP/GRPC
* Easy to use it
* Maybe High Performance
Expand All @@ -46,12 +46,11 @@ import (
"fmt"
"time"

"github.com/valyala/fasthttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/sunhailin-Leo/triton-service-go/models/bert"
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
"github.com/sunhailin-Leo/triton-service-go/models/transformers"
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
"github.com/valyala/fasthttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

const (
Expand All @@ -67,22 +66,19 @@ const (
)

// testGenerateModelInferRequest Triton Input
func testGenerateModelInferRequest(batchSize, maxSeqLength int) []*nvidia_inferenceserver.ModelInferRequest_InferInputTensor {
func testGenerateModelInferRequest() []*nvidia_inferenceserver.ModelInferRequest_InferInputTensor {
return []*nvidia_inferenceserver.ModelInferRequest_InferInputTensor{
{
Name: tBertModelSegmentIdsKey,
Datatype: tBertModelSegmentIdsDataType,
Shape: []int64{int64(batchSize), int64(maxSeqLength)},
},
{
Name: tBertModelInputIdsKey,
Datatype: tBertModelInputIdsDataType,
Shape: []int64{int64(batchSize), int64(maxSeqLength)},
},
{
Name: tBertModelInputMaskKey,
Datatype: tBertModelInputMaskDataType,
Shape: []int64{int64(batchSize), int64(maxSeqLength)},
},
}
}
Expand Down Expand Up @@ -124,13 +120,13 @@ func main() {
}

// Service
bertService, initErr := bert.NewModelService(
bertService, initErr := bert.NewBertModelService(
vocabPath, httpAddr, defaultHttpClient, defaultGRPCClient,
testGenerateModelInferRequest, testGenerateModelInferOutputRequest, testModerInferCallback)
if initErr != nil {
panic(initErr)
}
bertService = bertService.SetChineseTokenize(false).SetMaxSeqLength(maxSeqLen)
bertService.SetChineseTokenize(false).SetMaxSeqLength(maxSeqLen)
// infer
inferResultV1, inferErr := bertService.ModelInfer([]string{"<Data>"}, "<Model Name>", "<Model Version>", 1*time.Second)
if inferErr != nil {
Expand All @@ -144,6 +140,12 @@ func main() {

### Version

* version 2.0.0 - Coming soon
* **No longer compatible with Go version 1.18, 1.19, 1.20**
* refactor `models` package and rename package from `bert` to `transformers`.
* **Incompatible with previous versions, calls require simple modifications**
* Add `W2NER` model(Based on Bert, but used for NER tasks)

* version 1.4.6 - 2023/07/27
* remove `github.com/goccy/go-json` and set `encoding/json` to default json marshal/unmarshal.
* add `JsonEncoder` and `JsonDecoder` API to adapt other json parser.
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/sunhailin-Leo/triton-service-go

go 1.19
go 1.21

require (
github.com/valyala/fasthttp v1.52.0
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
Expand Down
184 changes: 184 additions & 0 deletions models/base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package models

import (
"time"

"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
"github.com/sunhailin-Leo/triton-service-go/utils"
)

// GenerateModelInferRequest model input callback.
type GenerateModelInferRequest func() []*nvidia_inferenceserver.ModelInferRequest_InferInputTensor

// GenerateModelInferOutputRequest model output callback.
type GenerateModelInferOutputRequest func(params ...interface{}) []*nvidia_inferenceserver.ModelInferRequest_InferRequestedOutputTensor

type ModelService struct {
IsGRPC bool
IsChinese bool
IsChineseCharMode bool
IsReturnPosArray bool
MaxSeqLength int
ModelName string
TritonService *nvidia_inferenceserver.TritonClientService
InferCallback nvidia_inferenceserver.DecoderFunc
GenerateModelInferRequest GenerateModelInferRequest
GenerateModelInferOutputRequest GenerateModelInferOutputRequest
}

////////////////////////////////////////////////// Flag Switch API //////////////////////////////////////////////////

// SetMaxSeqLength Set model infer max sequence length.
func (m *ModelService) SetMaxSeqLength(maxSeqLen int) *ModelService {
m.MaxSeqLength = maxSeqLen

return m
}

// SetChineseTokenize Use Chinese Tokenize when tokenize infer data.
func (m *ModelService) SetChineseTokenize(isCharMode bool) *ModelService {
m.IsChinese = true
m.IsChineseCharMode = isCharMode

return m
}

// UnsetChineseTokenize Un-use Chinese Tokenize when tokenize infer data.
func (m *ModelService) UnsetChineseTokenize() *ModelService {
m.IsChinese = false
m.IsChineseCharMode = false

return m
}

// SetModelInferWithGRPC Use grpc to call triton.
func (m *ModelService) SetModelInferWithGRPC() *ModelService {
m.IsGRPC = true

return m
}

// UnsetModelInferWithGRPC Un-use grpc to call triton.
func (m *ModelService) UnsetModelInferWithGRPC() *ModelService {
m.IsGRPC = false

return m
}

// GetModelInferIsGRPC Get isGRPC flag.
func (m *ModelService) GetModelInferIsGRPC() bool {
return m.IsGRPC
}

// GetTokenizerIsChineseMode Get isChinese flag.
func (m *ModelService) GetTokenizerIsChineseMode() bool {
return m.IsChinese
}

// SetTokenizerReturnPosInfo Set tokenizer return pos info.
func (m *ModelService) SetTokenizerReturnPosInfo() *ModelService {
m.IsReturnPosArray = true

return m
}

// UnsetTokenizerReturnPosInfo Un-set tokenizer return pos info.
func (m *ModelService) UnsetTokenizerReturnPosInfo() *ModelService {
m.IsReturnPosArray = false

return m
}

// SetModelName Set model name must equal to Triton config.pbtxt model name.
func (m *ModelService) SetModelName(modelPrefix, modelName string) *ModelService {
m.ModelName = modelPrefix + "-" + modelName

return m
}

func (m *ModelService) SetModelNameWithoutDash(modelName string) *ModelService {
m.ModelName = modelName

return m
}

// GetModelName Get model name.
func (m *ModelService) GetModelName() string { return m.ModelName }

// SetSecondaryServerURL set secondary server url【Only HTTP】
func (m *ModelService) SetSecondaryServerURL(url string) *ModelService {
if m.TritonService != nil {
m.TritonService.SetSecondaryServerURL(url)
}
return m
}

// SetJsonEncoder set json encoder
func (m *ModelService) SetJsonEncoder(encoder utils.JSONMarshal) *ModelService {
m.TritonService.SetJSONEncoder(encoder)
return m
}

// SetJsonDecoder set json decoder
func (m *ModelService) SetJsonDecoder(decoder utils.JSONUnmarshal) *ModelService {
m.TritonService.SetJsonDecoder(decoder)
return m
}

////////////////////////////////////////////////// Flag Switch API //////////////////////////////////////////////////

//////////////////////////////////////////// Triton Service API Function ////////////////////////////////////////////

// CheckServerReady check server is ready.
func (m *ModelService) CheckServerReady(requestTimeout time.Duration) (bool, error) {
return m.TritonService.CheckServerReady(requestTimeout)
}

// CheckServerAlive check server is alive.
func (m *ModelService) CheckServerAlive(requestTimeout time.Duration) (bool, error) {
return m.TritonService.CheckServerAlive(requestTimeout)
}

// CheckModelReady check model is ready.
func (m *ModelService) CheckModelReady(
modelName, modelVersion string, requestTimeout time.Duration,
) (bool, error) {
return m.TritonService.CheckModelReady(modelName, modelVersion, requestTimeout)
}

// GetServerMeta get server meta.
func (m *ModelService) GetServerMeta(
requestTimeout time.Duration,
) (*nvidia_inferenceserver.ServerMetadataResponse, error) {
return m.TritonService.ServerMetadata(requestTimeout)
}

// GetModelMeta get model meta.
func (m *ModelService) GetModelMeta(
modelName, modelVersion string, requestTimeout time.Duration,
) (*nvidia_inferenceserver.ModelMetadataResponse, error) {
return m.TritonService.ModelMetadataRequest(modelName, modelVersion, requestTimeout)
}

// GetAllModelInfo get all model info.
func (m *ModelService) GetAllModelInfo(
repoName string, isReady bool, requestTimeout time.Duration,
) (*nvidia_inferenceserver.RepositoryIndexResponse, error) {
return m.TritonService.ModelIndex(repoName, isReady, requestTimeout)
}

// GetModelConfig get model config.
func (m *ModelService) GetModelConfig(
modelName, modelVersion string, requestTimeout time.Duration,
) (interface{}, error) {
return m.TritonService.ModelConfiguration(modelName, modelVersion, requestTimeout)
}

// GetModelInferStats get model infer stats.
func (m *ModelService) GetModelInferStats(
modelName, modelVersion string, requestTimeout time.Duration,
) (*nvidia_inferenceserver.ModelStatisticsResponse, error) {
return m.TritonService.ModelInferStats(modelName, modelVersion, requestTimeout)
}

//////////////////////////////////////////// Triton Service API Function ////////////////////////////////////////////
Loading

0 comments on commit 12acc38

Please sign in to comment.