diff --git a/README.md b/README.md index f008a68..5c04dc0 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,10 @@ func main() { ### Version +* version 1.4.6 - 2023/07/27 + * remove `github.com/goccy/go-json` and set `encoding/json` to default json marshal/unmarshal. + * add `JsonEncoder` and `JsonDecoder` API to adapt other json parser. + * version 1.4.5 - 2023/07/12 * update go.mod * fix Chinese tokenizer error diff --git a/go.mod b/go.mod index e577040..87d05b8 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/sunhailin-Leo/triton-service-go go 1.18 require ( - github.com/goccy/go-json v0.10.2 github.com/valyala/fasthttp v1.48.0 golang.org/x/text v0.11.0 google.golang.org/grpc v1.56.2 diff --git a/go.sum b/go.sum index 36853d0..58e5b17 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= diff --git a/models/bert/model.go b/models/bert/model.go index 75de48c..a65f0fd 100644 --- a/models/bert/model.go +++ b/models/bert/model.go @@ -5,12 +5,10 @@ import ( "strings" "time" - "github.com/goccy/go-json" - "github.com/valyala/fasthttp" - "google.golang.org/grpc" - "github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver" "github.com/sunhailin-Leo/triton-service-go/utils" + "github.com/valyala/fasthttp" + "google.golang.org/grpc" ) const ( @@ -25,7 +23,6 @@ const ( ) type ModelService struct { - // isTraceDuration bool isGRPC bool isChinese bool isChineseCharMode bool @@ -42,18 +39,6 @@ type ModelService struct { ////////////////////////////////////////////////// Flag Switch API ////////////////////////////////////////////////// -//// SetModelInferWithTrace Set model infer trace obj. -// func (m *ModelService) SetModelInferWithTrace() *ModelService { -// m.isTraceDuration = true -// return m -// } -// -//// UnsetModelInferWithTrace unset model infer trace obj. -// func (m *ModelService) UnsetModelInferWithTrace() *ModelService { -// m.isTraceDuration = false -// return m -// } - // SetMaxSeqLength Set model infer max sequence length. func (m *ModelService) SetMaxSeqLength(maxSeqLen int) *ModelService { m.maxSeqLength = maxSeqLen @@ -134,6 +119,18 @@ func (m *ModelService) SetSecondaryServerURL(url string) *ModelService { return m } +// SetJsonEncoder set json encoder +func (m *ModelService) SetJsonEncoder(encoder utils.JSONMarshal) *ModelService { + m.tritonService.SetJSONEncoder(encoder) + return m +} + +// SetJsonDecoder set json decoder +func (m *ModelService) SetJsonDecoder(decoder utils.JSONUnmarshal) *ModelService { + m.tritonService.SetJsonDecoder(decoder) + return m +} + ////////////////////////////////////////////////// Flag Switch API ////////////////////////////////////////////////// ///////////////////////////////////////// Bert Service Pre-Process Function ///////////////////////////////////////// @@ -265,7 +262,7 @@ func (m *ModelService) generateHTTPRequest( ) ([]byte, []*InputObjects, error) { // Generate batch request json body requestInputBody, modelInputObj := m.generateHTTPInputs(inferDataArr, inferInputs) - jsonBody, jsonEncodeErr := json.Marshal(&HTTPRequestBody{ + jsonBody, jsonEncodeErr := m.tritonService.JsonMarshal(&HTTPRequestBody{ Inputs: requestInputBody, Outputs: m.generateHTTPOutputs(inferOutputs), }) diff --git a/models/bert/tokenizer.go b/models/bert/tokenizer.go index 7678620..3196375 100644 --- a/models/bert/tokenizer.go +++ b/models/bert/tokenizer.go @@ -93,6 +93,7 @@ func NewBaseTokenizer(opts ...OptionV1) *BaseTokenizer { // The resulting tokens preserve the alignment with the portion of the original text they belong to. func (t *BaseTokenizer) Tokenize(text string) []StringOffsetsPair { splitTokens := make([]StringOffsetsPair, 0) + text = utils.Clean(text) spaceTokens := t.splitOn(text, utils.IsWhitespace, false) for i := range spaceTokens { diff --git a/nvidia_inferenceserver/triton_service_interface.go b/nvidia_inferenceserver/triton_service_interface.go index 801e6bd..a4a4772 100644 --- a/nvidia_inferenceserver/triton_service_interface.go +++ b/nvidia_inferenceserver/triton_service_interface.go @@ -2,11 +2,12 @@ package nvidia_inferenceserver import ( "context" + "encoding/json" "errors" "strconv" "time" - "github.com/goccy/go-json" + "github.com/sunhailin-Leo/triton-service-go/utils" "github.com/valyala/fasthttp" "google.golang.org/grpc" ) @@ -122,6 +123,12 @@ type TritonClientService struct { grpcConn *grpc.ClientConn grpcClient GRPCInferenceServiceClient httpClient *fasthttp.Client + + // Default: json.Marshal + JSONEncoder utils.JSONMarshal + + // Default: json.Unmarshal + JSONDecoder utils.JSONUnmarshal } // disconnectToTritonWithGRPC Disconnect GRPC Connection. @@ -247,6 +254,28 @@ func (t *TritonClientService) decodeFuncErrorHandler(err error, isGRPC bool) err ///////////////////////////////////////////// expose API below ///////////////////////////////////////////// +// JsonMarshal Json Encoder +func (t *TritonClientService) JsonMarshal(v interface{}) ([]byte, error) { + return t.JSONEncoder(v) +} + +// JsonUnmarshal Json Decoder +func (t *TritonClientService) JsonUnmarshal(data []byte, v interface{}) error { + return t.JSONDecoder(data, v) +} + +// SetJSONEncoder set json encoder +func (t *TritonClientService) SetJSONEncoder(encoder utils.JSONMarshal) *TritonClientService { + t.JSONEncoder = encoder + return t +} + +// SetJsonDecoder set json decoder +func (t *TritonClientService) SetJsonDecoder(decoder utils.JSONUnmarshal) *TritonClientService { + t.JSONDecoder = decoder + return t +} + // ModelHTTPInfer Call Triton Infer with HTTP. func (t *TritonClientService) ModelHTTPInfer( requestBody []byte, @@ -379,7 +408,7 @@ func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMeta return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } serverMetadataResponse := new(ServerMetadataResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &serverMetadataResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &serverMetadataResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return serverMetadataResponse, nil @@ -406,7 +435,7 @@ func (t *TritonClientService) ModelMetadataRequest( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } modelMetadataResponse := new(ModelMetadataResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelMetadataResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelMetadataResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return modelMetadataResponse, nil @@ -425,7 +454,7 @@ func (t *TritonClientService) ModelIndex( ctx, &RepositoryIndexRequest{RepositoryName: repoName, Ready: isReady}) return repositoryIndexResponse, t.grpcErrorHandler(modelIndexErr) } - reqBody, jsonEncodeErr := json.Marshal(&ModelIndexRequestHTTPObj{repoName, isReady}) + reqBody, jsonEncodeErr := t.JsonMarshal(&ModelIndexRequestHTTPObj{repoName, isReady}) if jsonEncodeErr != nil { return nil, jsonEncodeErr } @@ -435,7 +464,7 @@ func (t *TritonClientService) ModelIndex( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } repositoryIndexResponse := new(RepositoryIndexResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryIndexResponse.Models); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryIndexResponse.Models); jsonDecodeErr != nil { return nil, jsonDecodeErr } return repositoryIndexResponse, nil @@ -461,7 +490,7 @@ func (t *TritonClientService) ModelConfiguration( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } modelConfigResponse := new(ModelConfigResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelConfigResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelConfigResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return modelConfigResponse, nil @@ -487,7 +516,7 @@ func (t *TritonClientService) ModelInferStats( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } modelStatisticsResponse := new(ModelStatisticsResponse) - jsonDecodeErr := json.Unmarshal(apiResp.Body(), &modelStatisticsResponse) + jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &modelStatisticsResponse) if jsonDecodeErr != nil { return nil, jsonDecodeErr } @@ -507,7 +536,7 @@ func (t *TritonClientService) ModelLoadWithHTTP( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } repositoryModelLoadResponse := new(RepositoryModelLoadResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryModelLoadResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryModelLoadResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return repositoryModelLoadResponse, nil @@ -539,7 +568,7 @@ func (t *TritonClientService) ModelUnloadWithHTTP( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } repositoryModelUnloadResponse := new(RepositoryModelUnloadResponse) - jsonDecodeErr := json.Unmarshal(apiResp.Body(), &repositoryModelUnloadResponse) + jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &repositoryModelUnloadResponse) if jsonDecodeErr != nil { return nil, jsonDecodeErr } @@ -601,13 +630,13 @@ func (t *TritonClientService) ShareMemoryStatus( // Parse Response if isCUDA { cudaSharedMemoryStatusResponse := new(CudaSharedMemoryStatusResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryStatusResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryStatusResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return cudaSharedMemoryStatusResponse, nil } systemSharedMemoryStatusResponse := new(SystemSharedMemoryStatusResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryStatusResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryStatusResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return systemSharedMemoryStatusResponse, nil @@ -632,7 +661,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister( ) return cudaSharedMemoryRegisterResponse, t.grpcErrorHandler(registerErr) } - reqBody, jsonEncodeErr := json.Marshal( + reqBody, jsonEncodeErr := t.JsonMarshal( &CudaMemoryRegisterBodyHTTPObj{cudaRawHandle, cudaDeviceID, byteSize}) if jsonEncodeErr != nil { return nil, jsonEncodeErr @@ -644,7 +673,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } cudaSharedMemoryRegisterResponse := new(CudaSharedMemoryRegisterResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryRegisterResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryRegisterResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return cudaSharedMemoryRegisterResponse, nil @@ -670,7 +699,7 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } cudaSharedMemoryUnregisterResponse := new(CudaSharedMemoryUnregisterResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &cudaSharedMemoryUnregisterResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &cudaSharedMemoryUnregisterResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return cudaSharedMemoryUnregisterResponse, nil @@ -695,7 +724,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister( ) return systemSharedMemoryRegisterResponse, t.grpcErrorHandler(registerErr) } - reqBody, jsonEncodeErr := json.Marshal( + reqBody, jsonEncodeErr := t.JsonMarshal( &SystemMemoryRegisterBodyHTTPObj{cpuMemRegionKey, cpuMemOffset, byteSize}) if jsonEncodeErr != nil { return nil, jsonEncodeErr @@ -707,7 +736,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } systemSharedMemoryRegisterResponse := new(SystemSharedMemoryRegisterResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryRegisterResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryRegisterResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return systemSharedMemoryRegisterResponse, nil @@ -733,7 +762,7 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } systemSharedMemoryUnregisterResponse := new(SystemSharedMemoryUnregisterResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), &systemSharedMemoryUnregisterResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), &systemSharedMemoryUnregisterResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return systemSharedMemoryUnregisterResponse, nil @@ -759,7 +788,7 @@ func (t *TritonClientService) GetModelTracingSetting( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } traceSettingResponse := new(TraceSettingResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return traceSettingResponse, nil @@ -780,7 +809,7 @@ func (t *TritonClientService) SetModelTracingSetting( return traceSettingResponse, t.grpcErrorHandler(setTraceSettingErr) } // Experimental - reqBody, jsonEncodeErr := json.Marshal(&TraceSettingRequestHTTPObj{settingMap}) + reqBody, jsonEncodeErr := t.JsonMarshal(&TraceSettingRequestHTTPObj{settingMap}) if jsonEncodeErr != nil { return nil, jsonEncodeErr } @@ -791,7 +820,7 @@ func (t *TritonClientService) SetModelTracingSetting( return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr) } traceSettingResponse := new(TraceSettingResponse) - if jsonDecodeErr := json.Unmarshal(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil { + if jsonDecodeErr := t.JSONDecoder(apiResp.Body(), traceSettingResponse); jsonDecodeErr != nil { return nil, jsonDecodeErr } return traceSettingResponse, nil @@ -818,7 +847,7 @@ func (t *TritonClientService) ShutdownTritonConnection() (disconnectionErr error // NewTritonClientWithOnlyHTTP init triton client. func NewTritonClientWithOnlyHTTP(uri string, httpClient *fasthttp.Client) *TritonClientService { - client := &TritonClientService{serverURL: uri} + client := &TritonClientService{serverURL: uri, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal} client.setHTTPConnection(httpClient) return client } @@ -828,7 +857,12 @@ func NewTritonClientWithOnlyGRPC(grpcConn *grpc.ClientConn) *TritonClientService if grpcConn == nil { return nil } - client := &TritonClientService{grpcConn: grpcConn, grpcClient: NewGRPCInferenceServiceClient(grpcConn)} + client := &TritonClientService{ + grpcConn: grpcConn, + grpcClient: NewGRPCInferenceServiceClient(grpcConn), + JSONEncoder: json.Marshal, + JSONDecoder: json.Unmarshal, + } return client } @@ -837,9 +871,11 @@ func NewTritonClientForAll( httpServerURL string, httpClient *fasthttp.Client, grpcConn *grpc.ClientConn, ) *TritonClientService { client := &TritonClientService{ - serverURL: httpServerURL, - grpcConn: grpcConn, - grpcClient: NewGRPCInferenceServiceClient(grpcConn), + serverURL: httpServerURL, + grpcConn: grpcConn, + grpcClient: NewGRPCInferenceServiceClient(grpcConn), + JSONEncoder: json.Marshal, + JSONDecoder: json.Unmarshal, } client.setHTTPConnection(httpClient) diff --git a/test/bert_test.go b/test/bert_test.go index 0e7d6a8..0faf416 100644 --- a/test/bert_test.go +++ b/test/bert_test.go @@ -5,11 +5,10 @@ import ( "testing" "github.com/sunhailin-Leo/triton-service-go/models/bert" + "github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver" "github.com/valyala/fasthttp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - - "github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver" ) const ( diff --git a/test/triton_test.go b/test/triton_test.go index 9689b26..0584246 100644 --- a/test/triton_test.go +++ b/test/triton_test.go @@ -4,11 +4,10 @@ import ( "errors" "testing" + "github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver" "github.com/valyala/fasthttp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - - "github.com/sunhailin-Leo/triton-service-go/nvidia_inferenceserver" ) func TestTritonHTTPClientInit(_ *testing.T) { diff --git a/utils/json.go b/utils/json.go new file mode 100644 index 0000000..477c8c3 --- /dev/null +++ b/utils/json.go @@ -0,0 +1,9 @@ +package utils + +// JSONMarshal returns the JSON encoding of v. +type JSONMarshal func(v interface{}) ([]byte, error) + +// JSONUnmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an InvalidUnmarshalError. +type JSONUnmarshal func(data []byte, v interface{}) error