Skip to content

Commit

Permalink
Sax client api for video_to_token in python, cc, and go.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657740442
Change-Id: I120fa174fa66addf219b01ddc1a25bd86124adc2
  • Loading branch information
bignamehyp authored and copybara-github committed Jul 30, 2024
1 parent af3f4e8 commit ac13ce8
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 1 deletion.
46 changes: 46 additions & 0 deletions saxml/client/cc/sax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using ::sax::server::vision::ImageToTextResponse;
using ::sax::server::vision::TextAndImageToImageResponse;
using ::sax::server::vision::TextToImageResponse;
using ::sax::server::vision::VideoToTextResponse;
using ::sax::server::vision::VideoToTokenResponse;
using VmEmbedResponse = ::sax::server::vision::EmbedResponse;
using AsrResponse = ::sax::server::audio::AsrResponse;

Expand Down Expand Up @@ -913,6 +914,51 @@ absl::Status VisionModel::VideoToText(
return absl::OkStatus();
}

absl::Status VisionModel::VideoToToken(
const std::vector<absl::string_view>& image_frames,
std::vector<double>* tokens) const {
return VisionModel::VideoToToken(ModelOptions(), image_frames, tokens);
}

absl::Status VisionModel::VideoToToken(
const ModelOptions& options,
const std::vector<absl::string_view>& image_frames,
std::vector<double>* tokens) const {
ExtraInputs extra;
options.ToProto(&extra);
std::string extraStr = "";
extra.SerializeToString(&extraStr);

char* outputStr = nullptr;
int outputSize = 0;
char* errMsgStr = nullptr;
int errCode = 0;
std::vector<int> frame_sizes;
std::vector<char*> frame_buffers;
for (const auto& frame : image_frames) {
frame_sizes.push_back(frame.size());
frame_buffers.push_back(const_cast<char*>(frame.data()));
}
go_vm_video_to_token(model_handle_, options.GetTimeout(),
const_cast<char**>(frame_buffers.data()),
const_cast<int*>(frame_sizes.data()),
image_frames.size(), const_cast<char*>(extraStr.data()),
extraStr.size(), &outputStr, &outputSize, &errMsgStr,
&errCode);
if (errCode != 0) {
return CreateErrorAndFree(errCode, errMsgStr);
}
VideoToTokenResponse output;
if (outputStr != nullptr) {
output.ParseFromArray(outputStr, outputSize);
free(outputStr);
}
for (const auto& token : output.tokens()) {
tokens->push_back(token);
}
return absl::OkStatus();
}

VisionModel::~VisionModel() { go_release_model(model_handle_); }

absl::Status Model::Open(absl::string_view id, const Options* options,
Expand Down
6 changes: 6 additions & 0 deletions saxml/client/cc/sax.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ class VisionModel {
absl::string_view text,
std::vector<ScoredText>* result) const;

absl::Status VideoToToken(const std::vector<absl::string_view>& image_frames,
std::vector<double>* tokens) const;
absl::Status VideoToToken(const ModelOptions& options,
const std::vector<absl::string_view>& image_frames,
std::vector<double>* tokens) const;

private:
explicit VisionModel(int64_t model_handle) : model_handle_(model_handle) {}
friend class Model;
Expand Down
52 changes: 52 additions & 0 deletions saxml/client/cc/saxwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,58 @@ func go_vm_video_to_text(ptr C.long, timeout C.float, imageFramesData **C.char,
buildReturnValues(outData, outSize, errMsg, errCode, &content, nil)
}

//export go_vm_video_to_token
func go_vm_video_to_token(ptr C.long, timeout C.float, imageFramesData **C.char,
perFrameSizes *C.int,
numFrames C.int,
optionsData *C.char, optionsSize C.int, outData **C.char, outSize *C.int, errMsg **C.char, errCode *C.int) {
vm := rcgo.Handle(ptr).Value().(*sax.VisionModel)
if vm == nil {
// This is not expected.
log.Fatalf("video_to_token() called on nil vision model.")
}

optionsByte := C.GoBytes(unsafe.Pointer(optionsData), optionsSize)
options := &cpb.ExtraInputs{}
if err := proto.Unmarshal(optionsByte, options); err != nil {
buildReturnValues(outData, outSize, errMsg, errCode, nil, err)
return
}
imageFrames := [][]byte{}
framePtrStart := unsafe.Pointer(imageFramesData)
frameSizesPtrStart := unsafe.Pointer(perFrameSizes)
framePtrSize := unsafe.Sizeof(*imageFramesData)
frameSizesPtrSize := unsafe.Sizeof(*perFrameSizes)
for i := 0; i < int(numFrames); i++ {
framePtr := uintptr(framePtrStart) + uintptr(i)*(framePtrSize)
frameSizePtr := uintptr(frameSizesPtrStart) + uintptr(i)*(frameSizesPtrSize)
frameSize := *((*C.int)(unsafe.Pointer(frameSizePtr)))
frame := *((**C.char)(unsafe.Pointer(framePtr)))
frameBytes := C.GoBytes(unsafe.Pointer(frame), frameSize)
imageFrames = append(imageFrames, frameBytes)
}
ctx, cancel := createContextWithTimeout(timeout)
if cancel != nil {
defer cancel()
}

res, err := vm.VideoToToken(ctx, imageFrames, protoOptionToSetter(options)...)
if err != nil {
buildReturnValues(outData, outSize, errMsg, errCode, nil, err)
return
}

ret := &vmpb.VideoToTokenResponse{
Tokens: res,
}
content, err := proto.Marshal(ret)
if err != nil {
buildReturnValues(outData, outSize, errMsg, errCode, nil, err)
return
}
buildReturnValues(outData, outSize, errMsg, errCode, &content, nil)
}

//export go_custom
func go_custom(ptr C.long, timeout C.float, requestData unsafe.Pointer, requestSize C.int, methodNameData *C.char, methodNameSize C.int, optionsData *C.char, optionsSize C.int, outData **C.char, outSize *C.int, errMsg **C.char, errCode *C.int) {
custom := rcgo.Handle(ptr).Value().(*sax.CustomModel)
Expand Down
23 changes: 23 additions & 0 deletions saxml/client/go/sax_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,26 @@ func (v *VisionModel) VideoToText(ctx context.Context, imageFrames [][]byte, tex
res := extractVideoToTextResponse(resp)
return res, nil
}

// VideoToToken performs tokenization for multiple image frames against a vision model.
// Specifically:
// - 'tokens' is a list of token with type float64.
func (v *VisionModel) VideoToToken(ctx context.Context, imageFrames [][]byte, options ...ModelOptionSetter) ([]float64, error) {
opts := NewModelOptions(options...)
req := &pb.VideoToTokenRequest{
ModelKey: v.model.modelID,
ImageFrames: imageFrames,
ExtraInputs: opts.ExtraInputs(),
}

var resp *pb.VideoToTokenResponse
err := v.model.run(ctx, "VideoToToken", func(conn *grpc.ClientConn) error {
var videoToTokenErr error
resp, videoToTokenErr = pbgrpc.NewVisionServiceClient(conn).VideoToToken(ctx, req)
return videoToTokenErr
})
if err != nil {
return nil, err
}
return resp.GetTokens(), nil
}
11 changes: 10 additions & 1 deletion saxml/client/python/sax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,16 @@ PYBIND11_MODULE(sax, m) {
return vm.VideoToText(image_frames, text, options);
},
py::arg("image_frames"), py::arg("text") = "",
py::arg("options") = nullptr);
py::arg("options") = nullptr)
.def(
"VideoToToken",
[](sax::client::pybind::VisionModel& vm,
const std::vector<absl::string_view>& image_frames,
const sax::client::ModelOptions* options)
-> absl::StatusOr<std::vector<double>> {
return vm.VideoToToken(image_frames, options);
},
py::arg("image_frames"), py::arg("options") = nullptr);

py::class_<sax::client::pybind::Model>(m, "Model")
.def(py::init<absl::string_view, const sax::client::Options*>())
Expand Down
1 change: 1 addition & 0 deletions saxml/client/python/sax.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class VisionModel:
def TextAndImageToImage(self, text: str, image_bytes: str, options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def TextToImage(self, text: str, options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def VideoToText(self, image_frames: list[str], text: str = ..., options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def VideoToToken(self, image_frames: list[str], options: ModelOptions = ...) -> list[float]: ...

def List(id: str, options: AdminOptions = ...) -> tuple[str,str,int]: ...
def ListAll(id: str, options: AdminOptions = ...) -> list[str]: ...
Expand Down
16 changes: 16 additions & 0 deletions saxml/client/python/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,22 @@ VisionModel::VideoToText(const std::vector<absl::string_view>& image_frames,
return result;
}

absl::StatusOr<std::vector<double>> VisionModel::VideoToToken(
const std::vector<absl::string_view>& image_frames,
const ModelOptions* options) const {
if (!status_.ok()) return status_;
std::vector<double> result;
{
pybind11::gil_scoped_release release;
if (options == nullptr) {
RETURN_IF_ERROR(model_->VideoToToken(image_frames, &result));
} else {
RETURN_IF_ERROR(model_->VideoToToken(*options, image_frames, &result));
}
}
return result;
}

Model::Model(absl::string_view id, const Options* options) {
status_ = ::sax::client::Model::Open(id, options, &base_);
}
Expand Down
7 changes: 7 additions & 0 deletions saxml/client/python/wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ class VisionModel {
const std::vector<absl::string_view>& image_frames,
absl::string_view text = "", const ModelOptions* options = nullptr) const;

// VideoToToken produces a list of tokens given 'image_frames'.
//
// Returns a vector of tokens.
absl::StatusOr<std::vector<double>> VideoToToken(
const std::vector<absl::string_view>& image_frames,
const ModelOptions* options = nullptr) const;

private:
explicit VisionModel(::sax::client::Model* base, const absl::Status& status);
::sax::client::Model* base_ = nullptr;
Expand Down

0 comments on commit ac13ce8

Please sign in to comment.