From 0616eeb670da8e1c857a188ad116ed1f0bb08dea Mon Sep 17 00:00:00 2001 From: alonre24 Date: Fri, 9 Apr 2021 15:37:59 +0300 Subject: [PATCH 1/4] Enable async modelset + refactor of model creation (draft) --- src/CMakeLists.txt | 1 + src/backends/backends.c | 13 +- src/backends/backends.h | 11 +- src/backends/onnxruntime.c | 82 +++--- src/backends/onnxruntime.h | 3 +- src/backends/tensorflow.c | 120 +++----- src/backends/tensorflow.h | 5 +- src/backends/tflite.c | 83 ++---- src/backends/tflite.h | 3 +- src/backends/torch.c | 104 +++---- src/backends/torch.h | 3 +- src/backends/util.c | 3 +- src/execution/background_modelset.c | 154 +++++++++++ src/execution/background_modelset.h | 33 +++ src/execution/background_workers.c | 18 -- src/execution/background_workers.h | 20 +- src/execution/command_parser.c | 224 ++++++++++++++- src/execution/command_parser.h | 2 + src/redis_ai_objects/model.c | 139 +++++++--- src/redis_ai_objects/model.h | 2 + src/redis_ai_objects/stats.c | 4 +- src/redis_ai_objects/stats.h | 4 +- src/redisai.c | 258 ++---------------- .../RDB/decoder/current/v1/decode_v1.c | 110 ++++---- src/util/queue.c | 2 + src/util/queue.h | 1 + tests/flow/tests_pytorch.py | 6 +- tests/flow/tests_tensorflow.py | 2 +- tests/flow/tests_tflite.py | 2 +- 29 files changed, 778 insertions(+), 634 deletions(-) create mode 100644 src/execution/background_modelset.c create mode 100644 src/execution/background_modelset.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b6272351d..dd1d46daa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -27,6 +27,7 @@ ADD_LIBRARY(redisai_obj OBJECT execution/command_parser.c execution/run_info.c execution/background_workers.c + execution/background_modelset.c config/config.c execution/DAG/dag.c execution/DAG/dag_parser.c diff --git a/src/backends/backends.c b/src/backends/backends.c index f647fa585..468757dfb 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -88,9 +88,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { init_backend(RedisModule_GetApi); backend.model_create_with_nodes = - (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, size_t, const char **, size_t, - const char **, const char *, size_t, - RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF"); + (int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF"); if (backend.model_create_with_nodes == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", @@ -180,8 +178,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { init_backend(RedisModule_GetApi); backend.model_create = - (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, - RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite"); + (int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite"); if (backend.model_create == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", @@ -272,8 +269,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { init_backend(RedisModule_GetApi); backend.model_create = - (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, - RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch"); + (int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch"); if (backend.model_create == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", @@ -396,8 +392,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { init_backend(RedisModule_GetApi); backend.model_create = - (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, - RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT"); + (int (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT"); if (backend.model_create == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", diff --git a/src/backends/backends.h b/src/backends/backends.h index 6fd4d1e80..467fc0343 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -40,15 +40,12 @@ */ typedef struct RAI_LoadedBackend { // ** model_create_with_nodes **: A callback function pointer that creates a - // model given the RAI_ModelOpts and input and output nodes - RAI_Model *(*model_create_with_nodes)(RAI_Backend, const char *, RAI_ModelOpts, size_t, - const char **, size_t, const char **, const char *, - size_t, RAI_Error *); + // model given the RAI_ModelOpts and input and output nodes (which are stored in the model). + int (*model_create_with_nodes)(RAI_Model *, RAI_Error *); // ** model_create **: A callback function pointer that creates a model given - // the RAI_ModelOpts - RAI_Model *(*model_create)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, - RAI_Error *); + // the RAI_ModelOpts (which are stored in the model). + int (*model_create)(RAI_Model *, RAI_Error *); // ** model_free **: A callback function pointer that frees a model given the // RAI_Model pointer diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 805332864..24d5633ed 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -326,12 +326,13 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long return NULL; } -RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - const char *modeldef, size_t modellen, RAI_Error *error) { +int RAI_ModelCreateORT(RAI_Model *model, RAI_Error *error) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); char **inputs_ = NULL; char **outputs_ = NULL; + size_t ninputs; + size_t noutputs; OrtSessionOptions *session_options = NULL; OrtSession *session = NULL; OrtStatus *status = NULL; @@ -348,7 +349,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo } ONNX_VALIDATE_STATUS(ort->CreateSessionOptions(&session_options)) - if (strcasecmp(devicestr, "CPU") == 0) { + if (strcasecmp(model->devicestr, "CPU") == 0) { // These are required to ensure that onnx will use the registered REDIS allocator (for // a model that defined to run on CPU). ONNX_VALIDATE_STATUS( @@ -359,24 +360,31 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo // TODO: these options could be configured at the AI.CONFIG level ONNX_VALIDATE_STATUS(ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC)) ONNX_VALIDATE_STATUS( - ort->SetIntraOpNumThreads(session_options, (int)opts.backends_intra_op_parallelism)) + ort->SetIntraOpNumThreads(session_options, (int)model->opts.backends_intra_op_parallelism)) ONNX_VALIDATE_STATUS( - ort->SetInterOpNumThreads(session_options, (int)opts.backends_inter_op_parallelism)) + ort->SetInterOpNumThreads(session_options, (int)model->opts.backends_inter_op_parallelism)) // If the model is set for GPU, this will set CUDA provider for the session, // so that onnx will use its own allocator for CUDA (not Redis allocator) - if (!setDeviceId(devicestr, session_options, error)) { + if (!setDeviceId(model->devicestr, session_options, error)) { ort->ReleaseSessionOptions(session_options); - return NULL; + return REDISMODULE_ERR; } ONNX_VALIDATE_STATUS( - ort->CreateSessionFromArray(env, modeldef, modellen, session_options, &session)) + ort->CreateSessionFromArray(env, model->data, model->datalen, session_options, &session)) ort->ReleaseSessionOptions(session_options); + model->session = session; + size_t n_input_nodes; - ONNX_VALIDATE_STATUS(ort->SessionGetInputCount(session, &n_input_nodes)) size_t n_output_nodes; + + // We save the model's inputs and outputs only in the first time that we create the model. + // We might create the model again when loading from RDB, in this case the inputs and outputs + // are already loaded from RDB. + // if (!model->inputs) { + ONNX_VALIDATE_STATUS(ort->SessionGetInputCount(session, &n_input_nodes)) ONNX_VALIDATE_STATUS(ort->SessionGetOutputCount(session, &n_output_nodes)) inputs_ = array_new(char *, n_input_nodes); @@ -393,27 +401,13 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo outputs_ = array_append(outputs_, output_name); } - // Since ONNXRuntime doesn't have a re-serialization function, - // we cache the blob in order to re-serialize it. - // Not optimal for storage purposes, but again, it may be temporary - char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); - memcpy(buffer, modeldef, modellen); - - RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); - ret->model = NULL; - ret->session = session; - ret->backend = backend; - ret->devicestr = RedisModule_Strdup(devicestr); - ret->refCount = 1; - ret->opts = opts; - ret->data = buffer; - ret->datalen = modellen; - ret->ninputs = n_input_nodes; - ret->noutputs = n_output_nodes; - ret->inputs = inputs_; - ret->outputs = outputs_; - - return ret; + model->ninputs = n_input_nodes; + model->noutputs = n_output_nodes; + model->inputs = inputs_; + model->outputs = outputs_; + //} + + return REDISMODULE_OK; error: RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status)); @@ -438,28 +432,32 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo ort->ReleaseSession(session); } ort->ReleaseStatus(status); - return NULL; + return REDISMODULE_ERR; } void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); OrtStatus *status = NULL; - for (uint32_t i = 0; i < model->ninputs; i++) { - ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->inputs[i])) + if (model->inputs) { + for (uint32_t i = 0; i < model->ninputs; i++) { + ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->inputs[i])) + } + array_free(model->inputs); + model->inputs = NULL; } - array_free(model->inputs); - for (uint32_t i = 0; i < model->noutputs; i++) { - ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->outputs[i])) + if (model->outputs) { + for (uint32_t i = 0; i < model->noutputs; i++) { + ONNX_VALIDATE_STATUS(ort->AllocatorFree(global_allocator, model->outputs[i])) + } + array_free(model->outputs); + model->outputs = NULL; } - array_free(model->outputs); - RedisModule_Free(model->devicestr); - RedisModule_Free(model->data); - ort->ReleaseSession(model->session); - model->model = NULL; - model->session = NULL; + if (model->session) { + ort->ReleaseSession(model->session); + } return; error: diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index ec282bac3..1411348f2 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -11,8 +11,7 @@ unsigned long long RAI_GetMemoryAccessORT(void); int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *)); -RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - const char *modeldef, size_t modellen, RAI_Error *err); +int RAI_ModelCreateORT(RAI_Model *model, RAI_Error *err); void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error); diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 1774aeed2..8356f7219 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -212,18 +212,15 @@ TF_Tensor *RAI_TFTensorFromTensors(RAI_Tensor **ts, size_t count) { return out; } -RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - size_t ninputs, const char **inputs, size_t noutputs, - const char **outputs, const char *modeldef, size_t modellen, - RAI_Error *error) { +int RAI_ModelCreateTF(RAI_Model *model, RAI_Error *error) { RAI_Device device; int64_t deviceid; - if (!parseDeviceStr(devicestr, &device, &deviceid)) { + if (!parseDeviceStr(model->devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELIMPORT, "ERR unsupported device"); } - TF_Graph *model = TF_NewGraph(); + TF_Graph *graph = TF_NewGraph(); TF_Status *status = TF_NewStatus(); TF_Buffer *tfbuffer = TF_NewBuffer(); TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions(); @@ -232,36 +229,37 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod TF_Status *sessionStatus = NULL; TF_Session *session = NULL; - tfbuffer->length = modellen; - tfbuffer->data = modeldef; + tfbuffer->length = model->datalen; + tfbuffer->data = model->data; - TF_GraphImportGraphDef(model, tfbuffer, options, status); + TF_GraphImportGraphDef(graph, tfbuffer, options, status); if (TF_GetCode(status) != TF_OK) { char *errorMessage = RedisModule_Strdup(TF_Message(status)); RAI_SetError(error, RAI_EMODELIMPORT, errorMessage); RedisModule_Free(errorMessage); - return NULL; + return REDISMODULE_ERR; } - for (size_t i = 0; i < ninputs; ++i) { - TF_Operation *oper = TF_GraphOperationByName(model, inputs[i]); + // Validate that the given inputs and outputs exist in the graph. + for (size_t i = 0; i < model->ninputs; ++i) { + TF_Operation *oper = TF_GraphOperationByName(graph, model->inputs[i]); if (oper == NULL || strcmp(TF_OperationOpType(oper), "Placeholder") != 0) { - size_t len = strlen(inputs[i]); + size_t len = strlen(model->inputs[i]); char *msg = RedisModule_Calloc(60 + len, sizeof(*msg)); - sprintf(msg, "ERR Input node named \"%s\" not found in TF graph.", inputs[i]); + sprintf(msg, "ERR Input node named \"%s\" not found in TF graph.", model->inputs[i]); RAI_SetError(error, RAI_EMODELIMPORT, msg); RedisModule_Free(msg); goto cleanup; } } - for (size_t i = 0; i < noutputs; ++i) { - TF_Operation *oper = TF_GraphOperationByName(model, outputs[i]); + for (size_t i = 0; i < model->noutputs; ++i) { + TF_Operation *oper = TF_GraphOperationByName(graph, model->outputs[i]); if (oper == NULL) { - size_t len = strlen(outputs[i]); + size_t len = strlen(model->outputs[i]); char *msg = RedisModule_Calloc(60 + len, sizeof(*msg)); - sprintf(msg, "ERR Output node named \"%s\" not found in TF graph", outputs[i]); + sprintf(msg, "ERR Output node named \"%s\" not found in TF graph", model->outputs[i]); RAI_SetError(error, RAI_EMODELIMPORT, msg); RedisModule_Free(msg); goto cleanup; @@ -297,6 +295,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod goto cleanup; } + RAI_ModelOpts opts = model->opts; if (opts.backends_intra_op_parallelism > 0) { uint8_t proto[] = {0x10, (uint8_t)opts.backends_intra_op_parallelism}; TF_SetConfig(sessionOptions, proto, sizeof(proto), optionsStatus); @@ -338,23 +337,22 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod } TF_DeleteStatus(optionsStatus); optionsStatus = NULL; - sessionStatus = TF_NewStatus(); - session = TF_NewSession(model, sessionOptions, sessionStatus); + session = TF_NewSession(graph, sessionOptions, sessionStatus); TF_Status *deviceListStatus = TF_NewStatus(); TF_DeviceList *deviceList = TF_SessionListDevices(session, deviceListStatus); const int num_devices = TF_DeviceListCount(deviceList); - int foundNoGPU = 1; + bool foundNoGPU = true; for (int i = 0; i < num_devices; ++i) { const char *device_type = TF_DeviceListType(deviceList, i, deviceListStatus); - int cmp = strcmp(device_type, "GPU"); - if (cmp == 0) { - foundNoGPU = 0; + bool cmp = strcmp(device_type, "GPU"); + if (!cmp) { + foundNoGPU = false; break; } } - if (foundNoGPU == 1 && device == RAI_DEVICE_GPU) { + if (foundNoGPU && device == RAI_DEVICE_GPU) { RAI_SetError(error, RAI_EMODELCREATE, "ERR GPU requested but TF couldn't find CUDA"); TF_DeleteDeviceList(deviceList); TF_DeleteStatus(deviceListStatus); @@ -371,37 +369,12 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod TF_DeleteSessionOptions(sessionOptions); TF_DeleteStatus(sessionStatus); - char **inputs_ = array_new(char *, ninputs); - for (long long i = 0; i < ninputs; i++) { - inputs_ = array_append(inputs_, RedisModule_Strdup(inputs[i])); - } - - char **outputs_ = array_new(char *, noutputs); - for (long long i = 0; i < noutputs; i++) { - outputs_ = array_append(outputs_, RedisModule_Strdup(outputs[i])); - } - - char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); - memcpy(buffer, modeldef, modellen); - - RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); - ret->model = model; - ret->session = session; - ret->backend = backend; - ret->devicestr = RedisModule_Strdup(devicestr); - ret->ninputs = ninputs; - ret->inputs = inputs_; - ret->noutputs = noutputs; - ret->outputs = outputs_; - ret->opts = opts; - ret->refCount = 1; - ret->data = buffer; - ret->datalen = modellen; - - return ret; + model->model = graph; + model->session = session; + return REDISMODULE_OK; cleanup: - TF_DeleteGraph(model); + TF_DeleteGraph(graph); if (options) TF_DeleteImportGraphDefOptions(options); if (tfbuffer) @@ -412,51 +385,30 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod TF_DeleteSessionOptions(sessionOptions); if (sessionStatus) TF_DeleteStatus(sessionStatus); - return NULL; + return REDISMODULE_ERR; } void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) { + + // If we got an error before we created the session, there's nothing to free. + if (model->session == NULL) { + return; + } + TF_Status *status = TF_NewStatus(); TF_CloseSession(model->session, status); - if (TF_GetCode(status) != TF_OK) { - RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status))); + RAI_SetError(error, RAI_EMODELFREE, TF_Message(status)); return; } TF_DeleteSession(model->session, status); - model->session = NULL; - if (TF_GetCode(status) != TF_OK) { - RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status))); + RAI_SetError(error, RAI_EMODELFREE, TF_Message(status)); return; } TF_DeleteGraph(model->model); - model->model = NULL; - - RedisModule_Free(model->devicestr); - - if (model->inputs) { - size_t ninputs = array_len(model->inputs); - for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free(model->inputs[i]); - } - array_free(model->inputs); - } - - if (model->outputs) { - size_t noutputs = array_len(model->outputs); - for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free(model->outputs[i]); - } - array_free(model->outputs); - } - - if (model->data) { - RedisModule_Free(model->data); - } - TF_DeleteStatus(status); } diff --git a/src/backends/tensorflow.h b/src/backends/tensorflow.h index 7bf9805a2..1a6bb14a4 100644 --- a/src/backends/tensorflow.h +++ b/src/backends/tensorflow.h @@ -7,10 +7,7 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)); -RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - size_t ninputs, const char **inputs, size_t noutputs, - const char **outputs, const char *modeldef, size_t modellen, - RAI_Error *error); +int RAI_ModelCreateTF(RAI_Model *model, RAI_Error *error); void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error); diff --git a/src/backends/tflite.c b/src/backends/tflite.c index 0f9e002d4..fb9831715 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -15,16 +15,17 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) { return REDISMODULE_OK; } -RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - const char *modeldef, size_t modellen, RAI_Error *error) { +int RAI_ModelCreateTFLite(RAI_Model *model, RAI_Error *error) { DLDeviceType dl_device; RAI_Device device; int64_t deviceid; char **inputs_ = NULL; char **outputs_ = NULL; - if (!parseDeviceStr(devicestr, &device, &deviceid)) { + size_t ninputs; + size_t noutputs; + if (!parseDeviceStr(model->devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device"); - return NULL; + return REDISMODULE_ERR; } switch (device) { @@ -34,26 +35,25 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI case RAI_DEVICE_GPU: dl_device = kDLGPU; break; - default: - RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Error configuring model: unsupported device"); - return NULL; } char *error_descr = NULL; - void *model = tfliteLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr); - - if (model == NULL) { - RAI_SetError(error, RAI_EMODELCREATE, error_descr); - RedisModule_Free(error_descr); - return NULL; + void *tf_model = + tfliteLoadModel(model->data, model->datalen, dl_device, deviceid, &error_descr); + if (tf_model == NULL) { + goto cleanup; } + model->model = tf_model; - size_t ninputs = tfliteModelNumInputs(model, &error_descr); + // We save the model's inputs and outputs only in the first time that we create the model. + // We might create the model again when loading from RDB, in this case the inputs and outputs + // are already loaded from RDB. + // if (!model->inputs) { + ninputs = tfliteModelNumInputs(tf_model, &error_descr); if (error_descr) { goto cleanup; } - - size_t noutputs = tfliteModelNumOutputs(model, &error_descr); + noutputs = tfliteModelNumOutputs(tf_model, &error_descr); if (error_descr) { goto cleanup; } @@ -62,43 +62,33 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI outputs_ = array_new(char *, noutputs); for (size_t i = 0; i < ninputs; i++) { - const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr); + const char *input = tfliteModelInputNameAtIndex(tf_model, i, &error_descr); if (error_descr) { goto cleanup; } inputs_ = array_append(inputs_, RedisModule_Strdup(input)); } - for (size_t i = 0; i < noutputs; i++) { - const char *output = tfliteModelOutputNameAtIndex(model, i, &error_descr); - ; + const char *output = tfliteModelOutputNameAtIndex(tf_model, i, &error_descr); if (error_descr) { goto cleanup; } outputs_ = array_append(outputs_, RedisModule_Strdup(output)); } + model->ninputs = ninputs; + model->noutputs = noutputs; + model->inputs = inputs_; + model->outputs = outputs_; + //} - char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); - memcpy(buffer, modeldef, modellen); - - RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); - ret->model = model; - ret->session = NULL; - ret->backend = backend; - ret->devicestr = RedisModule_Strdup(devicestr); - ret->ninputs = ninputs; - ret->inputs = inputs_; - ret->noutputs = noutputs; - ret->outputs = outputs_; - ret->refCount = 1; - ret->opts = opts; - ret->data = buffer; - ret->datalen = modellen; - return ret; + return REDISMODULE_OK; cleanup: RAI_SetError(error, RAI_EMODELCREATE, error_descr); RedisModule_Free(error_descr); + if (tf_model) { + tfliteDeallocContext(tf_model); + } if (inputs_) { ninputs = array_len(inputs_); for (size_t i = 0; i < ninputs; i++) { @@ -113,26 +103,13 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI } array_free(outputs_); } - return NULL; + return REDISMODULE_ERR; } void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) { - RedisModule_Free(model->data); - RedisModule_Free(model->devicestr); - tfliteDeallocContext(model->model); - size_t ninputs = model->ninputs; - for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free(model->inputs[i]); - } - array_free(model->inputs); - - size_t noutputs = model->noutputs; - for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free(model->outputs[i]); + if (model->model) { + tfliteDeallocContext(model->model); } - array_free(model->outputs); - - model->model = NULL; } int RAI_ModelRunTFLite(RAI_ModelRunCtx **mctxs, RAI_Error *error) { diff --git a/src/backends/tflite.h b/src/backends/tflite.h index d90148312..64bb95a08 100644 --- a/src/backends/tflite.h +++ b/src/backends/tflite.h @@ -7,8 +7,7 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)); -RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - const char *modeldef, size_t modellen, RAI_Error *err); +int RAI_ModelCreateTFLite(RAI_Model *model, RAI_Error *err); void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error); diff --git a/src/backends/torch.c b/src/backends/torch.c index a856c56d2..1d54e8632 100644 --- a/src/backends/torch.c +++ b/src/backends/torch.c @@ -29,19 +29,19 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) { return REDISMODULE_OK; } -RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - const char *modeldef, size_t modellen, RAI_Error *error) { - DLDeviceType dl_device; +int RAI_ModelCreateTorch(RAI_Model *model, RAI_Error *error) { + DLDeviceType dl_device; RAI_Device device = RAI_DEVICE_CPU; int64_t deviceid = 0; - + size_t ninputs; + size_t noutputs; char **inputs_ = NULL; char **outputs_ = NULL; - if (!parseDeviceStr(devicestr, &device, &deviceid)) { + if (!parseDeviceStr(model->devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device"); - return NULL; + return REDISMODULE_ERR; } switch (device) { @@ -51,44 +51,40 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ case RAI_DEVICE_GPU: dl_device = kDLGPU; break; - default: - RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Error configuring model: unsupported device"); - return NULL; } char *error_descr = NULL; - if (opts.backends_inter_op_parallelism > 0) { - torchSetInterOpThreads(opts.backends_inter_op_parallelism, &error_descr, RedisModule_Alloc); - } - - if (error_descr != NULL) { - RAI_SetError(error, RAI_EMODELCREATE, error_descr); - RedisModule_Free(error_descr); - return NULL; + void *torch_model = torchLoadModel(model->data, model->datalen, dl_device, deviceid, + &error_descr, RedisModule_Alloc); + if (error_descr) { + goto cleanup; } + model->model = torch_model; - if (opts.backends_intra_op_parallelism > 0) { - torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc); + if (model->opts.backends_inter_op_parallelism > 0) { + torchSetInterOpThreads((int)model->opts.backends_inter_op_parallelism, &error_descr, + RedisModule_Alloc); } if (error_descr) { - RAI_SetError(error, RAI_EMODELCREATE, error_descr); - RedisModule_Free(error_descr); - return NULL; + goto cleanup; + } + if (model->opts.backends_intra_op_parallelism > 0) { + torchSetIntraOpThreads((int)model->opts.backends_intra_op_parallelism, &error_descr, + RedisModule_Alloc); } - - void *model = - torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc); - if (error_descr) { goto cleanup; } - size_t ninputs = torchModelNumInputs(model, &error_descr); + // We save the model's inputs and outputs only in the first time that we create the model. + // We might create the model again when loading from RDB, in this case the inputs and outputs + // are already loaded from RDB. + // if (!model->inputs) { + ninputs = torchModelNumInputs(torch_model, &error_descr); if (error_descr) { goto cleanup; } - - size_t noutputs = torchModelNumOutputs(model, &error_descr); + noutputs = torchModelNumOutputs(torch_model, &error_descr); if (error_descr) { goto cleanup; } @@ -97,13 +93,12 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ outputs_ = array_new(char *, noutputs); for (size_t i = 0; i < ninputs; i++) { - const char *input = torchModelInputNameAtIndex(model, i, &error_descr); + const char *input = torchModelInputNameAtIndex(torch_model, i, &error_descr); if (error_descr) { goto cleanup; } inputs_ = array_append(inputs_, RedisModule_Strdup(input)); } - for (size_t i = 0; i < noutputs; i++) { const char *output = ""; if (error_descr) { @@ -112,26 +107,19 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ outputs_ = array_append(outputs_, RedisModule_Strdup(output)); } - char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); - memcpy(buffer, modeldef, modellen); + model->ninputs = ninputs; + model->noutputs = noutputs; + model->inputs = inputs_; + model->outputs = outputs_; + //} - RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); - ret->model = model; - ret->session = NULL; - ret->backend = backend; - ret->devicestr = RedisModule_Strdup(devicestr); - ret->ninputs = ninputs; - ret->inputs = inputs_; - ret->noutputs = noutputs; - ret->outputs = outputs_; - ret->opts = opts; - ret->refCount = 1; - ret->data = buffer; - ret->datalen = modellen; - return ret; + return REDISMODULE_OK; cleanup: RAI_SetError(error, RAI_EMODELCREATE, error_descr); + if (torch_model) { + torchDeallocContext(torch_model); + } RedisModule_Free(error_descr); if (inputs_) { ninputs = array_len(inputs_); @@ -147,29 +135,13 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ } array_free(outputs_); } - return NULL; + return REDISMODULE_ERR; } void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) { - if (model->devicestr) { - RedisModule_Free(model->devicestr); + if (model->model) { + torchDeallocContext(model->model); } - if (model->data) { - RedisModule_Free(model->data); - } - size_t ninputs = model->ninputs; - for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free(model->inputs[i]); - } - array_free(model->inputs); - - size_t noutputs = model->noutputs; - for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free(model->outputs[i]); - } - array_free(model->outputs); - - torchDeallocContext(model->model); } int RAI_ModelRunTorch(RAI_ModelRunCtx **mctxs, RAI_Error *error) { diff --git a/src/backends/torch.h b/src/backends/torch.h index cde36e588..4f9c2ab72 100644 --- a/src/backends/torch.h +++ b/src/backends/torch.h @@ -8,8 +8,7 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)); -RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, - const char *modeldef, size_t modellen, RAI_Error *err); +int RAI_ModelCreateTorch(RAI_Model *model, RAI_Error *err); void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error); diff --git a/src/backends/util.c b/src/backends/util.c index e12731e38..ca474f27d 100644 --- a/src/backends/util.c +++ b/src/backends/util.c @@ -1,7 +1,8 @@ #include "backends/util.h" +#include "string.h" int parseDeviceStr(const char *devicestr, RAI_Device *device, int64_t *deviceid) { - // if (strcasecmp(devicestr, "CPU") == 0) { + if (strncasecmp(devicestr, "CPU", 3) == 0) { *device = RAI_DEVICE_CPU; *deviceid = -1; diff --git a/src/execution/background_modelset.c b/src/execution/background_modelset.c new file mode 100644 index 000000000..9a16691d7 --- /dev/null +++ b/src/execution/background_modelset.c @@ -0,0 +1,154 @@ +#include "background_modelset.h" +#include "command_parser.h" +#include "backends/backends.h" +#include +#include "backends/util.h" + +#define BG_MODELSET_THREADS_NUM 1 + +int Init_BG_ModelSet() { + + modelSet_QueueInfo = RedisModule_Alloc(sizeof(RunQueueInfo)); + modelSet_QueueInfo->run_queue = queueCreate(); + modelSet_QueueInfo->devicestr = ""; + pthread_cond_init(&(modelSet_QueueInfo)->queue_condition_var, NULL); + pthread_mutex_init(&(modelSet_QueueInfo)->run_queue_mutex, NULL); + modelSet_QueueInfo->threads = + (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * BG_MODELSET_THREADS_NUM); + + /* create thread(s) */ + for (int i = 0; i < BG_MODELSET_THREADS_NUM; i++) { + if (pthread_create(&((modelSet_QueueInfo)->threads[i]), NULL, RedisAI_ModelSet_ThreadMain, + modelSet_QueueInfo) != 0) { + freeRunQueueInfo(modelSet_QueueInfo); + return REDISMODULE_ERR; + } + } + return REDISMODULE_OK; +} + +void ModelSet_Execute(void *args) { + ModelSetCtx *model_ctx = (ModelSetCtx *)args; + + RedisModuleString **argv = model_ctx->args; + model_ctx->model = RedisModule_Calloc(1, sizeof(*(model_ctx->model))); + RAI_InitError(&model_ctx->err); + RAI_Error *err = model_ctx->err; + RAI_Model *model = model_ctx->model; + model->refCount = 1; + + // If we fail, we unblock and the model_ctx internals will be freed. + int status = ParseModelSetCommand(argv, array_len(argv), model, err); + if (status != REDISMODULE_OK) { + RedisModule_UnblockClient(model_ctx->client, model_ctx); + return; + } + + const char *backend_str = RAI_BackendName(model->backend); + + if (ModelCreateBE(model, err) != REDISMODULE_OK) { + // If we got an error *not* because of lazy loading, we fail and unblock. + if (RAI_GetErrorCode(err) != RAI_EBACKENDNOTLOADED) { + RedisModule_UnblockClient(model_ctx->client, model_ctx); + return; + } + RedisModule_Log(NULL, "warning", "backend %s not loaded, will try loading default backend", + backend_str); + int ret = RAI_LoadDefaultBackend(NULL, model->backend); + if (ret != REDISMODULE_OK) { + RedisModule_Log(NULL, "error", "could not load %s default backend", backend_str); + RedisModule_UnblockClient(model_ctx->client, model_ctx); + return; + } + // Try creating model for backend again. + RAI_ClearError(err); + ModelCreateBE(model, err); + } + RedisModule_UnblockClient(model_ctx->client, model_ctx); +} + +int RedisAI_ModelSet_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + REDISMODULE_NOT_USED(argv); + REDISMODULE_NOT_USED(argc); + + ModelSetCtx *model_ctx = RedisModule_GetBlockedClientPrivateData(ctx); + + // If at some point we got an error, we return it (model_ctx is freed). + if (RAI_GetErrorCode(model_ctx->err) != RAI_OK) { + return RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(model_ctx->err)); + } + + // Save model in keyspace. + RAI_Model *model = model_ctx->model; + RedisModuleString *key_str = model->infokey; + RedisModuleKey *key = RedisModule_OpenKey(ctx, key_str, REDISMODULE_READ | REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + + // Two valid scenarios: 1. We create a new key, 2. The key is already holding + // a RedisAI model type (in this case we update the key's value). + if (type != REDISMODULE_KEYTYPE_EMPTY && + !(type == REDISMODULE_KEYTYPE_MODULE && + RedisModule_ModuleTypeGetType(key) == RedisAI_ModelType)) { + RedisModule_CloseKey(key); + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + RedisModule_ModuleTypeSetValue(key, RedisAI_ModelType, model); + RedisModule_CloseKey(key); + + // Save this model in stats global dict. + RAI_AddStatsEntry(NULL, key_str, RAI_MODEL, model->backend, model->devicestr, model->tag); + + // Get shallow copy so the free callback won't delete the model. + RAI_ModelGetShallowCopy(model_ctx->model); + + RedisModule_ReplyWithSimpleString(ctx, "OK"); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +void ModelSet_FreeData(RedisModuleCtx *ctx, void *private_data) { + ModelSetCtx *model_ctx = (ModelSetCtx *)private_data; + + RAI_FreeError(model_ctx->err); + + // This is a "dummy" error, we do not need it here since we only decrease + // the model's ref_count in case of success, and otherwise a different error has returned. + RAI_Error err = {0}; + if (model_ctx->model) { + RAI_ModelFree(model_ctx->model, &err); + } + + for (size_t i = 0; i < array_len(model_ctx->args); i++) { + RedisModule_FreeString(NULL, model_ctx->args[i]); + } + array_free(model_ctx->args); + + RedisModule_Free(model_ctx); +} + +void *RedisAI_ModelSet_ThreadMain(void *arg) { + RunQueueInfo *run_queue_info = (RunQueueInfo *)arg; + RAI_PTHREAD_SETNAME("redisai_modelset_bthread"); + pthread_mutex_lock(&run_queue_info->run_queue_mutex); + + while (true) { + pthread_cond_wait(&run_queue_info->queue_condition_var, &run_queue_info->run_queue_mutex); + queueItem *item = queuePop(run_queue_info->run_queue); + pthread_mutex_unlock(&run_queue_info->run_queue_mutex); + + // Currently the job's callback is always MODELSET. + Job *job = queueItemGetValue(item); + RedisModule_Free(item); + job->Execute(job->args); + JobFree(job); + pthread_mutex_lock(&run_queue_info->run_queue_mutex); + } +} + +Job *JobCreate(void *args, void (*CallBack)(void *)) { + Job *job = RedisModule_Alloc(sizeof(*job)); + job->args = args; + job->Execute = CallBack; +} + +void JobFree(Job *job) { RedisModule_Free(job); } \ No newline at end of file diff --git a/src/execution/background_modelset.h b/src/execution/background_modelset.h new file mode 100644 index 000000000..791ad7559 --- /dev/null +++ b/src/execution/background_modelset.h @@ -0,0 +1,33 @@ +#pragma once + +#include "background_workers.h" + +RunQueueInfo *modelSet_QueueInfo; + +// We use this generic struct to enable future extension - This BG workers may +// have more purposes (monitoring, statistics etc...) +typedef struct Job { + void (*Execute)(void *); + void *args; +} Job; + +typedef struct ModelSetCtx { + RedisModuleString **args; + RedisModuleBlockedClient *client; + RAI_Model *model; + RAI_Error *err; +} ModelSetCtx; + +Job *JobCreate(void *args, void (*CallBack)(void *)); + +void JobFree(Job *job); + +int Init_BG_ModelSet(); + +void ModelSet_FreeData(RedisModuleCtx *ctx, void *err); + +int RedisAI_ModelSet_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); + +void *RedisAI_ModelSet_ThreadMain(void *arg); + +void ModelSet_Execute(void *args); diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index b1b394c72..58819debf 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -20,24 +20,6 @@ #include "run_info.h" #include "background_workers.h" -/* Define for RedisAI thread name setter */ -#ifdef __linux__ -#define RAI_PTHREAD_SETNAME(name) pthread_setname_np(pthread_self(), name) -#else -#if (defined __NetBSD__ || defined __FreeBSD__ || defined __OpenBSD__) -#include -#define RAI_PTHREAD_SETNAME(name) pthread_set_name_np(pthread_self(), name) -#else -#if (defined __APPLE__ && defined(MAC_OS_X_VERSION_10_7)) -int pthread_setname_np(const char *name); -#include -#define RAI_PTHREAD_SETNAME(name) pthread_setname_np(name) -#else -#define RAI_PTHREAD_SETNAME(name) -#endif -#endif -#endif - int freeRunQueueInfo(RunQueueInfo *info) { int result = REDISMODULE_OK; if (info->run_queue) { diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 43d7099b0..36e45c16f 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -15,8 +15,6 @@ #define _GNU_SOURCE #endif -#include - #include "config/config.h" #include "DAG/dag.h" #include "redis_ai_objects/model.h" @@ -30,6 +28,24 @@ #include "util/dict.h" #include "util/queue.h" +/* Define for RedisAI thread name setter */ +#ifdef __linux__ +#define RAI_PTHREAD_SETNAME(name) pthread_setname_np(pthread_self(), name) +#else +#if (defined __NetBSD__ || defined __FreeBSD__ || defined __OpenBSD__) +#include +#define RAI_PTHREAD_SETNAME(name) pthread_set_name_np(pthread_self(), name) +#else +#if (defined __APPLE__ && defined(MAC_OS_X_VERSION_10_7)) +int pthread_setname_np(const char *name); +#include +#define RAI_PTHREAD_SETNAME(name) pthread_setname_np(name) +#else +#define RAI_PTHREAD_SETNAME(name) +#endif +#endif +#endif + AI_dict *run_queues; long long perqueueThreadPoolSize; diff --git a/src/execution/command_parser.c b/src/execution/command_parser.c index cc85adb0e..f11f4b93a 100644 --- a/src/execution/command_parser.c +++ b/src/execution/command_parser.c @@ -1,4 +1,3 @@ - #include "redismodule.h" #include "run_info.h" #include "command_parser.h" @@ -6,6 +5,8 @@ #include "DAG/dag_parser.h" #include "util/string_utils.h" #include "execution/modelRun_ctx.h" +#include "rmutil/args.h" +#include "redis_ai_objects/stats.h" static int _parseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) { @@ -341,6 +342,227 @@ int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisMod return res; } +int _ModelSetCommand_ParseBatchingArgs(ArgsCursor *ac, RAI_ModelOpts *opts, int backend, + RAI_Error *err) { + unsigned long long batchsize = 0; + if (AC_AdvanceIfMatch(ac, "BATCHSIZE")) { + if (backend == RAI_BACKEND_TFLITE) { + RAI_SetError(err, RAI_EMODELCREATE, + "ERR Auto-batching not supported by the TFLITE backend"); + return REDISMODULE_ERR; + } + if (AC_GetUnsignedLongLong(ac, &batchsize, 0) != AC_OK) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR Invalid argument for BATCHSIZE"); + return REDISMODULE_ERR; + } + } + + unsigned long long minbatchsize = 0; + if (AC_AdvanceIfMatch(ac, "MINBATCHSIZE")) { + if (batchsize == 0) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR MINBATCHSIZE specified without BATCHSIZE"); + return REDISMODULE_ERR; + } + if (AC_GetUnsignedLongLong(ac, &minbatchsize, 0) != AC_OK) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR Invalid argument for MINBATCHSIZE"); + return REDISMODULE_ERR; + } + } + + unsigned long long minbatchtimeout = 0; + if (AC_AdvanceIfMatch(ac, "MINBATCHTIMEOUT")) { + if (batchsize == 0) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR MINBATCHTIMEOUT specified without BATCHSIZE"); + return REDISMODULE_ERR; + } + if (minbatchsize == 0) { + RAI_SetError(err, RAI_EMODELCREATE, + "ERR MINBATCHTIMEOUT specified without MINBATCHSIZE"); + return REDISMODULE_ERR; + } + if (AC_GetUnsignedLongLong(ac, &minbatchtimeout, 0) != AC_OK) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR Invalid argument for MINBATCHTIMEOUT"); + return REDISMODULE_ERR; + } + } + + opts->batchsize = batchsize; + opts->minbatchsize = minbatchsize; + opts->minbatchtimeout = minbatchtimeout; + + return REDISMODULE_OK; +} + +int _ModelSetCommand_ParseIOArgs(RAI_Model *model, ArgsCursor *ac, int backend, RAI_Error *err) { + + ArgsCursor optionsac; + const char *blob_matches[1] = {"BLOB"}; + AC_GetSliceUntilMatches(ac, &optionsac, 1, blob_matches); + + if (optionsac.argc == 0) { + RAI_SetError(err, RAI_EMODELCREATE, + "ERR Insufficient arguments, INPUTS and OUTPUTS not specified for TF model"); + return REDISMODULE_ERR; + } + ArgsCursor inac = {0}; + ArgsCursor outac = {0}; + if (optionsac.argc > 0 && backend == RAI_BACKEND_TENSORFLOW) { + if (!AC_AdvanceIfMatch(&optionsac, "INPUTS")) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR INPUTS not specified for TF model"); + return REDISMODULE_ERR; + } + + const char *matches[1] = {"OUTPUTS"}; + AC_GetSliceUntilMatches(&optionsac, &inac, 1, matches); + if (!AC_IsAtEnd(&optionsac)) { + if (!AC_AdvanceIfMatch(&optionsac, "OUTPUTS")) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR OUTPUTS not specified for TF model"); + return REDISMODULE_ERR; + } + AC_GetSliceToEnd(&optionsac, &outac); + } + } + + model->ninputs = inac.argc; + model->inputs = array_new(char *, model->ninputs); + for (size_t i = 0; i < model->ninputs; i++) { + const char *input_str; + AC_GetString(&inac, &input_str, NULL, 0); + model->inputs = array_append(model->inputs, RedisModule_Strdup(input_str)); + } + model->noutputs = outac.argc; + model->outputs = array_new(char *, model->noutputs); + for (size_t i = 0; i < model->noutputs; i++) { + const char *output_str; + AC_GetString(&outac, &output_str, NULL, 0); + model->outputs = array_append(model->outputs, RedisModule_Strdup(output_str)); + } + return REDISMODULE_OK; +} + +void _ModelSetCommand_ParseBlob(RAI_Model *model, ArgsCursor *ac) { + ArgsCursor blobsac; + AC_GetSliceToEnd(ac, &blobsac); + size_t model_len; + char *model_def; + char *model_data; + + if (blobsac.argc == 1) { + AC_GetString(&blobsac, (const char **)&model_def, &model_len, 0); + model_data = RedisModule_Alloc(model_len); + memcpy(model_data, model_def, model_len); + } else { + // Blobs of large models are chunked, in this case we go over and copy the chunks. + const char *chunks[blobsac.argc]; + size_t chunk_lens[blobsac.argc]; + model_len = 0; + while (!AC_IsAtEnd(&blobsac)) { + AC_GetString(&blobsac, &chunks[blobsac.offset], &chunk_lens[blobsac.offset], 0); + model_len += chunk_lens[blobsac.offset - 1]; + } + model_data = RedisModule_Alloc(model_len); + size_t offset = 0; + for (size_t i = 0; i < blobsac.argc; i++) { + memcpy(model_data + offset, chunks[i], chunk_lens[i]); + offset += chunk_lens[i]; + } + } + model->data = model_data; + model->datalen = model_len; +} + +int ParseModelSetCommand(RedisModuleString **argv, int argc, RAI_Model *model, RAI_Error *err) { + + // Use an args cursor object to go over and parse the command args. + ArgsCursor ac; + ArgsCursor_InitRString(&ac, argv, argc); + + // Parse model key. + RedisModuleString *key_str; + AC_GetRString(&ac, &key_str, 0); + model->infokey = RAI_HoldString(NULL, key_str); + + // Parse argument. + const char *backend_str; + int backend; + AC_GetString(&ac, &backend_str, NULL, 0); + if (strcasecmp(backend_str, "TF") == 0) { + backend = RAI_BACKEND_TENSORFLOW; + } else if (strcasecmp(backend_str, "TFLITE") == 0) { + backend = RAI_BACKEND_TFLITE; + } else if (strcasecmp(backend_str, "TORCH") == 0) { + backend = RAI_BACKEND_TORCH; + } else if (strcasecmp(backend_str, "ONNX") == 0) { + backend = RAI_BACKEND_ONNXRUNTIME; + } else { + RAI_SetError(err, RAI_EMODELCREATE, "ERR unsupported backend"); + return REDISMODULE_ERR; + } + model->backend = backend; + + // Parse argument: check that the device string is "CPU", "GPU" or + // "GPU:" where is a number (contains digits only). + const char *device_str; + AC_GetString(&ac, &device_str, NULL, 0); + bool valid_device = false; + if (strcasecmp(device_str, "CPU") == 0 || strcasecmp(device_str, "GPU") == 0) { + valid_device = true; + } else if (strncasecmp(device_str, "GPU:", 4) == 0 && strlen(device_str) <= 10) { + bool digits_only = true; + for (size_t i = 5; i < strlen(device_str); i++) { + if (device_str[i] < '0' || device_str[i] > '9') { + digits_only = false; + break; + } + } + valid_device = digits_only; + } + if (!valid_device) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR Invalid DEVICE"); + return REDISMODULE_ERR; + } + model->devicestr = RedisModule_Strdup(device_str); + + // Parse argument, and add model key to the stats dict. + RedisModuleString *tag = NULL; + if (AC_AdvanceIfMatch(&ac, "TAG")) { + AC_GetRString(&ac, &tag, 0); + } + if (tag) { + model->tag = RAI_HoldString(NULL, tag); + } else { + model->tag = RedisModule_CreateString(NULL, "", 0); + } + + // Parse the optional args of BATCHSIZE, MINBATCHSIZE and MINBATCHTIMEOUT, and set model opts. + RAI_ModelOpts opts; + if (_ModelSetCommand_ParseBatchingArgs(&ac, &opts, backend, err) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + opts.backends_intra_op_parallelism = getBackendsIntraOpParallelism(); + opts.backends_inter_op_parallelism = getBackendsInterOpParallelism(); + model->opts = opts; + + // Parse inputs and output names (this arguments are mandatory only for TF models) + // and store them in model objects. + if (backend == RAI_BACKEND_TENSORFLOW) { + if (_ModelSetCommand_ParseIOArgs(model, &ac, backend, err) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + } + + // Parse model blob (final argument), and store it in the model. + const char *blob_matches[1] = {"BLOB"}; + AC_AdvanceUntilMatches(&ac, 1, blob_matches); + if (AC_IsAtEnd(&ac)) { + RAI_SetError(err, RAI_EMODELCREATE, "ERR Insufficient arguments, missing model BLOB"); + return REDISMODULE_ERR; + } + AC_Advance(&ac); + _ModelSetCommand_ParseBlob(model, &ac); + return REDISMODULE_OK; +} + int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, RunCommand command, bool ro_dag) { diff --git a/src/execution/command_parser.h b/src/execution/command_parser.h index 74fd6c792..0149cc252 100644 --- a/src/execution/command_parser.h +++ b/src/execution/command_parser.h @@ -25,6 +25,8 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModu int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, int argc); +int ParseModelSetCommand(RedisModuleString **argv, int argc, RAI_Model *model, RAI_Error *err); + /** * @brief Parse and execute RedisAI run command. After parsing and validation, the resulted * runInfo (DAG) is queued and the client is blocked until the execution is complete (async diff --git a/src/redis_ai_objects/model.c b/src/redis_ai_objects/model.c index a47edc00e..b96cecf90 100644 --- a/src/redis_ai_objects/model.c +++ b/src/redis_ai_objects/model.c @@ -43,49 +43,96 @@ int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RA return REDISMODULE_OK; } -RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag, - RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, - const char **outputs, const char *modeldef, size_t modellen, - RAI_Error *err) { - RAI_Model *model; +int ModelCreateBE(RAI_Model *model, RAI_Error *err) { + + int backend = model->backend; + if (backend == RAI_BACKEND_TENSORFLOW) { if (!RAI_backends.tf.model_create_with_nodes) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TF"); - return NULL; + return REDISMODULE_ERR; } - model = RAI_backends.tf.model_create_with_nodes(backend, devicestr, opts, ninputs, inputs, - noutputs, outputs, modeldef, modellen, err); + return RAI_backends.tf.model_create_with_nodes(model, err); + } else if (backend == RAI_BACKEND_TFLITE) { if (!RAI_backends.tflite.model_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TFLITE"); - return NULL; + return REDISMODULE_ERR; } - model = RAI_backends.tflite.model_create(backend, devicestr, opts, modeldef, modellen, err); + return RAI_backends.tflite.model_create(model, err); + } else if (backend == RAI_BACKEND_TORCH) { if (!RAI_backends.torch.model_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); - return NULL; + return REDISMODULE_ERR; } - model = RAI_backends.torch.model_create(backend, devicestr, opts, modeldef, modellen, err); + return RAI_backends.torch.model_create(model, err); + } else if (backend == RAI_BACKEND_ONNXRUNTIME) { if (!RAI_backends.onnx.model_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: ONNX"); - return NULL; + return REDISMODULE_ERR; } - model = RAI_backends.onnx.model_create(backend, devicestr, opts, modeldef, modellen, err); + return RAI_backends.onnx.model_create(model, err); } else { RAI_SetError(err, RAI_EUNSUPPORTEDBACKEND, "ERR Unsupported backend"); - return NULL; + return REDISMODULE_ERR; + } +} + +RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag, + RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, + const char **outputs, const char *modeldef, size_t modellen, + RAI_Error *err) { + + RAI_Model *model = RedisModule_Calloc(1, sizeof(*model)); + model->backend = backend; + model->devicestr = RedisModule_Strdup(devicestr); + if (tag) { + model->tag = RAI_HoldString(NULL, tag); + } else { + model->tag = RedisModule_CreateString(NULL, "", 0); } + model->opts = opts; + model->datalen = modellen; + model->data = RedisModule_Alloc(modellen); + memcpy(model->data, modeldef, modellen); - if (model) { - if (tag) { - model->tag = RAI_HoldString(NULL, tag); - } else { - model->tag = RedisModule_CreateString(NULL, "", 0); + if (backend == RAI_BACKEND_TENSORFLOW) { + model->ninputs = ninputs; + model->noutputs = noutputs; + model->inputs = array_new(char *, ninputs); + model->outputs = array_new(char *, noutputs); + for (size_t i = 0; i < ninputs; i++) { + model->inputs = array_append(model->inputs, RedisModule_Strdup(inputs[i])); + } + for (size_t i = 0; i < noutputs; i++) { + model->outputs = array_append(model->outputs, RedisModule_Strdup(outputs[i])); } } + const char *backend_str = RAI_BackendName(model->backend); + if (ModelCreateBE(model, err) != REDISMODULE_OK) { + // If we got an error *not* because of lazy loading, we fail and unblock. + if (RAI_GetErrorCode(err) != RAI_EBACKENDNOTLOADED) { + RAI_ModelFree(model, err); + return NULL; + } + RedisModule_Log(NULL, "warning", "backend %s not loaded, will try loading default backend", + backend_str); + int ret = RAI_LoadDefaultBackend(NULL, model->backend); + if (ret != REDISMODULE_OK) { + RedisModule_Log(NULL, "error", "could not load %s default backend", backend_str); + RAI_ModelFree(model, err); + return NULL; + } + // Try creating model for backend again. + RAI_ClearError(err); + if (ModelCreateBE(model, err) != REDISMODULE_OK) { + RAI_ModelFree(model, err); + return NULL; + } + } return model; } @@ -94,38 +141,40 @@ void RAI_ModelFree(RAI_Model *model, RAI_Error *err) { return; } - if (model->backend == RAI_BACKEND_TENSORFLOW) { - if (!RAI_backends.tf.model_free) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TF"); - return; - } + RedisModule_Free(model->devicestr); + if (model->tag) { + RedisModule_FreeString(NULL, model->tag); + } + if (model->data) { + RedisModule_Free(model->data); + } + + if (model->backend == RAI_BACKEND_TENSORFLOW && RAI_backends.tf.model_free) { RAI_backends.tf.model_free(model, err); - } else if (model->backend == RAI_BACKEND_TFLITE) { - if (!RAI_backends.tflite.model_free) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TFLITE"); - return; - } + } else if (model->backend == RAI_BACKEND_TFLITE && RAI_backends.tflite.model_free) { RAI_backends.tflite.model_free(model, err); - } else if (model->backend == RAI_BACKEND_TORCH) { - if (!RAI_backends.torch.model_free) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); - return; - } + } else if (model->backend == RAI_BACKEND_TORCH && RAI_backends.torch.model_free) { RAI_backends.torch.model_free(model, err); - } else if (model->backend == RAI_BACKEND_ONNXRUNTIME) { - if (!RAI_backends.onnx.model_free) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: ONNX"); - return; - } + } else if (model->backend == RAI_BACKEND_ONNXRUNTIME && RAI_backends.onnx.model_free) { RAI_backends.onnx.model_free(model, err); - } else { - RAI_SetError(err, RAI_EUNSUPPORTEDBACKEND, "Unsupported backend"); - return; } - RedisModule_FreeString(NULL, model->tag); - RAI_RemoveStatsEntry(model->infokey); + if (model->inputs) { + for (size_t i = 0; i < model->ninputs; i++) { + RedisModule_Free(model->inputs[i]); + } + array_free(model->inputs); + } + if (model->outputs) { + for (size_t i = 0; i < model->noutputs; i++) { + RedisModule_Free(model->outputs[i]); + } + array_free(model->outputs); + } + if (model->infokey) { + RedisModule_FreeString(NULL, model->infokey); + } RedisModule_Free(model); } diff --git a/src/redis_ai_objects/model.h b/src/redis_ai_objects/model.h index a4b2e573c..c2a928466 100644 --- a/src/redis_ai_objects/model.h +++ b/src/redis_ai_objects/model.h @@ -150,6 +150,8 @@ size_t ModelGetNumInputs(RAI_Model *model); * @brief Returns the number of outputs in the model definition. */ size_t ModelGetNumOutputs(RAI_Model *model); + +int ModelCreateBE(RAI_Model *model, RAI_Error *err); /** * Insert the ModelRunCtx to the run queues so it will run asynchronously. * diff --git a/src/redis_ai_objects/stats.c b/src/redis_ai_objects/stats.c index 6e1c7c9d2..0e597ec79 100644 --- a/src/redis_ai_objects/stats.c +++ b/src/redis_ai_objects/stats.c @@ -25,8 +25,8 @@ mstime_t mstime(void) { return ustime() / 1000; } void *RAI_AddStatsEntry(RedisModuleCtx *ctx, RedisModuleString *key, RAI_RunType runtype, RAI_Backend backend, const char *devicestr, RedisModuleString *tag) { - struct RedisAI_RunStats *rstats = NULL; - rstats = RedisModule_Calloc(1, sizeof(struct RedisAI_RunStats)); + + RedisAI_RunStats *rstats = RedisModule_Calloc(1, sizeof(*rstats)); rstats->key = RAI_HoldString(NULL, key); rstats->type = runtype; rstats->backend = backend; diff --git a/src/redis_ai_objects/stats.h b/src/redis_ai_objects/stats.h index 27edea14e..5b046a40f 100644 --- a/src/redis_ai_objects/stats.h +++ b/src/redis_ai_objects/stats.h @@ -13,7 +13,7 @@ #include "redismodule.h" #include "util/dict.h" -struct RedisAI_RunStats { +typedef struct RedisAI_RunStats { RedisModuleString *key; RAI_RunType type; RAI_Backend backend; @@ -23,7 +23,7 @@ struct RedisAI_RunStats { long long samples; long long calls; long long nerrors; -}; +} RedisAI_RunStats; AI_dict *run_stats; diff --git a/src/redisai.c b/src/redisai.c index e244073f5..17c08b922 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -8,8 +8,8 @@ #include "redis_ai_objects/tensor.h" #include "execution/command_parser.h" #include "backends/backends.h" -#include "backends/util.h" #include "execution/background_workers.h" +#include "execution/background_modelset.h" #include "execution/DAG/dag.h" #include "execution/DAG/dag_builder.h" #include "execution/DAG/dag_execute.h" @@ -166,236 +166,23 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, if (argc < 4) return RedisModule_WrongArity(ctx); - ArgsCursor ac; - ArgsCursor_InitRString(&ac, argv + 1, argc - 1); - - RedisModuleString *keystr; - AC_GetRString(&ac, &keystr, 0); - - const char *bckstr; - int backend; - AC_GetString(&ac, &bckstr, NULL, 0); - if (strcasecmp(bckstr, "TF") == 0) { - backend = RAI_BACKEND_TENSORFLOW; - } else if (strcasecmp(bckstr, "TFLITE") == 0) { - backend = RAI_BACKEND_TFLITE; - } else if (strcasecmp(bckstr, "TORCH") == 0) { - backend = RAI_BACKEND_TORCH; - } else if (strcasecmp(bckstr, "ONNX") == 0) { - backend = RAI_BACKEND_ONNXRUNTIME; - } else { - return RedisModule_ReplyWithError(ctx, "ERR unsupported backend"); - } - - const char *devicestr; - AC_GetString(&ac, &devicestr, NULL, 0); - - if (strlen(devicestr) > 10 || strcasecmp(devicestr, "INPUTS") == 0 || - strcasecmp(devicestr, "OUTPUTS") == 0 || strcasecmp(devicestr, "TAG") == 0 || - strcasecmp(devicestr, "BATCHSIZE") == 0 || strcasecmp(devicestr, "MINBATCHSIZE") == 0 || - strcasecmp(devicestr, "MINBATCHTIMEOUT") == 0 || strcasecmp(devicestr, "BLOB") == 0) { - return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE"); - } - - RedisModuleString *tag = NULL; - if (AC_AdvanceIfMatch(&ac, "TAG")) { - AC_GetRString(&ac, &tag, 0); - } - - unsigned long long batchsize = 0; - if (AC_AdvanceIfMatch(&ac, "BATCHSIZE")) { - if (backend == RAI_BACKEND_TFLITE) { - return RedisModule_ReplyWithError( - ctx, "ERR Auto-batching not supported by the TFLITE backend"); - } - if (AC_GetUnsignedLongLong(&ac, &batchsize, 0) != AC_OK) { - return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for BATCHSIZE"); - } - } - - unsigned long long minbatchsize = 0; - if (AC_AdvanceIfMatch(&ac, "MINBATCHSIZE")) { - if (batchsize == 0) { - return RedisModule_ReplyWithError(ctx, "ERR MINBATCHSIZE specified without BATCHSIZE"); - } - if (AC_GetUnsignedLongLong(&ac, &minbatchsize, 0) != AC_OK) { - return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for MINBATCHSIZE"); - } - } - - unsigned long long minbatchtimeout = 0; - if (AC_AdvanceIfMatch(&ac, "MINBATCHTIMEOUT")) { - if (batchsize == 0) { - return RedisModule_ReplyWithError(ctx, - "ERR MINBATCHTIMEOUT specified without BATCHSIZE"); - } - if (minbatchsize == 0) { - return RedisModule_ReplyWithError(ctx, - "ERR MINBATCHTIMEOUT specified without MINBATCHSIZE"); - } - if (AC_GetUnsignedLongLong(&ac, &minbatchtimeout, 0) != AC_OK) { - return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for MINBATCHTIMEOUT"); - } - } - - if (AC_IsAtEnd(&ac)) { - return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing model BLOB"); - } - - ArgsCursor optionsac; - const char *blob_matches[] = {"BLOB"}; - AC_GetSliceUntilMatches(&ac, &optionsac, 1, blob_matches); - - if (optionsac.argc == 0 && backend == RAI_BACKEND_TENSORFLOW) { - return RedisModule_ReplyWithError( - ctx, "ERR Insufficient arguments, INPUTS and OUTPUTS not specified"); - } - - ArgsCursor inac = {0}; - ArgsCursor outac = {0}; - if (optionsac.argc > 0 && backend == RAI_BACKEND_TENSORFLOW) { - if (!AC_AdvanceIfMatch(&optionsac, "INPUTS")) { - return RedisModule_ReplyWithError(ctx, "ERR INPUTS not specified"); - } - - const char *matches[] = {"OUTPUTS"}; - AC_GetSliceUntilMatches(&optionsac, &inac, 1, matches); - - if (!AC_IsAtEnd(&optionsac)) { - if (!AC_AdvanceIfMatch(&optionsac, "OUTPUTS")) { - return RedisModule_ReplyWithError(ctx, "ERR OUTPUTS not specified"); - } - - AC_GetSliceToEnd(&optionsac, &outac); - } - } - - size_t ninputs = inac.argc; - const char *inputs[ninputs]; - for (size_t i = 0; i < ninputs; i++) { - AC_GetString(&inac, inputs + i, NULL, 0); - } - - size_t noutputs = outac.argc; - const char *outputs[noutputs]; - for (size_t i = 0; i < noutputs; i++) { - AC_GetString(&outac, outputs + i, NULL, 0); - } - - RAI_ModelOpts opts = { - .batchsize = batchsize, - .minbatchsize = minbatchsize, - .minbatchtimeout = minbatchtimeout, - .backends_intra_op_parallelism = getBackendsIntraOpParallelism(), - .backends_inter_op_parallelism = getBackendsInterOpParallelism(), - }; - - RAI_Model *model = NULL; - - AC_AdvanceUntilMatches(&ac, 1, blob_matches); - - if (AC_IsAtEnd(&ac)) { - return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing model BLOB"); - } - - AC_Advance(&ac); - - ArgsCursor blobsac; - AC_GetSliceToEnd(&ac, &blobsac); - - size_t modellen; - char *modeldef; - - if (blobsac.argc == 1) { - AC_GetString(&blobsac, (const char **)&modeldef, &modellen, 0); - } else { - const char *chunks[blobsac.argc]; - size_t chunklens[blobsac.argc]; - modellen = 0; - while (!AC_IsAtEnd(&blobsac)) { - AC_GetString(&blobsac, &chunks[blobsac.offset], &chunklens[blobsac.offset], 0); - modellen += chunklens[blobsac.offset - 1]; - } - - modeldef = RedisModule_Calloc(modellen, sizeof(char)); - size_t offset = 0; - for (size_t i = 0; i < blobsac.argc; i++) { - memcpy(modeldef + offset, chunks[i], chunklens[i]); - offset += chunklens[i]; - } - } - - RAI_Error err = {0}; - - model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs, - modeldef, modellen, &err); - - if (err.code == RAI_EBACKENDNOTLOADED) { - RedisModule_Log(ctx, "warning", "backend %s not loaded, will try loading default backend", - bckstr); - int ret = RAI_LoadDefaultBackend(ctx, backend); - if (ret == REDISMODULE_ERR) { - RedisModule_Log(ctx, "error", "could not load %s default backend", bckstr); - int ret = RedisModule_ReplyWithError(ctx, "ERR Could not load backend"); - RAI_ClearError(&err); - return ret; - } - RAI_ClearError(&err); - model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs, - modeldef, modellen, &err); - } - - if (blobsac.argc > 1) { - RedisModule_Free(modeldef); - } - - if (err.code != RAI_OK) { - RedisModule_Log(ctx, "error", "%s", err.detail); - int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline); - RAI_ClearError(&err); - return ret; - } - - // TODO: if backend loaded, make sure there's a queue - RunQueueInfo *run_queue_info = NULL; - if (ensureRunQueue(devicestr, &run_queue_info) != REDISMODULE_OK) { - RAI_ModelFree(model, &err); - if (err.code != RAI_OK) { - RedisModule_Log(ctx, "error", "%s", err.detail); - int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline); - RAI_ClearError(&err); - return ret; - } - return RedisModule_ReplyWithError(ctx, - "ERR Could not initialize queue on requested device"); - } - - RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE); - int type = RedisModule_KeyType(key); - if (type != REDISMODULE_KEYTYPE_EMPTY && - !(type == REDISMODULE_KEYTYPE_MODULE && - RedisModule_ModuleTypeGetType(key) == RedisAI_ModelType)) { - RedisModule_CloseKey(key); - RAI_ModelFree(model, &err); - if (err.code != RAI_OK) { - RedisModule_Log(ctx, "error", "%s", err.detail); - int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline); - RAI_ClearError(&err); - return ret; - } - return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); - } - - RedisModule_ModuleTypeSetValue(key, RedisAI_ModelType, model); - - model->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_MODEL, backend, devicestr, tag); - - RedisModule_CloseKey(key); - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - - RedisModule_ReplicateVerbatim(ctx); - + // Save the command args in the context and send it to be processed in the background thread. + RedisModuleString **args = array_new(RedisModuleString *, argc - 1); + for (size_t i = 1; i < argc; i++) { + RAI_HoldString(NULL, argv[i]); + args = array_append(args, argv[i]); + } + ModelSetCtx *model_ctx = RedisModule_Alloc(sizeof(*model_ctx)); + model_ctx->args = args; + model_ctx->client = + RedisModule_BlockClient(ctx, RedisAI_ModelSet_Reply, NULL, ModelSet_FreeData, 0); + + // Create a "modelset" job and push it to the queue. + Job *job = JobCreate(model_ctx, ModelSet_Execute); + pthread_mutex_lock(&modelSet_QueueInfo->run_queue_mutex); + queuePush(modelSet_QueueInfo->run_queue, job); + pthread_cond_signal(&modelSet_QueueInfo->queue_condition_var); + pthread_mutex_unlock(&modelSet_QueueInfo->run_queue_mutex); return REDISMODULE_OK; } @@ -531,9 +318,11 @@ int RedisAI_ModelDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, return REDISMODULE_ERR; } - RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE); + RedisModuleString *model_key_str = argv[1]; + RedisModuleKey *key = RedisModule_OpenKey(ctx, model_key_str, REDISMODULE_WRITE); RedisModule_DeleteKey(key); RedisModule_CloseKey(key); + RAI_RemoveStatsEntry(model_key_str); RedisModule_ReplicateVerbatim(ctx); return RedisModule_ReplyWithSimpleString(ctx, "OK"); @@ -1335,7 +1124,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RedisModule_Log(ctx, "warning", "Queue not initialized for device CPU"); return REDISMODULE_ERR; } - + if (Init_BG_ModelSet() != REDISMODULE_OK) { + RedisModule_Log(ctx, "warning", "Background ModelSet queue was not initialized"); + return REDISMODULE_ERR; + } run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); return REDISMODULE_OK; diff --git a/src/serialization/RDB/decoder/current/v1/decode_v1.c b/src/serialization/RDB/decoder/current/v1/decode_v1.c index 0947d6aca..a3101ea82 100644 --- a/src/serialization/RDB/decoder/current/v1/decode_v1.c +++ b/src/serialization/RDB/decoder/current/v1/decode_v1.c @@ -74,36 +74,40 @@ void *RAI_RDBLoadModel_v1(RedisModuleIO *io) { char *devicestr = NULL; RedisModuleString *tag = NULL; size_t ninputs = 0; - const char **inputs = NULL; + char **inputs = NULL; size_t noutputs = 0; - const char **outputs = NULL; + char **outputs = NULL; char *buffer = NULL; - + RAI_Error err = {0}; + char *error_str = "Experienced a short read while reading a model from RDB"; + + RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io); + RedisModuleString *key_str = + RedisModule_CreateStringFromString(NULL, RedisModule_GetKeyNameFromIO(io)); + if (!key_str) { + RedisModule_LogIOError(io, "error", "Couldn't get model key name from RDB"); + return NULL; + } RAI_Backend backend = RedisModule_LoadUnsigned(io); devicestr = RedisModule_LoadStringBuffer(io, NULL); tag = RedisModule_LoadString(io); - const size_t batchsize = RedisModule_LoadUnsigned(io); const size_t minbatchsize = RedisModule_LoadUnsigned(io); ninputs = RedisModule_LoadUnsigned(io); if (RedisModule_IsIOError(io)) goto cleanup; - - inputs = RedisModule_Alloc(ninputs * sizeof(char *)); - + inputs = array_new(char *, ninputs); for (size_t i = 0; i < ninputs; i++) { - inputs[i] = RedisModule_LoadStringBuffer(io, NULL); + inputs = array_append(inputs, RedisModule_LoadStringBuffer(io, NULL)); } noutputs = RedisModule_LoadUnsigned(io); if (RedisModule_IsIOError(io)) goto cleanup; - - outputs = RedisModule_Alloc(noutputs * sizeof(char *)); - + outputs = array_new(char *, noutputs); for (size_t i = 0; i < noutputs; i++) { - outputs[i] = RedisModule_LoadStringBuffer(io, NULL); + outputs = array_append(outputs, RedisModule_LoadStringBuffer(io, NULL)); } RAI_ModelOpts opts = { @@ -130,48 +134,41 @@ void *RAI_RDBLoadModel_v1(RedisModuleIO *io) { RedisModule_Free(chunk_buffer); } - RAI_Error err = {0}; - RAI_Model *model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, - outputs, buffer, len, &err); - - if (err.code == RAI_EBACKENDNOTLOADED) { - RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io); - int ret = RAI_LoadDefaultBackend(ctx, backend); - if (ret == REDISMODULE_ERR) { - RedisModule_Log(ctx, "error", "Could not load default backend"); - RAI_ClearError(&err); + RAI_Model *model = RedisModule_Calloc(1, sizeof(*model)); + model->infokey = RAI_HoldString(NULL, key_str); + model->backend = backend; + model->devicestr = devicestr; + model->tag = tag; + model->inputs = inputs; + model->ninputs = ninputs; + model->outputs = outputs; + model->noutputs = noutputs; + model->opts = opts; + model->data = buffer; + model->datalen = len; + + const char *backend_str = RAI_BackendName(backend); + if (ModelCreateBE(model, &err) != REDISMODULE_OK) { + // If we got an error *not* because of lazy loading, we fail and unblock. + if (RAI_GetErrorCode(&err) != RAI_EBACKENDNOTLOADED) { + error_str = (char *)RAI_GetError(&err); goto cleanup; } + RedisModule_Log(ctx, "warning", "backend %s not loaded, will try loading default backend", + backend_str); + int ret = RAI_LoadDefaultBackend(NULL, model->backend); + if (ret != REDISMODULE_OK) { + sprintf(error_str, "could not load %s default backend", backend_str); + goto cleanup; + } + // Try creating model for backend again. RAI_ClearError(&err); - model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs, - buffer, len, &err); - } - - if (err.code != RAI_OK) { - RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io); - RedisModule_Log(ctx, "error", "%s", err.detail); - RAI_ClearError(&err); - goto cleanup; - } - - RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io); - RedisModuleString *stats_keystr = - RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io)); - - model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag); - - for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free((void *)inputs[i]); - } - RedisModule_Free(inputs); - for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free((void *)outputs[i]); + if (ModelCreateBE(model, &err) != REDISMODULE_OK) { + error_str = (char *)RAI_GetError(&err); + goto cleanup; + } } - RedisModule_Free(outputs); - RedisModule_Free(buffer); - RedisModule_Free(devicestr); - RedisModule_FreeString(NULL, stats_keystr); - RedisModule_FreeString(NULL, tag); + RAI_AddStatsEntry(ctx, key_str, RAI_MODEL, backend, devicestr, tag); return model; @@ -182,22 +179,25 @@ void *RAI_RDBLoadModel_v1(RedisModuleIO *io) { RedisModule_FreeString(NULL, tag); if (inputs) { for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free((void *)inputs[i]); + RedisModule_Free(inputs[i]); } - RedisModule_Free(inputs); + array_free(inputs); } if (outputs) { for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free((void *)outputs[i]); + RedisModule_Free(outputs[i]); } - RedisModule_Free(outputs); + array_free(outputs); } if (buffer) RedisModule_Free(buffer); - RedisModule_LogIOError(io, "error", "Experienced a short read while reading a model from RDB"); + RedisModule_LogIOError(io, "error", "%s", error_str); + if (RAI_GetErrorCode(&err) != RAI_OK) { + RAI_ClearError(&err); + } return NULL; } diff --git a/src/util/queue.c b/src/util/queue.c index 3c22488e0..ec452c804 100644 --- a/src/util/queue.c +++ b/src/util/queue.c @@ -92,6 +92,8 @@ queueItem *queueEvict(queue *queue, queueItem *item) { long long queueLength(queue *queue) { return queue->len; } +void *queueItemGetValue(queueItem *item) { return item->value; } + void queueRelease(queue *queue) { unsigned long len; queueItem *current; diff --git a/src/util/queue.h b/src/util/queue.h index af755181d..acce2256a 100644 --- a/src/util/queue.h +++ b/src/util/queue.h @@ -27,6 +27,7 @@ void queuePushFront(queue *queue, void *value); queueItem *queuePop(queue *queue); queueItem *queueFront(queue *queue); queueItem *queueNext(queueItem *item); +void *queueItemGetValue(queueItem *item); queueItem *queueEvict(queue *queue, queueItem *item); long long queueLength(queue *queue); void queueRelease(queue *queue); diff --git a/tests/flow/tests_pytorch.py b/tests/flow/tests_pytorch.py index ed7c0f1b9..91d2a0fe0 100644 --- a/tests/flow/tests_pytorch.py +++ b/tests/flow/tests_pytorch.py @@ -924,8 +924,9 @@ def test_pytorch_model_rdb_save_load(env): env.start() con = env.getConnection() model_serialized_after_rdbload = con.execute_command('AI.MODELGET', 'm{1}', 'BLOB') - con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}') - _, dtype_after_rdbload, _, shape_after_rdbload, _, data_after_rdbload = con.execute_command('AI.TENSORGET', 'c{1}', 'META', 'VALUES') + + con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'd{1}') + _, dtype_after_rdbload, _, shape_after_rdbload, _, data_after_rdbload = con.execute_command('AI.TENSORGET', 'd{1}', 'META', 'VALUES') # Assert in memory model metadata is equal to loaded model metadata env.assertTrue(model_serialized_memory[1:6] == model_serialized_after_rdbload[1:6]) @@ -967,6 +968,7 @@ def test_parallelism(): env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2") env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2") + def test_modelget_for_tuple_output(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) diff --git a/tests/flow/tests_tensorflow.py b/tests/flow/tests_tensorflow.py index c7e000f0e..10ea60122 100644 --- a/tests/flow/tests_tensorflow.py +++ b/tests/flow/tests_tensorflow.py @@ -426,7 +426,7 @@ def test_run_tf_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Insufficient arguments, INPUTS and OUTPUTS not specified", exception.__str__()) + env.assertEqual("Insufficient arguments, INPUTS and OUTPUTS not specified for TF model", exception.__str__()) try: con.execute_command('AI.MODELSET', 'm_8{1}', 'TF', DEVICE, diff --git a/tests/flow/tests_tflite.py b/tests/flow/tests_tflite.py index 29c824eeb..ba8acff7e 100644 --- a/tests/flow/tests_tflite.py +++ b/tests/flow/tests_tflite.py @@ -106,7 +106,7 @@ def test_run_tflite_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Insufficient arguments, missing model BLOB", exception.__str__()) + env.assertEqual("Invalid DEVICE", exception.__str__()) try: con.execute_command('AI.MODELSET', 'm_2{1}', 'BLOB', model_pb) From 2df73fd940db041c07af92f2df997aad3bd3fbf5 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sat, 10 Apr 2021 15:40:05 +0300 Subject: [PATCH 2/4] Fix in rdb loading --- src/execution/command_parser.c | 8 +++++--- src/serialization/RDB/decoder/current/v1/decode_v1.c | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/execution/command_parser.c b/src/execution/command_parser.c index f11f4b93a..60db506a5 100644 --- a/src/execution/command_parser.c +++ b/src/execution/command_parser.c @@ -500,14 +500,16 @@ int ParseModelSetCommand(RedisModuleString **argv, int argc, RAI_Model *model, R } model->backend = backend; - // Parse argument: check that the device string is "CPU", "GPU" or - // "GPU:" where is a number (contains digits only). + // Parse argument: check that the device string is "CPU", "GPU", + // "GPU:" or "GPU:, where is a number (contains digits only). const char *device_str; AC_GetString(&ac, &device_str, NULL, 0); bool valid_device = false; if (strcasecmp(device_str, "CPU") == 0 || strcasecmp(device_str, "GPU") == 0) { valid_device = true; - } else if (strncasecmp(device_str, "GPU:", 4) == 0 && strlen(device_str) <= 10) { + } else if ((strncasecmp(device_str, "GPU:", 4) == 0 || + strncasecmp(device_str, "CPU:", 4) == 0) && + strlen(device_str) <= 10) { bool digits_only = true; for (size_t i = 5; i < strlen(device_str); i++) { if (device_str[i] < '0' || device_str[i] > '9') { diff --git a/src/serialization/RDB/decoder/current/v1/decode_v1.c b/src/serialization/RDB/decoder/current/v1/decode_v1.c index a3101ea82..cac7e04a3 100644 --- a/src/serialization/RDB/decoder/current/v1/decode_v1.c +++ b/src/serialization/RDB/decoder/current/v1/decode_v1.c @@ -135,6 +135,7 @@ void *RAI_RDBLoadModel_v1(RedisModuleIO *io) { } RAI_Model *model = RedisModule_Calloc(1, sizeof(*model)); + model->refCount = 1; model->infokey = RAI_HoldString(NULL, key_str); model->backend = backend; model->devicestr = devicestr; From 6e613af148c6ffc58b9a57b8257acb4d878243ce Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 11 Apr 2021 12:54:28 +0300 Subject: [PATCH 3/4] Free stats when releasing model type. --- src/execution/background_modelset.c | 2 ++ src/redis_ai_types/model_type.c | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/execution/background_modelset.c b/src/execution/background_modelset.c index 9a16691d7..6eb137a1b 100644 --- a/src/execution/background_modelset.c +++ b/src/execution/background_modelset.c @@ -102,6 +102,8 @@ int RedisAI_ModelSet_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int ar RAI_ModelGetShallowCopy(model_ctx->model); RedisModule_ReplyWithSimpleString(ctx, "OK"); + + // todo: this should be replaced with RedisModule_Replicate (currently not replicating properly) RedisModule_ReplicateVerbatim(ctx); return REDISMODULE_OK; } diff --git a/src/redis_ai_types/model_type.c b/src/redis_ai_types/model_type.c index 817bd2eb2..0278c37d8 100644 --- a/src/redis_ai_types/model_type.c +++ b/src/redis_ai_types/model_type.c @@ -28,7 +28,9 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi static void RAI_Model_DTFree(void *value) { RAI_Error err = {0}; - RAI_ModelFree(value, &err); + RAI_Model *model = (RAI_Model *)value; + RAI_RemoveStatsEntry(model->infokey); + RAI_ModelFree(model, &err); if (err.code != RAI_OK) { printf("ERR: %s\n", err.detail); RAI_ClearError(&err); From 989bd6e1efb37b1a071d52467cd48302fa54c454 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 12 Apr 2021 18:37:10 +0300 Subject: [PATCH 4/4] Fix rdb decode v0 to fit the new changes. --- .../RDB/decoder/previous/v0/decode_v0.c | 119 +++++++++--------- 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/src/serialization/RDB/decoder/previous/v0/decode_v0.c b/src/serialization/RDB/decoder/previous/v0/decode_v0.c index 810438771..27385fa91 100644 --- a/src/serialization/RDB/decoder/previous/v0/decode_v0.c +++ b/src/serialization/RDB/decoder/previous/v0/decode_v0.c @@ -66,16 +66,25 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) { char *devicestr = NULL; RedisModuleString *tag = NULL; size_t ninputs = 0; - const char **inputs = NULL; + char **inputs = NULL; size_t noutputs = 0; - const char **outputs = NULL; + char **outputs = NULL; char *buffer = NULL; - + RAI_Error err = {0}; + char *error_str = "Experienced a short read while reading a model from RDB"; + + RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io); + RedisModuleString *key_str = + RedisModule_CreateStringFromString(NULL, RedisModule_GetKeyNameFromIO(io)); + if (!key_str) { + RedisModule_LogIOError(io, "error", "Couldn't get model key name from RDB"); + return NULL; + } RAI_Backend backend = RedisModule_LoadUnsigned(io); devicestr = RedisModule_LoadStringBuffer(io, NULL); - size_t len; - char *cstr_tag = RedisModule_LoadStringBuffer(io, &len); - tag = RedisModule_CreateString(NULL, cstr_tag, len - 1); + size_t tag_len; + char *cstr_tag = RedisModule_LoadStringBuffer(io, &tag_len); + tag = RedisModule_CreateString(NULL, cstr_tag, tag_len - 1); RedisModule_Free(cstr_tag); const size_t batchsize = RedisModule_LoadUnsigned(io); @@ -84,21 +93,17 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) { ninputs = RedisModule_LoadUnsigned(io); if (RedisModule_IsIOError(io)) goto cleanup; - - inputs = RedisModule_Alloc(ninputs * sizeof(char *)); - + inputs = array_new(char *, ninputs); for (size_t i = 0; i < ninputs; i++) { - inputs[i] = RedisModule_LoadStringBuffer(io, NULL); + inputs = array_append(inputs, RedisModule_LoadStringBuffer(io, NULL)); } noutputs = RedisModule_LoadUnsigned(io); if (RedisModule_IsIOError(io)) goto cleanup; - - outputs = RedisModule_Alloc(noutputs * sizeof(char *)); - + outputs = array_new(char *, noutputs); for (size_t i = 0; i < noutputs; i++) { - outputs[i] = RedisModule_LoadStringBuffer(io, NULL); + outputs = array_append(outputs, RedisModule_LoadStringBuffer(io, NULL)); } RAI_ModelOpts opts = { @@ -108,52 +113,47 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) { .backends_inter_op_parallelism = getBackendsInterOpParallelism(), }; + size_t len; buffer = RedisModule_LoadStringBuffer(io, &len); if (RedisModule_IsIOError(io)) goto cleanup; - RAI_Error err = {0}; - RAI_Model *model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, - outputs, buffer, len, &err); - - if (err.code == RAI_EBACKENDNOTLOADED) { - RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io); - int ret = RAI_LoadDefaultBackend(ctx, backend); - if (ret == REDISMODULE_ERR) { - RedisModule_Log(ctx, "error", "Could not load default backend"); - RAI_ClearError(&err); + RAI_Model *model = RedisModule_Calloc(1, sizeof(*model)); + model->refCount = 1; + model->infokey = RAI_HoldString(NULL, key_str); + model->backend = backend; + model->devicestr = devicestr; + model->tag = tag; + model->inputs = inputs; + model->ninputs = ninputs; + model->outputs = outputs; + model->noutputs = noutputs; + model->opts = opts; + model->data = buffer; + model->datalen = len; + + const char *backend_str = RAI_BackendName(backend); + if (ModelCreateBE(model, &err) != REDISMODULE_OK) { + // If we got an error *not* because of lazy loading, we fail and unblock. + if (RAI_GetErrorCode(&err) != RAI_EBACKENDNOTLOADED) { + error_str = (char *)RAI_GetError(&err); goto cleanup; } + RedisModule_Log(ctx, "warning", "backend %s not loaded, will try loading default backend", + backend_str); + int ret = RAI_LoadDefaultBackend(NULL, model->backend); + if (ret != REDISMODULE_OK) { + sprintf(error_str, "could not load %s default backend", backend_str); + goto cleanup; + } + // Try creating model for backend again. RAI_ClearError(&err); - model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs, - buffer, len, &err); - } - - if (err.code != RAI_OK) { - RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io); - RedisModule_Log(ctx, "error", "%s", err.detail); - RAI_ClearError(&err); - goto cleanup; - } - - RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io); - RedisModuleString *stats_keystr = - RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io)); - - model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag); - - for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free((void *)inputs[i]); - } - RedisModule_Free(inputs); - for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free((void *)outputs[i]); + if (ModelCreateBE(model, &err) != REDISMODULE_OK) { + error_str = (char *)RAI_GetError(&err); + goto cleanup; + } } - RedisModule_Free(outputs); - RedisModule_Free(buffer); - RedisModule_Free(devicestr); - RedisModule_FreeString(NULL, stats_keystr); - RedisModule_FreeString(NULL, tag); + RAI_AddStatsEntry(ctx, key_str, RAI_MODEL, backend, devicestr, tag); return model; @@ -161,25 +161,28 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) { if (devicestr) RedisModule_Free(devicestr); if (tag) - RedisModule_Free(tag); + RedisModule_FreeString(NULL, tag); if (inputs) { for (size_t i = 0; i < ninputs; i++) { - RedisModule_Free((void *)inputs[i]); + RedisModule_Free(inputs[i]); } - RedisModule_Free(inputs); + array_free(inputs); } if (outputs) { for (size_t i = 0; i < noutputs; i++) { - RedisModule_Free((void *)outputs[i]); + RedisModule_Free(outputs[i]); } - RedisModule_Free(outputs); + array_free(outputs); } if (buffer) RedisModule_Free(buffer); - RedisModule_LogIOError(io, "error", "Experienced a short read while reading a model from RDB"); + RedisModule_LogIOError(io, "error", "%s", error_str); + if (RAI_GetErrorCode(&err) != RAI_OK) { + RAI_ClearError(&err); + } return NULL; }