Skip to content

Commit

Permalink
Add vm.token_to_video API to SAX.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660534687
Change-Id: I3fce0a302aac8d822b8e4b292ba8382542311df5
  • Loading branch information
bignamehyp authored and copybara-github committed Aug 7, 2024
1 parent 3b623b8 commit 5b61e5e
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 138 deletions.
43 changes: 43 additions & 0 deletions saxml/client/cc/sax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ using ::sax::server::vision::ImageToImageResponse;
using ::sax::server::vision::ImageToTextResponse;
using ::sax::server::vision::TextAndImageToImageResponse;
using ::sax::server::vision::TextToImageResponse;
using ::sax::server::vision::TokenToVideoResponse;
using ::sax::server::vision::VideoToTextResponse;
using ::sax::server::vision::VideoToTokenResponse;
using VmEmbedResponse = ::sax::server::vision::EmbedResponse;
Expand Down Expand Up @@ -959,6 +960,48 @@ absl::Status VisionModel::VideoToToken(
return absl::OkStatus();
}

absl::Status VisionModel::TokenToVideo(
const std::vector<double>& tokens,
std::vector<absl::string_view>* image_frames) const {
return VisionModel::TokenToVideo(ModelOptions(), tokens, image_frames);
}

absl::Status VisionModel::TokenToVideo(
const ModelOptions& options, const std::vector<double>& tokens,
std::vector<absl::string_view>* image_frames) const {
ExtraInputs extra;
options.ToProto(&extra);
std::string extraStr = "";
extra.SerializeToString(&extraStr);

char* outputStr = nullptr;
int outputSize = 0;
char* errMsgStr = nullptr;
int errCode = 0;
std::vector<char*> token_buffers;
for (const auto& token : tokens) {
token_buffers.push_back(
const_cast<char*>(reinterpret_cast<const char*>(&token)));
}
go_vm_token_to_video(model_handle_, options.GetTimeout(),
const_cast<char**>(token_buffers.data()), tokens.size(),
const_cast<char*>(extraStr.data()), extraStr.size(),
&outputStr, &outputSize, &errMsgStr, &errCode);
if (errCode != 0) {
return CreateErrorAndFree(errCode, errMsgStr);
}
TokenToVideoResponse output;
if (outputStr != nullptr) {
output.ParseFromArray(outputStr, outputSize);
free(outputStr);
}
image_frames->reserve(output.image_frames_size());
for (const auto& frame : output.image_frames()) {
image_frames->push_back(frame);
}
return absl::OkStatus();
}

VisionModel::~VisionModel() { go_release_model(model_handle_); }

absl::Status Model::Open(absl::string_view id, const Options* options,
Expand Down
11 changes: 11 additions & 0 deletions saxml/client/cc/sax.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,17 @@ class VisionModel {
const std::vector<absl::string_view>& image_frames,
std::vector<double>* tokens) const;

// TokenToVideo produces a video (a list of image frames) in bytes given a
// list of 'tokens' in type double.
//
// On success, returns OK and fills in image frames computed by
// the model. Otherwise, returns an error.
absl::Status TokenToVideo(const std::vector<double>& tokens,
std::vector<absl::string_view>* image_frames) const;
absl::Status TokenToVideo(const ModelOptions& options,
const std::vector<double>& tokens,
std::vector<absl::string_view>* image_frames) const;

private:
explicit VisionModel(int64_t model_handle) : model_handle_(model_handle) {}
friend class Model;
Expand Down
47 changes: 47 additions & 0 deletions saxml/client/cc/saxwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,53 @@ func go_vm_video_to_token(ptr C.long, timeout C.float, imageFramesData **C.char,
buildReturnValues(outData, outSize, errMsg, errCode, &content, nil)
}

//export go_vm_token_to_video
func go_vm_token_to_video(ptr C.long, timeout C.float, tokensData **C.char, tokensSize C.int,
optionsData *C.char, optionsSize C.int, outData **C.char, outSize *C.int,
errMsg **C.char, errCode *C.int) {
vm := rcgo.Handle(ptr).Value().(*sax.VisionModel)
if vm == nil {
// This is not expected.
log.Fatalf("token_to_video() called on nil vision model.")
}
optionsByte := C.GoBytes(unsafe.Pointer(optionsData), optionsSize)
options := &cpb.ExtraInputs{}
if err := proto.Unmarshal(optionsByte, options); err != nil {
buildReturnValues(outData, outSize, errMsg, errCode, nil, err)
return
}

tokens := make([]float64, tokensSize)
tokenDataPtrStart := unsafe.Pointer(tokensData)
for i := 0; i < int(tokensSize); i++ {
tokenDataPtr := uintptr(tokenDataPtrStart) + uintptr(i)*unsafe.Sizeof("float64")
token := *((*float64)(unsafe.Pointer(tokenDataPtr)))
tokens[i] = token
}

ctx, cancel := createContextWithTimeout(timeout)
if cancel != nil {
defer cancel()
}
res, err := vm.TokenToVideo(ctx, tokens, protoOptionToSetter(options)...)
if err != nil {
buildReturnValues(outData, outSize, errMsg, errCode, nil, err)
return
}

ret := &vmpb.TokenToVideoResponse{}
for _, imageBytes := range res {
ret.ImageFrames = append(ret.GetImageFrames(), imageBytes)
}
content, err := proto.Marshal(ret)
if err != nil {
buildReturnValues(outData, outSize, errMsg, errCode, nil, err)
return
}
buildReturnValues(outData, outSize, errMsg, errCode, &content, nil)

}

//export go_custom
func go_custom(ptr C.long, timeout C.float, requestData unsafe.Pointer, requestSize C.int, methodNameData *C.char, methodNameSize C.int, optionsData *C.char, optionsSize C.int, outData **C.char, outSize *C.int, errMsg **C.char, errCode *C.int) {
custom := rcgo.Handle(ptr).Value().(*sax.CustomModel)
Expand Down
25 changes: 25 additions & 0 deletions saxml/client/go/sax_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,28 @@ func (v *VisionModel) VideoToToken(ctx context.Context, imageFrames [][]byte, op
}
return resp.GetTokens(), nil
}

// TokenToVideo performs de-tokenization for a list of tokens against a vision model.
// Specifically:
// - 'tokens' is a list of token with type float64.
// Return:
// - 'imageFrames' is a list of bytes where each element is a serialized image frame.
func (v *VisionModel) TokenToVideo(ctx context.Context, tokens []float64, options ...ModelOptionSetter) ([][]byte, error) {
opts := NewModelOptions(options...)
req := &pb.TokenToVideoRequest{
ModelKey: v.model.modelID,
Tokens: tokens,
ExtraInputs: opts.ExtraInputs(),
}

var resp *pb.TokenToVideoResponse
err := v.model.run(ctx, "TokenToVideo", func(conn *grpc.ClientConn) error {
var tokenToVideoErr error
resp, tokenToVideoErr = pbgrpc.NewVisionServiceClient(conn).TokenToVideo(ctx, req)
return tokenToVideoErr
})
if err != nil {
return nil, err
}
return resp.GetImageFrames(), nil
}
11 changes: 10 additions & 1 deletion saxml/client/python/sax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,16 @@ PYBIND11_MODULE(sax, m) {
-> absl::StatusOr<std::vector<double>> {
return vm.VideoToToken(image_frames, options);
},
py::arg("image_frames"), py::arg("options") = nullptr);
py::arg("image_frames"), py::arg("options") = nullptr)
.def(
"TokenToVideo",
[](sax::client::pybind::VisionModel& vm,
const std::vector<double>& tokens,
const sax::client::ModelOptions* options)
-> absl::StatusOr<std::vector<pybind11::bytes>> {
return vm.TokenToVideo(tokens, options);
},
py::arg("tokens"), py::arg("options") = nullptr);

py::class_<sax::client::pybind::Model>(m, "Model")
.def(py::init<absl::string_view, const sax::client::Options*>())
Expand Down
1 change: 1 addition & 0 deletions saxml/client/python/sax.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class VisionModel:
def ImageToText(self, image_bytes: str, text: str = ..., options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def TextAndImageToImage(self, text: str, image_bytes: str, options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def TextToImage(self, text: str, options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def TokenToVideo(self, tokens: list[float], options: ModelOptions = ...) -> list[bytes]: ...
def VideoToText(self, image_frames: list[str], text: str = ..., options: ModelOptions = ...) -> list[tuple[bytes,float]]: ...
def VideoToToken(self, image_frames: list[str], options: ModelOptions = ...) -> list[float]: ...

Expand Down
20 changes: 20 additions & 0 deletions saxml/client/python/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,26 @@ absl::StatusOr<std::vector<double>> VisionModel::VideoToToken(
return result;
}

absl::StatusOr<std::vector<pybind11::bytes>> VisionModel::TokenToVideo(
const std::vector<double>& tokens, const ModelOptions* options) const {
if (!status_.ok()) return status_;
std::vector<absl::string_view> videos;
{
pybind11::gil_scoped_release release;
if (options == nullptr) {
RETURN_IF_ERROR(model_->TokenToVideo(tokens, &videos));
} else {
RETURN_IF_ERROR(model_->TokenToVideo(*options, tokens, &videos));
}
}
std::vector<pybind11::bytes> result;
result.reserve(videos.size());
for (auto video : videos) {
result.push_back(pybind11::bytes(std::move(video)));
}
return result;
}

Model::Model(absl::string_view id, const Options* options) {
status_ = ::sax::client::Model::Open(id, options, &base_);
}
Expand Down
7 changes: 7 additions & 0 deletions saxml/client/python/wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ class VisionModel {
const std::vector<absl::string_view>& image_frames,
const ModelOptions* options = nullptr) const;

// TokenToVideo produces a list of image frames given a list of 'tokens'.
//
// Returns a vector of image frames with each frame being bytes.
absl::StatusOr<std::vector<pybind11::bytes>> TokenToVideo(
const std::vector<double>& tokens,
const ModelOptions* options = nullptr) const;

private:
explicit VisionModel(::sax::client::Model* base, const absl::Status& status);
::sax::client::Model* base_ = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions saxml/common/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,18 @@ func (s *stubVisionModelServer) VideoToToken(ctx context.Context, in *vmpb.Video
}, nil
}

func (s *stubVisionModelServer) TokenToVideo(ctx context.Context, in *vmpb.TokenToVideoRequest) (*vmpb.TokenToVideoResponse, error) {
value := len(in.GetTokens())
return &vmpb.TokenToVideoResponse{
ImageFrames: [][]byte{
[]byte("frame_0"),
[]byte("frame_1"),
[]byte("frame_2"),
[]byte("frame_" + strconv.Itoa(int(value))),
},
}, nil
}

type stubAudioModelServer struct{}

func (s *stubAudioModelServer) Recognize(ctx context.Context, in *ampb.AsrRequest) (*ampb.AsrResponse, error) {
Expand Down
15 changes: 14 additions & 1 deletion saxml/protobuf/vision.proto
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ message VideoToTokenResponse {
repeated double tokens = 1; // quantized or soft tokens.
}

message TokenToVideoRequest {
string model_key = 1;
repeated double tokens = 2; // quantized or soft tokens.
.sax.ExtraInputs extra_inputs = 3;
}

message TokenToVideoResponse {
repeated bytes image_frames = 2; // Video composed of multiple image frames.
}

service VisionService {
// Returns the score (e.g., log pplx) given the text.
rpc Classify(ClassifyRequest) returns (ClassifyResponse);
Expand Down Expand Up @@ -217,6 +227,9 @@ service VisionService {
// Returns text generation results given video.
rpc VideoToText(VideoToTextRequest) returns (VideoToTextResponse);

// Returns video tokens results given video.
// Returns video tokens results given video (tokenization).
rpc VideoToToken(VideoToTokenRequest) returns (VideoToTokenResponse);

// Returns video bytes results given video tokens (de-tokenization).
rpc TokenToVideo(TokenToVideoRequest) returns (TokenToVideoResponse);
}
Loading

0 comments on commit 5b61e5e

Please sign in to comment.