From d082fc2e67fe17531d2b2fa71dad6168e5a4ea85 Mon Sep 17 00:00:00 2001 From: Jeremy Tubongbanua Date: Mon, 29 Jul 2024 10:45:15 -0400 Subject: [PATCH] feat: connection_hooks.h/.c --- .../atclient/include/atclient/connection.h | 125 +-- .../include/atclient/connection_hooks.h | 64 ++ packages/atclient/src/connection.c | 820 ++++++++++++------ packages/atclient/src/connection_hooks.c | 393 +++++++++ .../tests/test_atclient_connection.c | 14 +- 5 files changed, 1051 insertions(+), 365 deletions(-) create mode 100644 packages/atclient/include/atclient/connection_hooks.h create mode 100644 packages/atclient/src/connection_hooks.c diff --git a/packages/atclient/include/atclient/connection.h b/packages/atclient/include/atclient/connection.h index d3930c90..5b48160a 100644 --- a/packages/atclient/include/atclient/connection.h +++ b/packages/atclient/include/atclient/connection.h @@ -7,8 +7,7 @@ #include #include #include - -#define ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE 128 // the size of the buffer for the host name +#include "atclient/connection_hooks.h" // represents the type of connection typedef enum atclient_connection_type { @@ -16,39 +15,21 @@ typedef enum atclient_connection_type { ATCLIENT_CONNECTION_TYPE_ATSERVER // uses 'noop:0\r\n' to check if it is connected } atclient_connection_type; -typedef int(atclient_connection_send_hook)(const unsigned char *src, const size_t src_len, unsigned char *recv, - const size_t recv_size, size_t *recv_len); - -typedef enum atclient_connection_hook_type { - ATCLIENT_CONNECTION_HOOK_TYPE_NONE = 0, - ATCLIENT_CONNECTION_HOOK_TYPE_PRE_SEND, - ATCLIENT_CONNECTION_HOOK_TYPE_POST_SEND, - ATCLIENT_CONNECTION_HOOK_TYPE_PRE_RECV, - ATCLIENT_CONNECTION_HOOK_TYPE_POST_RECV, -} atclient_connection_hook_type; - -typedef struct atclient_connection_hooks { - bool _is_nested_call; // internal variable for preventing infinite recursion (hooks cannot trigger other hooks in - // their nested calls) - atclient_connection_send_hook *pre_send; - atclient_connection_send_hook *post_send; - atclient_connection_send_hook *pre_recv; - atclient_connection_send_hook *post_recv; - bool readonly_src; -} atclient_connection_hooks; - typedef struct atclient_connection { - atclient_connection_type type; + atclient_connection_type type; // set in atclient_connection_init + + bool _is_host_initialized: 1; + char *host; // example: "root.atsign.org" - char host[ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE]; - int port; // example: 64 + bool _is_port_initialized: 1; + uint16_t port; // example: 64 // atclient_connection_connect sets this to true and atclient_connection_disconnect sets this to false // this does not mean that the connection is still alive, it just means that the connection was established at least // once, at some point, check atclient_connection_is_connected for a live status on the connection - // _should_be_connected also serves as an internal boolean to check if the following mbedlts contexts have been + // _is_connection_enabled also serves as an internal boolean to check if the following mbedlts contexts have been // initialized and need to be freed at the end - bool _should_be_connected; + bool _is_connection_enabled: 1; mbedtls_net_context net; mbedtls_ssl_context ssl; mbedtls_ssl_config ssl_config; @@ -56,7 +37,7 @@ typedef struct atclient_connection { mbedtls_entropy_context entropy; mbedtls_ctr_drbg_context ctr_drbg; - bool _is_hooks_enabled; + bool _is_hooks_enabled: 1; atclient_connection_hooks *hooks; } atclient_connection; @@ -70,6 +51,13 @@ typedef struct atclient_connection { */ void atclient_connection_init(atclient_connection *ctx, atclient_connection_type type); +/** + * @brief free memory allocated by the init function + * + * @param ctx the struct which was previously initialized + */ +void atclient_connection_free(atclient_connection *ctx); + /** * @brief after initializing a connection context, connect to a host and port * @@ -78,7 +66,28 @@ void atclient_connection_init(atclient_connection *ctx, atclient_connection_type * @param port the port to connect to * @return int 0 on success, otherwise error */ -int atclient_connection_connect(atclient_connection *ctx, const char *host, const int port); +int atclient_connection_connect(atclient_connection *ctx, const char *host, const uint16_t port); + +/** + * @brief Reads data from the connection + * + * @param ctx the connection initialized and connected using atclient_connection_init and atclient_connection_connect + * @param value a double pointer that will be allocated by the function to the data read + * @param value_len the length of the data read, will be set by the function + * @param value_max_len the maximum length of the data to read, setting this to 0 means no limit + * @return int 0 on success + */ +int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len, const size_t value_max_len); + +/** + * @brief Write data to the connection + * + * @param ctx connection initialized and connected using atclient_connection_init and atclient_connection_connect + * @param value the data to write + * @param value_len the length of the data to write + * @return int 0 on success + */ +int atclient_connection_write(atclient_connection *ctx, const unsigned char *value, const size_t value_len); /** * @brief send data to the connection @@ -112,60 +121,4 @@ int atclient_connection_disconnect(atclient_connection *ctx); */ bool atclient_connection_is_connected(atclient_connection *ctx); -/** - * @brief free memory allocated by the init function - * - * @param ctx the struct which was previously initialized - */ -void atclient_connection_free(atclient_connection *ctx); - -/** - * @brief Initialize the hooks memory allocation - * - * @param ctx the struct for the connection - */ -void atclient_connection_enable_hooks(atclient_connection *ctx); - -/** - * @brief Add a hook to be called during the connection lifecycle - * - * @param ctx the struct for the connection - * @param type the hook type you want to add - * @param hook the hook function itself - * - * @return int 0 on success, otherwise error - */ -int atclient_connection_hooks_set(atclient_connection *ctx, atclient_connection_hook_type type, void *hook); - -/** - * @brief Set whether the readonly_src status for all hooks - * - * @param ctx the struct for the connection - * @param readonly_src the new state for readonly_src - * - * @note For performance, keep readonly_src set to true if you don't need to write access to src - */ -void atclient_connection_hooks_set_readonly_src(atclient_connection *ctx, bool readonly_src); - -/** - * @brief Write data to the connection - * - * @param ctx connection initialized and connected using atclient_connection_init and atclient_connection_connect - * @param value the data to write - * @param value_len the length of the data to write - * @return int 0 on success - */ -int atclient_connection_write(atclient_connection *ctx, const unsigned char *value, const size_t value_len); - -/** - * @brief Reads data from the connection - * - * @param ctx the connection initialized and connected using atclient_connection_init and atclient_connection_connect - * @param value a double pointer that will be allocated by the function to the data read - * @param value_len the length of the data read, will be set by the function - * @param value_max_len the maximum length of the data to read, setting this to 0 means no limit - * @return int 0 on success - */ -int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len, const size_t value_max_len); - #endif diff --git a/packages/atclient/include/atclient/connection_hooks.h b/packages/atclient/include/atclient/connection_hooks.h new file mode 100644 index 00000000..04a48ba4 --- /dev/null +++ b/packages/atclient/include/atclient/connection_hooks.h @@ -0,0 +1,64 @@ +#ifndef ATCLIENT_CONNECTION_HOOKS_H +#define ATCLIENT_CONNECTION_HOOKS_H + +#include +#include +#include + +#define VALUE_INITIALIZED 0b00000001 + +#define ATCLIENT_CONNECTION_HOOKS_PRE_READ_INDEX 0 +#define ATCLIENT_CONNECTION_HOOKS_POST_READ_INDEX 0 +#define ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INDEX 0 +#define ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INDEX 0 + +#define ATCLIENT_CONNECTION_HOOKS_PRE_READ_INITIALIZED (VALUE_INITIALIZED << 0) +#define ATCLIENT_CONNECTION_HOOKS_POST_READ_INITIALIZED (VALUE_INITIALIZED << 1) +#define ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INITIALIZED (VALUE_INITIALIZED << 2) +#define ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INITIALIZED (VALUE_INITIALIZED << 3) + +typedef struct atclient_connection atclient_connection; + +typedef struct atclient_connection_hook_params { + unsigned char *src; + size_t src_len; + unsigned char *recv; + size_t recv_size; + size_t *recv_len; +} atclient_connection_hook_params; + +typedef int(atclient_connection_hook)(atclient_connection_hook_params *params); + +typedef enum atclient_connection_hook_type { + ATCLIENT_CONNECTION_HOOK_TYPE_NONE = 0, + ATCLIENT_CONNECTION_HOOK_TYPE_PRE_READ, + ATCLIENT_CONNECTION_HOOK_TYPE_POST_READ, + ATCLIENT_CONNECTION_HOOK_TYPE_PRE_WRITE, + ATCLIENT_CONNECTION_HOOK_TYPE_POST_WRITE, +} atclient_connection_hook_type; + +typedef struct atclient_connection_hooks { + bool _is_nested_call; + bool readonly_src; + atclient_connection_hook *pre_read; + atclient_connection_hook *post_read; + atclient_connection_hook *pre_write; + atclient_connection_hook *post_write; + uint8_t _initialized_fields[1]; +} atclient_connection_hooks; + +bool atclient_connection_hooks_is_enabled(atclient_connection *ctx); +int atclient_connection_hooks_enable(atclient_connection *conn); +void atclient_connection_hooks_disable(atclient_connection *conn); + +// Q. Why is hook a void pointer? +// A. In case we want to add future hook types which use a different function signature +int atclient_connection_hooks_set(atclient_connection *ctx, const atclient_connection_hook_type type, void *hook); + +bool atclient_connection_hooks_is_pre_read_initialized(const atclient_connection *ctx); +bool atclient_connection_hooks_is_post_read_initialized(const atclient_connection *ctx); +bool atclient_connection_hooks_is_pre_write_initialized(const atclient_connection *ctx); +bool atclient_connection_hooks_is_post_write_initialized(const atclient_connection *ctx); + + +#endif \ No newline at end of file diff --git a/packages/atclient/src/connection.c b/packages/atclient/src/connection.c index 971d187c..64016a4b 100644 --- a/packages/atclient/src/connection.c +++ b/packages/atclient/src/connection.c @@ -16,113 +16,125 @@ #define TAG "connection" /* Concatenation of all available CA certificates in PEM format */ -const char cas_pem[] = LETS_ENCRYPT_ROOT GOOGLE_GLOBAL_SIGN GOOGLE_GTS_ROOT_R1 GOOGLE_GTS_ROOT_R2 GOOGLE_GTS_ROOT_R3 - GOOGLE_GTS_ROOT_R4 ZEROSSL_INTERMEDIATE ""; -const size_t cas_pem_len = sizeof(cas_pem); +static const char cas_pem[] = LETS_ENCRYPT_ROOT GOOGLE_GLOBAL_SIGN GOOGLE_GTS_ROOT_R1 GOOGLE_GTS_ROOT_R2 + GOOGLE_GTS_ROOT_R3 GOOGLE_GTS_ROOT_R4 ZEROSSL_INTERMEDIATE ""; +static const size_t cas_pem_len = sizeof(cas_pem); -static void my_debug(void *ctx, int level, const char *file, int line, const char *str) { - ((void)level); - fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str); - fflush((FILE *)ctx); -} +static void my_debug(void *ctx, int level, const char *file, int line, const char *str); -static void init_contexts(atclient_connection *ctx) { - mbedtls_net_init(&(ctx->net)); - mbedtls_ssl_init(&(ctx->ssl)); - mbedtls_ssl_config_init(&(ctx->ssl_config)); - mbedtls_x509_crt_init(&(ctx->cacert)); - mbedtls_entropy_init(&(ctx->entropy)); - mbedtls_ctr_drbg_init(&(ctx->ctr_drbg)); -} +static void atclient_connection_set_is_connection_enabled(atclient_connection *ctx, const bool should_be_connected); +static bool atclient_connection_is_connection_enabled(const atclient_connection *ctx); +static void atclient_connection_enable_connection(atclient_connection *ctx); +static void atclient_connection_disable_connection(atclient_connection *ctx); -static void free_contexts(atclient_connection *ctx) { - mbedtls_net_free(&(ctx->net)); - mbedtls_ssl_free(&(ctx->ssl)); - mbedtls_ssl_config_free(&(ctx->ssl_config)); - mbedtls_x509_crt_free(&(ctx->cacert)); - mbedtls_entropy_free(&(ctx->entropy)); - mbedtls_ctr_drbg_free(&(ctx->ctr_drbg)); -} +static void atclient_connection_set_is_host_initialized(atclient_connection *ctx, const bool is_host_initialized); +static bool atclient_connection_is_host_initialized(const atclient_connection *ctx); +static int atclient_connection_set_host(atclient_connection *ctx, const char *host); +static void atclient_connection_unset_host(atclient_connection *ctx); -static void free_connection_hooks(atclient_connection *ctx) { - if (ctx->hooks != NULL && ctx->_is_hooks_enabled) { - free(ctx->hooks); - ctx->hooks = NULL; - ctx->_is_hooks_enabled = false; - } -} +static void atclient_connection_set_is_port_initialized(atclient_connection *ctx, const bool is_port_initialized); +static bool atclient_connection_is_port_initialized(const atclient_connection *ctx); +static int atclient_connection_set_port(atclient_connection *ctx, const uint16_t port); +static void atclient_connection_unset_port(atclient_connection *ctx); void atclient_connection_init(atclient_connection *ctx, atclient_connection_type type) { memset(ctx, 0, sizeof(atclient_connection)); ctx->type = type; - memset(ctx->host, 0, ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE); - ctx->port = -1; - ctx->_should_be_connected = false; - ctx->hooks = NULL; + ctx->_is_host_initialized = false; + ctx->host = NULL; + ctx->_is_port_initialized = false; + ctx->port = 0; + ctx->_is_connection_enabled = false; ctx->_is_hooks_enabled = false; + ctx->hooks = NULL; +} + +void atclient_connection_free(atclient_connection *ctx) { + if (atclient_connection_is_connection_enabled(ctx)) { + atclient_connection_disable_connection(ctx); + } + if (atclient_connection_hooks_is_enabled(ctx)) { + atclient_connection_hooks_disable(ctx); + } + if (atclient_connection_is_host_initialized(ctx)) { + atclient_connection_unset_host(ctx); + } + if (atclient_connection_is_port_initialized(ctx)) { + atclient_connection_unset_port(ctx); + } + memset(ctx, 0, sizeof(atclient_connection)); } -int atclient_connection_connect(atclient_connection *ctx, const char *host, const int port) { +int atclient_connection_connect(atclient_connection *ctx, const char *host, const uint16_t port) { int ret = 1; - if (ctx->_should_be_connected) { - if((ret = atclient_connection_disconnect(ctx)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_WARN, "atclient_connection_disconnect failed with exit code: %d. Continuing connection anyways..\n", ret); - } + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return ret; + } + + if (host == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "host is NULL\n"); + return ret; } - init_contexts(ctx); - ctx->_should_be_connected = true; + /* + * 2. Variables + */ + const size_t recv_size = 256; + unsigned char recv[recv_size]; + memset(recv, 0, sizeof(unsigned char) * recv_size); + size_t recv_len = 0; - const size_t readbufsize = 1024; - unsigned char readbuf[readbufsize]; - memset(readbuf, 0, sizeof(unsigned char) * readbufsize); - size_t readbuflen = 0; + const size_t port_str_size = 6; + char port_str[port_str_size]; /* - * 1. Set the ctx->host and ctx->port + * 3. Disable and Reenable connection */ - memcpy(ctx->host, host, strlen(host)); // assume null terminated, example: "root.atsign.org" - ctx->port = port; // example: 64 + if (atclient_connection_is_connection_enabled(ctx)) { + atclient_connection_disable_connection(ctx); + } - char portstr[6]; - sprintf(portstr, "%d", ctx->port); + atclient_connection_enable_connection(ctx); /* - * 2. Parse CA certs + * 3. Parse CA certs */ - ret = mbedtls_x509_crt_parse(&(ctx->cacert), (unsigned char *)cas_pem, cas_pem_len); - if (ret != 0) { + if ((ret = mbedtls_x509_crt_parse(&(ctx->cacert), (unsigned char *)cas_pem, cas_pem_len)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_x509_crt_parse failed with exit code: %d\n", ret); goto exit; } /* - * 3. Seed the random number generator + * 4. Seed the random number generator */ - - ret = mbedtls_ctr_drbg_seed(&(ctx->ctr_drbg), mbedtls_entropy_func, &(ctx->entropy), - (unsigned char *)ATCHOPS_RNG_PERSONALIZATION, strlen(ATCHOPS_RNG_PERSONALIZATION)); - if (ret != 0) { + if ((ret = mbedtls_ctr_drbg_seed(&(ctx->ctr_drbg), mbedtls_entropy_func, &(ctx->entropy), + (unsigned char *)ATCHOPS_RNG_PERSONALIZATION, + strlen(ATCHOPS_RNG_PERSONALIZATION))) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ctr_drbg_seed failed with exit code: %d\n", ret); goto exit; } /* - * 4. Start the socket connection + * 5. Start the socket connection */ - ret = mbedtls_net_connect(&(ctx->net), host, portstr, MBEDTLS_NET_PROTO_TCP); - if (ret != 0) { + snprintf(port_str, port_str_size, "%d", port); + if ((ret = mbedtls_net_connect(&(ctx->net), host, port_str, MBEDTLS_NET_PROTO_TCP)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_net_connect failed with exit code: %d\n", ret); goto exit; } /* - * 5. Prepare the SSL connection + * 6. Prepare the SSL connection */ - ret = mbedtls_ssl_config_defaults(&(ctx->ssl_config), MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, - MBEDTLS_SSL_PRESET_DEFAULT); - if (ret != 0) { + if ((ret = mbedtls_ssl_config_defaults(&(ctx->ssl_config), MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_config_defaults failed with exit code: %d\n", ret); goto exit; } @@ -134,14 +146,12 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons mbedtls_ssl_conf_read_timeout(&(ctx->ssl_config), ATCLIENT_CLIENT_READ_TIMEOUT_MS); // recv will timeout after X seconds - ret = mbedtls_ssl_setup(&(ctx->ssl), &(ctx->ssl_config)); - if (ret != 0) { + if ((ret = mbedtls_ssl_setup(&(ctx->ssl), &(ctx->ssl_config))) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_setup failed with exit code: %d\n", ret); goto exit; } - ret = mbedtls_ssl_set_hostname(&(ctx->ssl), host); - if (ret != 0) { + if ((ret = mbedtls_ssl_set_hostname(&(ctx->ssl), host)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_set_hostname failed with exit code: %d\n", ret); goto exit; } @@ -149,10 +159,9 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons mbedtls_ssl_set_bio(&(ctx->ssl), &(ctx->net), mbedtls_net_send, NULL, mbedtls_net_recv_timeout); /* - * 6. Perform the SSL handshake + * 7. Perform the SSL handshake */ - ret = mbedtls_ssl_handshake(&(ctx->ssl)); - if (ret != 0) { + if ((ret = mbedtls_ssl_handshake(&(ctx->ssl))) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_handshake failed with exit code: %d\n", ret); goto exit; } @@ -160,8 +169,7 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons /* * 7. Verify the server certificate */ - ret = mbedtls_ssl_get_verify_result(&(ctx->ssl)); - if (ret != 0) { + if ((ret = mbedtls_ssl_get_verify_result(&(ctx->ssl))) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_get_verify_result failed with exit code: %d\n", ret); goto exit; } @@ -171,59 +179,167 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons // =============== // read anything that was already sent - ret = mbedtls_ssl_read(&(ctx->ssl), readbuf, readbufsize); - if (ret < 0) { + if ((ret = mbedtls_ssl_read(&(ctx->ssl), recv, recv_size)) <= 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret); goto exit; } // press enter - ret = mbedtls_ssl_write(&(ctx->ssl), (const unsigned char *)"\r\n", 2); - if (ret < 0) { + if ((ret = mbedtls_ssl_write(&(ctx->ssl), (const unsigned char *)"\r\n", strlen("\r\n"))) <= 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); goto exit; } // read anything that was sent - ret = mbedtls_ssl_read(&(ctx->ssl), readbuf, readbufsize); - if (ret < 0) { + memset(recv, 0, sizeof(unsigned char) * recv_size); + if ((ret = mbedtls_ssl_read(&(ctx->ssl), recv, recv_size)) <= 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret); goto exit; } // now we are guaranteed a blank canvas - if (ret > 0) { - ret = 0; // a positive exit code is not an error + if((ret = atclient_connection_set_host(ctx, host)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_connection_set_host failed with exit code: %d\n", ret); + goto exit; + } + + if((ret = atclient_connection_set_port(ctx, port)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_connection_set_port failed with exit code: %d\n", ret); + goto exit; } + + ret = 0; goto exit; exit: { if (ret != 0) { - // undo what we set - memset(ctx->host, 0, ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE); - ctx->port = -1; + atclient_connection_disable_connection(ctx); } return ret; } } +int atclient_connection_write(atclient_connection *ctx, const unsigned char *value, const size_t value_len) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + goto exit; + } + + if (value == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "value is NULL\n"); + goto exit; + } + + if (value_len == 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "value_len is 0\n"); + goto exit; + } + + if (!atclient_connection_is_connection_enabled(ctx)) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is not enabled\n"); + goto exit; + } + + /* + * 2. Write the value + */ + if ((ret = mbedtls_ssl_write(&(ctx->ssl), value, value_len)) <= 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); + goto exit; + } + + /* + * 3. Print debug log + */ + if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG) { + unsigned char *valuecopy = malloc(sizeof(unsigned char) * value_len); + if (valuecopy != NULL) { + memcpy(valuecopy, value, value_len); + atlogger_fix_stdout_buffer((char *)valuecopy, value_len); + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sSENT: %s\"%.*s\"%s\n", BBLU, HCYN, value_len, valuecopy, + reset); + free(valuecopy); + } else { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, + "Failed to allocate memory to pretty print the network sent transmission\n"); + } + } + + /* + * 4. Call hooks, if they exist + */ + bool try_hooks = atclient_connection_hooks_is_enabled(ctx) && !ctx->hooks->_is_nested_call; + if (try_hooks && ctx->hooks->post_write != NULL) { + ctx->hooks->_is_nested_call = true; + atclient_connection_hook_params params; + params.src = (unsigned char *)value; + params.src_len = value_len; + params.recv = NULL; + params.recv_size = 0; + params.recv_len = NULL; + ret = ctx->hooks->post_write(¶ms); + if (ctx->hooks != NULL) { + ctx->hooks->_is_nested_call = false; + } + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "post_recv hook failed with exit code: %d\n", ret); + goto exit; + } + } + + ret = 0; + goto exit; +exit: { return ret; } +} + int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_r, const size_t srclen_r, unsigned char *recv, const size_t recvsize_r, size_t *recvlen) { int ret = 1; + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + goto exit; + } + + if (src_r == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "src is NULL\n"); + goto exit; + } + + if (srclen_r == 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "srclen is 0\n"); + goto exit; + } + + if(!atclient_connection_is_connection_enabled(ctx)) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is not enabled\n"); + goto exit; + } + + /* + * 2. Prep hook stuff + */ // Clone readonly inputs so it is editable by the hooks size_t srclen = srclen_r; size_t recvsize = recvsize_r; - bool try_hooks = ctx->hooks != NULL && !ctx->hooks->_is_nested_call; + bool try_hooks = atclient_connection_hooks_is_enabled(ctx) && !ctx->hooks->_is_nested_call; bool allocate_src = try_hooks && ctx->hooks->readonly_src == false; - unsigned char *src; + unsigned char *src = NULL; if (allocate_src) { - src = malloc(sizeof(unsigned char) * srclen); - if (src == NULL) { + if ((src = malloc(sizeof(unsigned char) * srclen)) == NULL) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory for src\n"); allocate_src = false; // don't try to free since the memory failed to be allocated goto exit; @@ -233,16 +349,18 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ src = (unsigned char *)src_r; } - if (!ctx->_should_be_connected) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, - "ctx->_should_be_connected should be true, but is false. You are trying to send messages to a " - "non-connected connection.\n"); - goto exit; - } - - if (try_hooks && ctx->hooks->pre_send != NULL) { + /* + * 3. Call pre_send hook, if it exists + */ + if (try_hooks && atclient_connection_hooks_is_pre_write_initialized(ctx)) { ctx->hooks->_is_nested_call = true; - ret = ctx->hooks->pre_send(src, srclen, recv, recvsize, recvlen); + atclient_connection_hook_params params; + params.src = src; + params.src_len = srclen; + params.recv = recv; + params.recv_size = recvsize; + params.recv_len = recvlen; + ret = ctx->hooks->pre_write(¶ms); if (ctx->hooks != NULL) { ctx->hooks->_is_nested_call = false; } @@ -252,15 +370,25 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ } } - ret = mbedtls_ssl_write(&(ctx->ssl), src, srclen); - if (ret <= 0) { + /* + * 4. Write the value + */ + if ((ret = mbedtls_ssl_write(&(ctx->ssl), src, srclen)) <= 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); goto exit; } - if (try_hooks && ctx->hooks->post_send != NULL) { + /* + * 5. Call post_send hook, if it exists + */ + if (try_hooks && atclient_connection_hooks_is_post_write_initialized(ctx)) { ctx->hooks->_is_nested_call = true; - ret = ctx->hooks->post_send(src, srclen, recv, recvsize, recvlen); + atclient_connection_hook_params params; + params.src = src; + params.src_len = srclen; + params.recv = recv; + params.recv_size = recvsize; + ret = ctx->hooks->post_write(¶ms); if (ctx->hooks != NULL) { ctx->hooks->_is_nested_call = false; } @@ -270,10 +398,12 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ } } - unsigned char *srccopy; + /* + * 6. Print debug log + */ if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG && ret == srclen) { - srccopy = malloc(sizeof(unsigned char) * srclen); - if (srccopy != NULL) { + unsigned char *srccopy = NULL; + if ((srccopy = malloc(sizeof(unsigned char) * srclen)) != NULL) { memcpy(srccopy, src, srclen); atlogger_fix_stdout_buffer((char *)srccopy, srclen); atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sSENT: %s\"%.*s\"%s\n", BBLU, HCYN, strlen((char *)srccopy), @@ -285,16 +415,27 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ } } + /* + * 7. Exit if recv is NULL + */ if (recv == NULL) { ret = 0; goto exit; } + /* + * 8. Run pre read hook, if it exists + */ memset(recv, 0, sizeof(unsigned char) * recvsize); - - if (try_hooks && ctx->hooks->pre_recv != NULL) { + if (try_hooks && atclient_connection_hooks_is_pre_read_initialized(ctx)) { ctx->hooks->_is_nested_call = true; - ret = ctx->hooks->pre_recv(src, srclen, recv, recvsize, recvlen); + atclient_connection_hook_params params; + params.src = src; + params.src_len = srclen; + params.recv = recv; + params.recv_size = recvsize; + params.recv_len = recvlen; + ret = ctx->hooks->pre_read(¶ms); if (ctx->hooks != NULL) { ctx->hooks->_is_nested_call = false; } @@ -304,28 +445,19 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ } } + /* + * 9. Read the value + */ int tries = 0; bool found = false; size_t l = 0; do { - ret = mbedtls_ssl_read(&(ctx->ssl), recv + l, recvsize - l); - if (ret <= 0) { + if ((ret = mbedtls_ssl_read(&(ctx->ssl), recv + l, recvsize - l)) <= 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret); goto exit; } - if (ret == 0) { - tries++; - if (tries >= ATCLIENT_CONNECTION_MAX_READ_TRIES) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, - "mbedtls_ssl_read tried to read %d times and found nothing: %d\n", tries, ret); - ret = 1; - goto exit; - } - } l = l + ret; - for (int i = l; i >= l - ret && i >= 0; i--) { - // printf("i: %d c: %.2x\n", i, (unsigned char) *(recv + i)); if (*(recv + i) == '\n' || *(recv + i) == '\r') { *recvlen = i; found = true; @@ -335,16 +467,36 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ if (found) { break; } - } while (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == 0 || !found); - - // atlogger_fix_stdout_buffer((char *)recv, *recvlen); recv[*recvlen] = '\0'; // null terminate the string - unsigned char *recvcopy; + /* + * 10. Run post read hook, if it exists + */ + if (try_hooks && atclient_connection_hooks_is_post_read_initialized(ctx)) { + ctx->hooks->_is_nested_call = true; + atclient_connection_hook_params params; + params.src = src; + params.src_len = srclen; + params.recv = recv; + params.recv_size = recvsize; + params.recv_len = recvlen; + ret = ctx->hooks->post_read(¶ms); + if (ctx->hooks != NULL) { + ctx->hooks->_is_nested_call = false; + } + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "post_recv hook failed with exit code: %d\n", ret); + goto exit; + } + } + + /* + * 11. Print debug log + */ if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG) { - recvcopy = malloc(sizeof(unsigned char) * (*recvlen)); - if (recvcopy != NULL) { + unsigned char *recvcopy = NULL; + if ((recvcopy = malloc(sizeof(unsigned char) * (*recvlen))) != NULL) { memcpy(recvcopy, recv, *recvlen); atlogger_fix_stdout_buffer((char *)recvcopy, *recvlen); atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, *recvlen, recvcopy, @@ -356,18 +508,6 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_ } } - if (try_hooks && ctx->hooks->post_recv != NULL) { - ctx->hooks->_is_nested_call = true; - ret = ctx->hooks->post_recv(src, srclen, recv, recvsize, recvlen); - if (ctx->hooks != NULL) { - ctx->hooks->_is_nested_call = false; - } - if (ret != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "post_recv hook failed with exit code: %d\n", ret); - goto exit; - } - } - ret = 0; goto exit; exit: { @@ -381,29 +521,42 @@ exit: { int atclient_connection_disconnect(atclient_connection *ctx) { int ret = 1; - if (!ctx->_should_be_connected) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, - "ctx->_should_be_connected should be true, but is false, it was never connected in the first place!\n"); - goto exit; + /* + * 1. Validate arguments + */ + if(ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return ret; + } + + if(!atclient_connection_is_connection_enabled(ctx)) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is not enabled\n"); + return ret; } do { ret = mbedtls_ssl_close_notify(&(ctx->ssl)); } while (ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == MBEDTLS_ERR_SSL_WANT_READ || ret != 0); - free_contexts(ctx); - ctx->_should_be_connected = false; + atclient_connection_disable_connection(ctx); ret = 0; - goto exit; exit: { return ret; } } bool atclient_connection_is_connected(atclient_connection *ctx) { - if (!ctx->_should_be_connected) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx->_should_be_connected should be true, but is false\n"); + /* + * 1. Validate arguments + */ + if(ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL, of course it's not connected lol\n"); + return false; + } + + if(!atclient_connection_is_connection_enabled(ctx)) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is not enabled\n"); return false; } @@ -438,138 +591,261 @@ bool atclient_connection_is_connected(atclient_connection *ctx) { return true; } -void atclient_connection_free(atclient_connection *ctx) { - if (ctx->_should_be_connected) { - free_contexts(ctx); +int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len, + const size_t value_max_len) { + int ret = 1; +} + +static void my_debug(void *ctx, int level, const char *file, int line, const char *str) { + ((void)level); + fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str); + fflush((FILE *)ctx); +} + +static void atclient_connection_set_is_connection_enabled(atclient_connection *ctx, const bool should_be_connected) { + ctx->_is_connection_enabled = should_be_connected; +} + +static bool atclient_connection_is_connection_enabled(const atclient_connection *ctx) { + return ctx->_is_connection_enabled; +} + +static void atclient_connection_enable_connection(atclient_connection *ctx) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return; } - if (ctx->hooks != NULL) { - free(ctx->hooks); + + /* + * 2. Disable connection, if necessary + */ + if (atclient_connection_is_connection_enabled(ctx)) { + atclient_connection_disable_connection(ctx); } - memset(ctx, 0, sizeof(atclient_connection)); - memset(ctx->host, 0, ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE); - ctx->port = -1; - ctx->_should_be_connected = false; + + /* + * 3. Enable the connection + */ + mbedtls_net_init(&(ctx->net)); + mbedtls_ssl_init(&(ctx->ssl)); + mbedtls_ssl_config_init(&(ctx->ssl_config)); + mbedtls_x509_crt_init(&(ctx->cacert)); + mbedtls_entropy_init(&(ctx->entropy)); + mbedtls_ctr_drbg_init(&(ctx->ctr_drbg)); + + /* + * 4. Set the connection enabled flag + */ + atclient_connection_set_is_connection_enabled(ctx, true); } -void atclient_connection_enable_hooks(atclient_connection *ctx) { - ctx->hooks = malloc(sizeof(atclient_connection_hooks)); // TODO handle malloc failure - memset(ctx->hooks, 0, sizeof(atclient_connection_hooks)); - ctx->hooks->readonly_src = true; - ctx->_is_hooks_enabled = true; +static void atclient_connection_disable_connection(atclient_connection *ctx) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return; + } + + /* + * 2. Free the contexts + */ + if (atclient_connection_is_connection_enabled(ctx)) { + mbedtls_net_free(&(ctx->net)); + mbedtls_ssl_free(&(ctx->ssl)); + mbedtls_ssl_config_free(&(ctx->ssl_config)); + mbedtls_x509_crt_free(&(ctx->cacert)); + mbedtls_entropy_free(&(ctx->entropy)); + mbedtls_ctr_drbg_free(&(ctx->ctr_drbg)); + } + + /* + * 3. Set the connection disabled flag + */ + atclient_connection_set_is_connection_enabled(ctx, false); } -// Q. Why is hook a void pointer? -// A. In case we want to add future hook types which use a different function signature -int atclient_connection_hooks_set(atclient_connection *ctx, atclient_connection_hook_type type, void *hook) { - atclient_connection_hooks *hooks = ctx->hooks; - if (hooks == NULL || !ctx->_is_hooks_enabled) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, - "Make sure to enable hooks struct before trying to set a hook\n"); - return -1; - } - - switch (type) { - case ATCLIENT_CONNECTION_HOOK_TYPE_NONE: - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Received 'NONE' hook as hook set input type\n"); - return 1; - case ATCLIENT_CONNECTION_HOOK_TYPE_PRE_SEND: - hooks->pre_send = (atclient_connection_send_hook *)hook; - break; - case ATCLIENT_CONNECTION_HOOK_TYPE_POST_SEND: - hooks->post_send = (atclient_connection_send_hook *)hook; - break; - case ATCLIENT_CONNECTION_HOOK_TYPE_PRE_RECV: - hooks->pre_recv = (atclient_connection_send_hook *)hook; - break; - case ATCLIENT_CONNECTION_HOOK_TYPE_POST_RECV: - hooks->post_recv = (atclient_connection_send_hook *)hook; - break; - } - - return 0; +static void atclient_connection_set_is_host_initialized(atclient_connection *ctx, const bool is_host_initialized) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return; + } + + /* + * 2. Set the host initialized flag + */ + ctx->_is_host_initialized = is_host_initialized; } -void atclient_connection_hooks_set_readonly_src(atclient_connection *ctx, bool readonly_src) { - if (ctx->hooks == NULL) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, - "Make sure to enable hooks struct before trying to set readonly_src\n"); +static bool atclient_connection_is_host_initialized(const atclient_connection *ctx) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return false; + } + + /* + * 2. Return the host initialized flag + */ + return ctx->_is_host_initialized; +} + +static int atclient_connection_set_host(atclient_connection *ctx, const char *host) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return ret; + } + + if (host == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "host is NULL\n"); + return ret; + } + + /* + * 2. Allocate memory for the host + */ + const size_t host_len = strlen(host); + const size_t host_size = host_len + 1; + if ((ctx->host = malloc(sizeof(char) * host_size)) == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory for host\n"); + goto exit; + } + + /* + * 3. Copy the host + */ + memcpy(ctx->host, host, host_len); + ctx->host[host_len] = '\0'; + + /* + * 4. Set the host initialized flag + */ + atclient_connection_set_is_host_initialized(ctx, true); + + ret = 0; + goto exit; +exit: { return ret; } +} + +static void atclient_connection_unset_host(atclient_connection *ctx) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return; + } + + /* + * 2. Free the host + */ + if (atclient_connection_is_host_initialized(ctx)) { + free(ctx->host); + } + ctx->host = NULL; + + /* + * 3. Unset the host initialized flag + */ + atclient_connection_set_is_host_initialized(ctx, false); +} + +static void atclient_connection_set_is_port_initialized(atclient_connection *ctx, const bool is_port_initialized) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); return; } - ctx->hooks->readonly_src = readonly_src; + + /* + * 2. Set the port initialized flag + */ + ctx->_is_port_initialized = is_port_initialized; } -// int atclient_connection_write(atclient_connection *ctx, const unsigned char *value, const size_t value_len) { -// int ret = 1; - -// /* -// * 1. Validate arguments -// */ -// if(ctx == NULL) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); -// goto exit; -// } - -// if(value == NULL) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "value is NULL\n"); -// goto exit; -// } - -// if(value_len == 0) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "value_len is 0\n"); -// goto exit; -// } - -// if (!ctx->_should_be_connected) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, -// "ctx->_should_be_connected should be true, but is false, you are trying to write to a non-connected " -// "connection\n"); -// goto exit; -// } - -// /* -// * 2. Write the value -// */ -// if ((ret = mbedtls_ssl_write(&(ctx->ssl), value, value_len)) <= 0) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); -// goto exit; -// } - -// /* -// * 3. Print debug log -// */ -// if(atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG) { -// unsigned char *valuecopy = malloc(sizeof(unsigned char) * value_len); -// if (valuecopy != NULL) { -// memcpy(valuecopy, value, value_len); -// atlogger_fix_stdout_buffer((char *)valuecopy, value_len); -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sSENT: %s\"%.*s\"%s\n", BBLU, HCYN, value_len, valuecopy, reset); -// free(valuecopy); -// } else { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory to pretty print the network sent transmission\n"); -// } -// } - -// /* -// * 4. Call hooks, if they exist -// */ -// bool try_hooks = ctx->hooks != NULL && !ctx->hooks->_is_nested_call; -// if (try_hooks && ctx->hooks->post_recv != NULL) { -// ctx->hooks->_is_nested_call = true; -// ret = ctx->hooks->post_recv(src, srclen, recv, recvsize, recvlen) -// if (ctx->hooks != NULL) { -// ctx->hooks->_is_nested_call = false; -// } -// if (ret != 0) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "post_recv hook failed with exit code: %d\n", ret); -// goto exit; -// } -// } - -// ret = 0; -// goto exit; -// exit: { -// return ret; -// } -// } - -// int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len, const size_t value_max_len); \ No newline at end of file +static bool atclient_connection_is_port_initialized(const atclient_connection *ctx) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return false; + } + + /* + * 2. Return the port initialized flag + */ + return ctx->_is_port_initialized; +} + +static int atclient_connection_set_port(atclient_connection *ctx, const uint16_t port) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return ret; + } + + if (port < 0) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "port is less than 0\n"); + return ret; + } + + /* + * 2. Set the port + */ + ctx->port = port; + + /* + * 3. Set the port initialized flag + */ + atclient_connection_set_is_port_initialized(ctx, true); + + ret = 0; + goto exit; +exit: { return ret; } +} + +static void atclient_connection_unset_port(atclient_connection *ctx) { + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx is NULL\n"); + return; + } + + /* + * 2. Unset the port + */ + ctx->port = 0; + + /* + * 3. Unset the port initialized flag + */ + atclient_connection_set_is_port_initialized(ctx, false); +} diff --git a/packages/atclient/src/connection_hooks.c b/packages/atclient/src/connection_hooks.c new file mode 100644 index 00000000..64e7b2e0 --- /dev/null +++ b/packages/atclient/src/connection_hooks.c @@ -0,0 +1,393 @@ +#include "atclient/connection.h" +#include +#include +#include +#include "atclient/connection_hooks.h" +#include + +#define TAG "connection_hooks" + +static void atclient_connection_hooks_set_is_enabled(atclient_connection *ctx, const bool enabled); + +static void atclient_connection_hooks_set_is_pre_read_initialized(atclient_connection *ctx, const bool initialized); +static int atclient_connection_hooks_set_pre_read(atclient_connection *ctx, atclient_connection_hook *hook); +static void atclient_connection_hooks_unset_pre_read(atclient_connection *ctx); + +static void atclient_connection_hooks_set_is_post_read_initialized(atclient_connection *ctx, const bool initialized); +static int atclient_connection_hooks_set_post_read(atclient_connection *ctx, atclient_connection_hook *hook); +static void atclient_connection_hooks_unset_post_read(atclient_connection *ctx); + +static void atclient_connection_hooks_set_is_pre_write_initialized(atclient_connection *ctx, const bool initialized); +static int atclient_connection_hooks_set_pre_write(atclient_connection *ctx, atclient_connection_hook *hook); +static void atclient_connection_hooks_unset_pre_write(atclient_connection *ctx); + +static void atclient_connection_hooks_set_is_post_write_initialized(atclient_connection *ctx, const bool initialized); +static int atclient_connection_hooks_set_post_write(atclient_connection *ctx, atclient_connection_hook *hook); +static void atclient_connection_hooks_unset_post_write(atclient_connection *ctx); + +bool atclient_connection_hooks_is_enabled(atclient_connection *ctx) { + if (ctx->hooks == NULL) { + return false; + } + return ctx->_is_hooks_enabled; +} + +int atclient_connection_hooks_enable(atclient_connection *conn) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (conn == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return ret; + } + + /* + * 2. Disable hooks if they are already enabled + */ + if (atclient_connection_hooks_is_enabled(conn)) { + atclient_connection_hooks_disable(conn); + } + + /* + * 3. Allocate memory for the hooks struct + */ + if ((conn->hooks = malloc(sizeof(atclient_connection_hooks))) == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory for connection hooks\n"); + goto exit; + } + memset(conn->hooks, 0, sizeof(atclient_connection_hooks)); + atclient_connection_hooks_set_is_enabled(conn, true); + + /* + * 4. Set any defaults + */ + conn->hooks->readonly_src = true; + + ret = 0; + goto exit; +exit: { return ret; } +} + +void atclient_connection_hooks_disable(atclient_connection *conn) { + /* + * 1. Validate arguments + */ + if (conn == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return; + } + + /* + * 2. Free the hooks struct + */ + if (atclient_connection_hooks_is_enabled(conn)) { + free(conn->hooks); + } + atclient_connection_hooks_set_is_enabled(conn, false); + conn->hooks = NULL; +} + +int atclient_connection_hooks_set(atclient_connection *ctx, const atclient_connection_hook_type type, void *hook) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return ret; + } + + if (type == ATCLIENT_CONNECTION_HOOK_TYPE_NONE) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Received 'NONE' hook as hook set input type\n"); + return ret; + } + + if (hook == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Hook is NULL\n"); + return ret; + } + + if (!atclient_connection_hooks_is_enabled(ctx)) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Make sure to enable hooks struct before trying to set a hook\n"); + return ret; + } + + /* + * 2. Set the hook + */ + switch (type) { + case ATCLIENT_CONNECTION_HOOK_TYPE_NONE: + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Received 'NONE' hook as hook set input type\n"); + goto exit; + case ATCLIENT_CONNECTION_HOOK_TYPE_PRE_READ: { + if((ret = atclient_connection_hooks_set_pre_read(ctx, (atclient_connection_hook *)hook)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set pre read hook\n"); + goto exit; + } + break; + } + case ATCLIENT_CONNECTION_HOOK_TYPE_POST_READ: { + if((ret = atclient_connection_hooks_set_post_read(ctx, (atclient_connection_hook *)hook)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set post read hook\n"); + goto exit; + } + break; + } + case ATCLIENT_CONNECTION_HOOK_TYPE_PRE_WRITE: { + if((ret = atclient_connection_hooks_set_pre_write(ctx, (atclient_connection_hook *)hook)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set pre write hook\n"); + goto exit; + } + break; + } + case ATCLIENT_CONNECTION_HOOK_TYPE_POST_WRITE: { + if((ret = atclient_connection_hooks_set_post_write(ctx, (atclient_connection_hook *)hook)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set post write hook\n"); + goto exit; + } + break; + } + } + + ret = 0; + goto exit; +exit: { return ret; } +} + +bool atclient_connection_hooks_is_pre_read_initialized(const atclient_connection *ctx) { + return ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_PRE_READ_INDEX] & + ATCLIENT_CONNECTION_HOOKS_PRE_READ_INITIALIZED; +} + +bool atclient_connection_hooks_is_post_read_initialized(const atclient_connection *ctx) { + return ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_POST_READ_INDEX] & + ATCLIENT_CONNECTION_HOOKS_POST_READ_INITIALIZED; +} + +bool atclient_connection_hooks_is_pre_write_initialized(const atclient_connection *ctx) { + return ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INDEX] & + ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INITIALIZED; +} + +bool atclient_connection_hooks_is_post_write_initialized(const atclient_connection *ctx) { + return ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INDEX] & + ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INITIALIZED; +} + +static void atclient_connection_hooks_set_is_enabled(atclient_connection *ctx, const bool enabled) { + ctx->_is_hooks_enabled = enabled; +} + +static void atclient_connection_hooks_set_is_pre_read_initialized(atclient_connection *ctx, const bool initialized) { + if (initialized) { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_PRE_READ_INDEX] |= + ATCLIENT_CONNECTION_HOOKS_PRE_READ_INITIALIZED; + } else { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_PRE_READ_INDEX] &= + ~ATCLIENT_CONNECTION_HOOKS_PRE_READ_INITIALIZED; + } +} + +static int atclient_connection_hooks_set_pre_read(atclient_connection *ctx, atclient_connection_hook *hook) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return ret; + } + + if (!atclient_connection_hooks_is_enabled(ctx)) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Make sure to enable hooks struct before trying to set a hook\n"); + return ret; + } + + if (hook == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Hook is NULL\n"); + return ret; + } + + if (atclient_connection_hooks_is_pre_read_initialized(ctx)) { + atclient_connection_hooks_unset_pre_read(ctx); + } + + ctx->hooks->pre_read = hook; + atclient_connection_hooks_set_is_pre_read_initialized(ctx, true); + + ret = 0; + goto exit; +exit: { return ret; } +} + +static void atclient_connection_hooks_unset_pre_read(atclient_connection *ctx) { + ctx->hooks->pre_read = NULL; + atclient_connection_hooks_set_is_pre_read_initialized(ctx, false); +} + +static void atclient_connection_hooks_set_is_post_read_initialized(atclient_connection *ctx, const bool initialized) { + if (initialized) { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_POST_READ_INDEX] |= + ATCLIENT_CONNECTION_HOOKS_POST_READ_INITIALIZED; + } else { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_POST_READ_INDEX] &= + ~ATCLIENT_CONNECTION_HOOKS_POST_READ_INITIALIZED; + } +} + +static int atclient_connection_hooks_set_post_read(atclient_connection *ctx, atclient_connection_hook *hook) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return ret; + } + + if (!atclient_connection_hooks_is_enabled(ctx)) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Make sure to enable hooks struct before trying to set a hook\n"); + return ret; + } + + if (hook == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Hook is NULL\n"); + return ret; + } + + if (atclient_connection_hooks_is_post_read_initialized(ctx)) { + atclient_connection_hooks_unset_post_read(ctx); + } + + ctx->hooks->post_read = hook; + atclient_connection_hooks_set_is_post_read_initialized(ctx, true); + + ret = 0; + goto exit; +exit: { return ret; } +} + +static void atclient_connection_hooks_unset_post_read(atclient_connection *ctx) { + ctx->hooks->post_read = NULL; + atclient_connection_hooks_set_is_post_read_initialized(ctx, false); +} + +static void atclient_connection_hooks_set_is_pre_write_initialized(atclient_connection *ctx, const bool initialized) { + if (initialized) { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INDEX] |= + ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INITIALIZED; + } else { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INDEX] &= + ~ATCLIENT_CONNECTION_HOOKS_PRE_WRITE_INITIALIZED; + } +} + +static int atclient_connection_hooks_set_pre_write(atclient_connection *ctx, atclient_connection_hook *hook) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return ret; + } + + if (!atclient_connection_hooks_is_enabled(ctx)) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Make sure to enable hooks struct before trying to set a hook\n"); + return ret; + } + + if (hook == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Hook is NULL\n"); + return ret; + } + + if (atclient_connection_hooks_is_pre_write_initialized(ctx)) { + atclient_connection_hooks_unset_pre_write(ctx); + } + + ctx->hooks->pre_write = hook; + atclient_connection_hooks_set_is_pre_write_initialized(ctx, true); + + ret = 0; + goto exit; +exit: { return ret; } +} + +static void atclient_connection_hooks_unset_pre_write(atclient_connection *ctx) { + ctx->hooks->pre_write = NULL; + atclient_connection_hooks_set_is_pre_write_initialized(ctx, false); +} + +static void atclient_connection_hooks_set_is_post_write_initialized(atclient_connection *ctx, const bool initialized) { + if (initialized) { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INDEX] |= + ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INITIALIZED; + } else { + ctx->hooks->_initialized_fields[ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INDEX] &= + ~ATCLIENT_CONNECTION_HOOKS_POST_WRITE_INITIALIZED; + } +} + +static int atclient_connection_hooks_set_post_write(atclient_connection *ctx, atclient_connection_hook *hook) { + int ret = 1; + + /* + * 1. Validate arguments + */ + if (ctx == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Connection is NULL\n"); + return ret; + } + + if (!atclient_connection_hooks_is_enabled(ctx)) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Make sure to enable hooks struct before trying to set a hook\n"); + return ret; + } + + if (hook == NULL) { + ret = 1; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Hook is NULL\n"); + return ret; + } + + if (atclient_connection_hooks_is_post_write_initialized(ctx)) { + atclient_connection_hooks_unset_post_write(ctx); + } + + ctx->hooks->post_write = hook; + atclient_connection_hooks_set_is_post_write_initialized(ctx, true); + + ret = 0; + goto exit; +exit: { return ret; } +} + +static void atclient_connection_hooks_unset_post_write(atclient_connection *ctx) { + ctx->hooks->post_write = NULL; + atclient_connection_hooks_set_is_post_write_initialized(ctx, false); +} \ No newline at end of file diff --git a/tests/functional_tests/tests/test_atclient_connection.c b/tests/functional_tests/tests/test_atclient_connection.c index e72581df..fde2f855 100644 --- a/tests/functional_tests/tests/test_atclient_connection.c +++ b/tests/functional_tests/tests/test_atclient_connection.c @@ -142,7 +142,7 @@ static int test_1_initialize(atclient_connection *conn) { atclient_connection_init(conn, ATCLIENT_CONNECTION_TYPE_ATDIRECTORY); - if ((ret = assert_equals(conn->_should_be_connected, false)) != 0) { + if ((ret = assert_equals(conn->_is_connection_enabled, false)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "conn->_should_be_connected should be false, but is true\n"); goto exit; } @@ -166,7 +166,7 @@ static int test_2_connect(atclient_connection *conn) { goto exit; } - if ((ret = assert_equals(conn->_should_be_connected, true)) != 0) { + if ((ret = assert_equals(conn->_is_connection_enabled, true)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "root_conn._should_be_connected should be true, but is false\n"); goto exit; } @@ -234,7 +234,7 @@ static int test_5_disconnect(atclient_connection *conn) { goto exit; } - if ((ret = assert_equals(conn->_should_be_connected, false)) != 0) { + if ((ret = assert_equals(conn->_is_connection_enabled, false)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "root_conn._should_be_connected should be false, but is true\n"); goto exit; } @@ -308,7 +308,7 @@ static int test_8_reconnect(atclient_connection *conn) { goto exit; } - if (!conn->_should_be_connected) { + if (!conn->_is_connection_enabled) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "ctx->_should_be_connected should be true, but is false\n"); ret = 1; goto exit; @@ -345,7 +345,7 @@ static int test_10_free(atclient_connection *conn) { atclient_connection_free(conn); - if ((ret = assert_equals(conn->_should_be_connected, false)) != 0) { + if ((ret = assert_equals(conn->_is_connection_enabled, false)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "conn->_should_be_connected should be false, but is true\n"); goto exit; } @@ -363,7 +363,7 @@ static int test_11_initialize(atclient_connection *conn) { atclient_connection_init(conn, ATCLIENT_CONNECTION_TYPE_ATDIRECTORY); - if ((ret = assert_equals(conn->_should_be_connected, false)) != 0) { + if ((ret = assert_equals(conn->_is_connection_enabled, false)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "conn->_should_be_connected should be false, but is true\n"); goto exit; } @@ -491,7 +491,7 @@ static int test_17_should_be_connected_should_be_true(atclient_connection *conn) int ret = 1; - if ((ret = assert_equals(conn->_should_be_connected, true)) != 0) { + if ((ret = assert_equals(conn->_is_connection_enabled, true)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "conn->_should_be_connected should be true, but is false\n"); goto exit; }