Skip to content

Commit

Permalink
version 1.4.6
Browse files Browse the repository at this point in the history
  • Loading branch information
sunhailin-Leo committed Jul 27, 2023
1 parent e3d13a2 commit cc518ce
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 50 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ func main() {

### Version

* version 1.4.6 - 2023/07/27
* remove `github.com/goccy/go-json` and set `encoding/json` to default json marshal/unmarshal.
* add `JsonEncoder` and `JsonDecoder` API to adapt other json parser.

* version 1.4.5 - 2023/07/12
* update go.mod
* fix Chinese tokenizer error
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module github.com/sunhailin-Leo/triton-service-go
go 1.18

require (
github.com/goccy/go-json v0.10.2
github.com/valyala/fasthttp v1.48.0
golang.org/x/text v0.11.0
google.golang.org/grpc v1.56.2
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
Expand Down
33 changes: 15 additions & 18 deletions models/bert/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ import (
"strings"
"time"

"github.com/goccy/go-json"
"github.com/valyala/fasthttp"
"google.golang.org/grpc"

"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
"github.com/sunhailin-Leo/triton-service-go/utils"
"github.com/valyala/fasthttp"
"google.golang.org/grpc"
)

const (
Expand All @@ -25,7 +23,6 @@ const (
)

type ModelService struct {
// isTraceDuration bool
isGRPC bool
isChinese bool
isChineseCharMode bool
Expand All @@ -42,18 +39,6 @@ type ModelService struct {

////////////////////////////////////////////////// Flag Switch API //////////////////////////////////////////////////

//// SetModelInferWithTrace Set model infer trace obj.
// func (m *ModelService) SetModelInferWithTrace() *ModelService {
// m.isTraceDuration = true
// return m
// }
//
//// UnsetModelInferWithTrace unset model infer trace obj.
// func (m *ModelService) UnsetModelInferWithTrace() *ModelService {
// m.isTraceDuration = false
// return m
// }

// SetMaxSeqLength Set model infer max sequence length.
func (m *ModelService) SetMaxSeqLength(maxSeqLen int) *ModelService {
m.maxSeqLength = maxSeqLen
Expand Down Expand Up @@ -134,6 +119,18 @@ func (m *ModelService) SetSecondaryServerURL(url string) *ModelService {
return m
}

// SetJsonEncoder set json encoder
func (m *ModelService) SetJsonEncoder(encoder utils.JSONMarshal) *ModelService {
m.tritonService.SetJSONEncoder(encoder)
return m
}

// SetJsonDecoder set json decoder
func (m *ModelService) SetJsonDecoder(decoder utils.JSONUnmarshal) *ModelService {
m.tritonService.SetJsonDecoder(decoder)
return m
}

////////////////////////////////////////////////// Flag Switch API //////////////////////////////////////////////////

///////////////////////////////////////// Bert Service Pre-Process Function /////////////////////////////////////////
Expand Down Expand Up @@ -265,7 +262,7 @@ func (m *ModelService) generateHTTPRequest(
) ([]byte, []*InputObjects, error) {
// Generate batch request json body
requestInputBody, modelInputObj := m.generateHTTPInputs(inferDataArr, inferInputs)
jsonBody, jsonEncodeErr := json.Marshal(&HTTPRequestBody{
jsonBody, jsonEncodeErr := m.tritonService.JsonMarshal(&HTTPRequestBody{
Inputs: requestInputBody,
Outputs: m.generateHTTPOutputs(inferOutputs),
})
Expand Down
1 change: 1 addition & 0 deletions models/bert/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func NewBaseTokenizer(opts ...OptionV1) *BaseTokenizer {
// The resulting tokens preserve the alignment with the portion of the original text they belong to.
func (t *BaseTokenizer) Tokenize(text string) []StringOffsetsPair {
splitTokens := make([]StringOffsetsPair, 0)
text = utils.Clean(text)
spaceTokens := t.splitOn(text, utils.IsWhitespace, false)

for i := range spaceTokens {
Expand Down
86 changes: 61 additions & 25 deletions nvidia_inferenceserver/triton_service_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package nvidia_inferenceserver

import (
"context"
"encoding/json"
"errors"
"strconv"
"time"

"github.com/goccy/go-json"
"github.com/sunhailin-Leo/triton-service-go/utils"
"github.com/valyala/fasthttp"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -122,6 +123,12 @@ type TritonClientService struct {
grpcConn *grpc.ClientConn
grpcClient GRPCInferenceServiceClient
httpClient *fasthttp.Client

// Default: json.Marshal
JSONEncoder utils.JSONMarshal

// Default: json.Unmarshal
JSONDecoder utils.JSONUnmarshal
}

// disconnectToTritonWithGRPC Disconnect GRPC Connection.
Expand Down Expand Up @@ -247,6 +254,28 @@ func (t *TritonClientService) decodeFuncErrorHandler(err error, isGRPC bool) err

///////////////////////////////////////////// expose API below /////////////////////////////////////////////

// JsonMarshal Json Encoder
func (t *TritonClientService) JsonMarshal(v interface{}) ([]byte, error) {
return t.JSONEncoder(v)
}

// JsonUnmarshal Json Decoder
func (t *TritonClientService) JsonUnmarshal(data []byte, v interface{}) error {
return t.JSONDecoder(data, v)
}

// SetJSONEncoder set json encoder
func (t *TritonClientService) SetJSONEncoder(encoder utils.JSONMarshal) *TritonClientService {
t.JSONEncoder = encoder
return t
}

// SetJsonDecoder set json decoder
func (t *TritonClientService) SetJsonDecoder(decoder utils.JSONUnmarshal) *TritonClientService {
t.JSONDecoder = decoder
return t
}

// ModelHTTPInfer Call Triton Infer with HTTP.
func (t *TritonClientService) ModelHTTPInfer(
requestBody []byte,
Expand Down Expand Up @@ -379,7 +408,7 @@ func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMeta
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
serverMetadataResponse := new(ServerMetadataResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &serverMetadataResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &serverMetadataResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return serverMetadataResponse, nil
Expand All @@ -406,7 +435,7 @@ func (t *TritonClientService) ModelMetadataRequest(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
modelMetadataResponse := new(ModelMetadataResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelMetadataResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelMetadataResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return modelMetadataResponse, nil
Expand All @@ -425,7 +454,7 @@ func (t *TritonClientService) ModelIndex(
ctx, &RepositoryIndexRequest{RepositoryName: repoName, Ready: isReady})
return repositoryIndexResponse, t.grpcErrorHandler(modelIndexErr)
}
reqBody, jsonEncodeErr := json.Marshal(&ModelIndexRequestHTTPObj{repoName, isReady})
reqBody, jsonEncodeErr := t.JsonMarshal(&ModelIndexRequestHTTPObj{repoName, isReady})
if jsonEncodeErr != nil {
return nil, jsonEncodeErr
}
Expand All @@ -435,7 +464,7 @@ func (t *TritonClientService) ModelIndex(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
repositoryIndexResponse := new(RepositoryIndexResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryIndexResponse.Models); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryIndexResponse.Models); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return repositoryIndexResponse, nil
Expand All @@ -461,7 +490,7 @@ func (t *TritonClientService) ModelConfiguration(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
modelConfigResponse := new(ModelConfigResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelConfigResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelConfigResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return modelConfigResponse, nil
Expand All @@ -487,7 +516,7 @@ func (t *TritonClientService) ModelInferStats(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
modelStatisticsResponse := new(ModelStatisticsResponse)
jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelStatisticsResponse)
jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelStatisticsResponse)
if jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
Expand All @@ -507,7 +536,7 @@ func (t *TritonClientService) ModelLoadWithHTTP(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
repositoryModelLoadResponse := new(RepositoryModelLoadResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryModelLoadResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryModelLoadResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return repositoryModelLoadResponse, nil
Expand Down Expand Up @@ -539,7 +568,7 @@ func (t *TritonClientService) ModelUnloadWithHTTP(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
repositoryModelUnloadResponse := new(RepositoryModelUnloadResponse)
jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryModelUnloadResponse)
jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryModelUnloadResponse)
if jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
Expand Down Expand Up @@ -601,13 +630,13 @@ func (t *TritonClientService) ShareMemoryStatus(
// Parse Response
if isCUDA {
cudaSharedMemoryStatusResponse := new(CudaSharedMemoryStatusResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryStatusResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryStatusResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return cudaSharedMemoryStatusResponse, nil
}
systemSharedMemoryStatusResponse := new(SystemSharedMemoryStatusResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryStatusResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryStatusResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return systemSharedMemoryStatusResponse, nil
Expand All @@ -632,7 +661,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
)
return cudaSharedMemoryRegisterResponse, t.grpcErrorHandler(registerErr)
}
reqBody, jsonEncodeErr := json.Marshal(
reqBody, jsonEncodeErr := t.JsonMarshal(
&CudaMemoryRegisterBodyHTTPObj{cudaRawHandle, cudaDeviceID, byteSize})
if jsonEncodeErr != nil {
return nil, jsonEncodeErr
Expand All @@ -644,7 +673,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
cudaSharedMemoryRegisterResponse := new(CudaSharedMemoryRegisterResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryRegisterResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryRegisterResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return cudaSharedMemoryRegisterResponse, nil
Expand All @@ -670,7 +699,7 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
cudaSharedMemoryUnregisterResponse := new(CudaSharedMemoryUnregisterResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return cudaSharedMemoryUnregisterResponse, nil
Expand All @@ -695,7 +724,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
)
return systemSharedMemoryRegisterResponse, t.grpcErrorHandler(registerErr)
}
reqBody, jsonEncodeErr := json.Marshal(
reqBody, jsonEncodeErr := t.JsonMarshal(
&SystemMemoryRegisterBodyHTTPObj{cpuMemRegionKey, cpuMemOffset, byteSize})
if jsonEncodeErr != nil {
return nil, jsonEncodeErr
Expand All @@ -707,7 +736,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
systemSharedMemoryRegisterResponse := new(SystemSharedMemoryRegisterResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryRegisterResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryRegisterResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return systemSharedMemoryRegisterResponse, nil
Expand All @@ -733,7 +762,7 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
systemSharedMemoryUnregisterResponse := new(SystemSharedMemoryUnregisterResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryUnregisterResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return systemSharedMemoryUnregisterResponse, nil
Expand All @@ -759,7 +788,7 @@ func (t *TritonClientService) GetModelTracingSetting(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
traceSettingResponse := new(TraceSettingResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return traceSettingResponse, nil
Expand All @@ -780,7 +809,7 @@ func (t *TritonClientService) SetModelTracingSetting(
return traceSettingResponse, t.grpcErrorHandler(setTraceSettingErr)
}
// Experimental
reqBody, jsonEncodeErr := json.Marshal(&TraceSettingRequestHTTPObj{settingMap})
reqBody, jsonEncodeErr := t.JsonMarshal(&TraceSettingRequestHTTPObj{settingMap})
if jsonEncodeErr != nil {
return nil, jsonEncodeErr
}
Expand All @@ -791,7 +820,7 @@ func (t *TritonClientService) SetModelTracingSetting(
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
traceSettingResponse := new(TraceSettingResponse)
if jsonDecodeErr := json.Unmarshal(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil {
return nil, jsonDecodeErr
}
return traceSettingResponse, nil
Expand All @@ -818,7 +847,7 @@ func (t *TritonClientService) ShutdownTritonConnection() (disconnectionErr error

// NewTritonClientWithOnlyHTTP init triton client.
func NewTritonClientWithOnlyHTTP(uri string, httpClient *fasthttp.Client) *TritonClientService {
client := &TritonClientService{serverURL: uri}
client := &TritonClientService{serverURL: uri, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal}
client.setHTTPConnection(httpClient)
return client
}
Expand All @@ -828,7 +857,12 @@ func NewTritonClientWithOnlyGRPC(grpcConn *grpc.ClientConn) *TritonClientService
if grpcConn == nil {
return nil
}
client := &TritonClientService{grpcConn: grpcConn, grpcClient: NewGRPCInferenceServiceClient(grpcConn)}
client := &TritonClientService{
grpcConn: grpcConn,
grpcClient: NewGRPCInferenceServiceClient(grpcConn),
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
}
return client
}

Expand All @@ -837,9 +871,11 @@ func NewTritonClientForAll(
httpServerURL string, httpClient *fasthttp.Client, grpcConn *grpc.ClientConn,
) *TritonClientService {
client := &TritonClientService{
serverURL: httpServerURL,
grpcConn: grpcConn,
grpcClient: NewGRPCInferenceServiceClient(grpcConn),
serverURL: httpServerURL,
grpcConn: grpcConn,
grpcClient: NewGRPCInferenceServiceClient(grpcConn),
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
}
client.setHTTPConnection(httpClient)

Expand Down
3 changes: 1 addition & 2 deletions test/bert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ import (
"testing"

"github.com/sunhailin-Leo/triton-service-go/models/bert"
"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
"github.com/valyala/fasthttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
)

const (
Expand Down
3 changes: 1 addition & 2 deletions test/triton_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"errors"
"testing"

"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
"github.com/valyala/fasthttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver"
)

func TestTritonHTTPClientInit(_ *testing.T) {
Expand Down
Loading

0 comments on commit cc518ce

Please sign in to comment.