Skip to content

Enable async modelset + refactor of model creation #676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ADD_LIBRARY(redisai_obj OBJECT
execution/command_parser.c
execution/run_info.c
execution/background_workers.c
execution/background_modelset.c
config/config.c
execution/DAG/dag.c
execution/DAG/dag_parser.c
Expand Down
13 changes: 4 additions & 9 deletions src/backends/backends.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) {
init_backend(RedisModule_GetApi);

backend.model_create_with_nodes =
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, size_t, const char **, size_t,
const char **, const char *, size_t,
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF");
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF");
if (backend.model_create_with_nodes == NULL) {
dlclose(handle);
RedisModule_Log(ctx, "warning",
Expand Down Expand Up @@ -180,8 +178,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) {
init_backend(RedisModule_GetApi);

backend.model_create =
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite");
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite");
if (backend.model_create == NULL) {
dlclose(handle);
RedisModule_Log(ctx, "warning",
Expand Down Expand Up @@ -272,8 +269,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
init_backend(RedisModule_GetApi);

backend.model_create =
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch");
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch");
if (backend.model_create == NULL) {
dlclose(handle);
RedisModule_Log(ctx, "warning",
Expand Down Expand Up @@ -396,8 +392,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) {
init_backend(RedisModule_GetApi);

backend.model_create =
(RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT");
(int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT");
if (backend.model_create == NULL) {
dlclose(handle);
RedisModule_Log(ctx, "warning",
Expand Down
11 changes: 4 additions & 7 deletions src/backends/backends.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@
*/
typedef struct RAI_LoadedBackend {
// ** model_create_with_nodes **: A callback function pointer that creates a
// model given the RAI_ModelOpts and input and output nodes
RAI_Model *(*model_create_with_nodes)(RAI_Backend, const char *, RAI_ModelOpts, size_t,
const char **, size_t, const char **, const char *,
size_t, RAI_Error *);
// model given the RAI_ModelOpts and input and output nodes (which are stored in the model).
int (*model_create_with_nodes)(RAI_Model *, RAI_Error *);

// ** model_create **: A callback function pointer that creates a model given
// the RAI_ModelOpts
RAI_Model *(*model_create)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t,
RAI_Error *);
// the RAI_ModelOpts (which are stored in the model).
int (*model_create)(RAI_Model *, RAI_Error *);

// ** model_free **: A callback function pointer that frees a model given the
// RAI_Model pointer
Expand Down
82 changes: 40 additions & 42 deletions src/backends/onnxruntime.c
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,13 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
return NULL;
}

RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
const char *modeldef, size_t modellen, RAI_Error *error) {
int RAI_ModelCreateORT(RAI_Model *model, RAI_Error *error) {

const OrtApi *ort = OrtGetApiBase()->GetApi(1);
char **inputs_ = NULL;
char **outputs_ = NULL;
size_t ninputs;
size_t noutputs;
OrtSessionOptions *session_options = NULL;
OrtSession *session = NULL;
OrtStatus *status = NULL;
Expand All @@ -348,7 +349,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
}

ONNX_VALIDATE_STATUS(ort->CreateSessionOptions(&session_options))
if (strcasecmp(devicestr, "CPU") == 0) {
if (strcasecmp(model->devicestr, "CPU") == 0) {
// These are required to ensure that onnx will use the registered REDIS allocator (for
// a model that defined to run on CPU).
ONNX_VALIDATE_STATUS(
Expand All @@ -359,24 +360,31 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
// TODO: these options could be configured at the AI.CONFIG level
ONNX_VALIDATE_STATUS(ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC))
ONNX_VALIDATE_STATUS(
ort->SetIntraOpNumThreads(session_options, (int)opts.backends_intra_op_parallelism))
ort->SetIntraOpNumThreads(session_options, (int)model->opts.backends_intra_op_parallelism))
ONNX_VALIDATE_STATUS(
ort->SetInterOpNumThreads(session_options, (int)opts.backends_inter_op_parallelism))
ort->SetInterOpNumThreads(session_options, (int)model->opts.backends_inter_op_parallelism))

// If the model is set for GPU, this will set CUDA provider for the session,
// so that onnx will use its own allocator for CUDA (not Redis allocator)
if (!setDeviceId(devicestr, session_options, error)) {
if (!setDeviceId(model->devicestr, session_options, error)) {
ort->ReleaseSessionOptions(session_options);
return NULL;
return REDISMODULE_ERR;
}

ONNX_VALIDATE_STATUS(
ort->CreateSessionFromArray(env, modeldef, modellen, session_options, &session))
ort->CreateSessionFromArray(env, model->data, model->datalen, session_options, &session))
ort->ReleaseSessionOptions(session_options);

model->session = session;

size_t n_input_nodes;
ONNX_VALIDATE_STATUS(ort->SessionGetInputCount(session, &n_input_nodes))
size_t n_output_nodes;

// We save the model's inputs and outputs only in the first time that we create the model.
// We might create the model again when loading from RDB, in this case the inputs and outputs
// are already loaded from RDB.
// if (!model->inputs) {
ONNX_VALIDATE_STATUS(ort->SessionGetInputCount(session, &n_input_nodes))
ONNX_VALIDATE_STATUS(ort->SessionGetOutputCount(session, &n_output_nodes))

inputs_ = array_new(char *, n_input_nodes);
Expand All @@ -393,27 +401,13 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
outputs_ = array_append(outputs_, output_name);
}

// Since ONNXRuntime doesn't have a re-serialization function,
// we cache the blob in order to re-serialize it.
// Not optimal for storage purposes, but again, it may be temporary
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret));
ret->model = NULL;
ret->session = session;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->refCount = 1;
ret->opts = opts;
ret->data = buffer;
ret->datalen = modellen;
ret->ninputs = n_input_nodes;
ret->noutputs = n_output_nodes;
ret->inputs = inputs_;
ret->outputs = outputs_;

return ret;
model->ninputs = n_input_nodes;
model->noutputs = n_output_nodes;
model->inputs = inputs_;
model->outputs = outputs_;
//}

return REDISMODULE_OK;

error:
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
Expand All @@ -438,28 +432,32 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
ort->ReleaseSession(session);
}
ort->ReleaseStatus(status);
return NULL;
return REDISMODULE_ERR;
}

void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
const OrtApi *ort = OrtGetApiBase()->GetApi(1);
OrtStatus *status = NULL;

for (uint32_t i = 0; i < model->ninputs; i++) {
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->inputs[i]))
if (model->inputs) {
for (uint32_t i = 0; i < model->ninputs; i++) {
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->inputs[i]))
}
array_free(model->inputs);
model->inputs = NULL;
}
array_free(model->inputs);

for (uint32_t i = 0; i < model->noutputs; i++) {
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->outputs[i]))
if (model->outputs) {
for (uint32_t i = 0; i < model->noutputs; i++) {
ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->outputs[i]))
}
array_free(model->outputs);
model->outputs = NULL;
}
array_free(model->outputs);

RedisModule_Free(model->devicestr);
RedisModule_Free(model->data);
ort->ReleaseSession(model->session);
model->model = NULL;
model->session = NULL;
if (model->session) {
ort->ReleaseSession(model->session);
}
return;

error:
Expand Down
3 changes: 1 addition & 2 deletions src/backends/onnxruntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ unsigned long long RAI_GetMemoryAccessORT(void);

int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *));

RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
const char *modeldef, size_t modellen, RAI_Error *err);
int RAI_ModelCreateORT(RAI_Model *model, RAI_Error *err);

void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error);

Expand Down
Loading