From 05dced79f0feb2fabf121deacf45af3e891f3716 Mon Sep 17 00:00:00 2001 From: XavierChanth Date: Mon, 9 Dec 2024 15:55:24 -0500 Subject: [PATCH 1/3] feat: move mbedtls sockets into its own layer under connection --- examples/desktop/repl/src/main.c | 10 +- generators/arduino/atsdk/generate.sh | 1 + packages/atchops/include/atchops/platform.h | 2 - packages/atclient/CMakeLists.txt | 2 + .../atclient/include/atclient/connection.h | 27 +- packages/atclient/include/atclient/mbedtls.h | 5 +- packages/atclient/include/atclient/monitor.h | 5 + packages/atclient/include/atclient/socket.h | 119 +++++ .../include/atclient/socket_mbedtls.h | 34 ++ .../atclient/include/atclient/socket_shared.h | 69 +++ packages/atclient/src/atclient.c | 34 +- packages/atclient/src/atclient_get_atkeys.c | 2 +- packages/atclient/src/connection.c | 398 +++++---------- packages/atclient/src/monitor.c | 79 +-- packages/atclient/src/socket.c | 18 + packages/atclient/src/socket_mbedtls.c | 456 ++++++++++++++++++ packages/atclient/src/socket_raw_mbedtls.c | 19 + tests/functional_tests/lib/src/config.c | 9 +- .../tests/test_atclient_connection.c | 7 +- .../tests/test_atclient_get_atkeys.c | 1 + .../tests/test_atclient_monitor.c | 16 +- .../tests/test_atclient_sharedkey.c | 10 +- 22 files changed, 926 insertions(+), 397 deletions(-) create mode 100644 packages/atclient/include/atclient/socket.h create mode 100644 packages/atclient/include/atclient/socket_mbedtls.h create mode 100644 packages/atclient/include/atclient/socket_shared.h create mode 100644 packages/atclient/src/socket.c create mode 100644 packages/atclient/src/socket_mbedtls.c create mode 100644 packages/atclient/src/socket_raw_mbedtls.c diff --git a/examples/desktop/repl/src/main.c b/examples/desktop/repl/src/main.c index 41aa9707..7ca1e764 100644 --- a/examples/desktop/repl/src/main.c +++ b/examples/desktop/repl/src/main.c @@ -24,8 +24,7 @@ * --key-file [~/.atsign/keys/@atsign_key.atKeys] */ -static int set_up_pkam_auth_options(atclient_authenticate_options *pkam_authenticate_options, - const char *root_url); +static int set_up_pkam_auth_options(atclient_authenticate_options *pkam_authenticate_options, const char *root_url); static int start_repl_loop(atclient *atclient, repl_args *repl_args); int main(int argc, char *argv[]) { @@ -97,8 +96,7 @@ exit: { } } -static int set_up_pkam_auth_options(atclient_authenticate_options *pkam_authenticate_options, - const char *root_url) { +static int set_up_pkam_auth_options(atclient_authenticate_options *pkam_authenticate_options, const char *root_url) { int ret = 1; if (pkam_authenticate_options == NULL) { @@ -162,7 +160,7 @@ static int start_repl_loop(atclient *atclient, repl_args *repl_args) { bool loop = true; - const size_t stdin_buffer_size = STDIN_BUFFER_SIZE; + size_t stdin_buffer_size = STDIN_BUFFER_SIZE; char stdin_buffer[stdin_buffer_size]; char *stdin_buffer_ptr = stdin_buffer; size_t stdin_buffer_len = 0; @@ -202,7 +200,7 @@ static int start_repl_loop(atclient *atclient, repl_args *repl_args) { goto exit; } - if((ret = atclient_connection_read(&(atclient->atserver_connection), &recv, NULL, 0)) != 0) { + if ((ret = atclient_connection_read(&(atclient->atserver_connection), (unsigned char **)&recv, NULL)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read response\n"); goto exit; } diff --git a/generators/arduino/atsdk/generate.sh b/generators/arduino/atsdk/generate.sh index fdb0b9f4..3dbcb37f 100755 --- a/generators/arduino/atsdk/generate.sh +++ b/generators/arduino/atsdk/generate.sh @@ -205,6 +205,7 @@ overrides() { echo '#define PRIu64 "llu"' echo "#define ATCHOPS_TARGET_ARDUINO" echo "#define ATCHOPS_MBEDTLS_VERSION_2" + echo "#define ATCLIENT_NET_SOCKET_PROVIDER_EXTERNAL" } >>$src_base/atchops/platform.h } diff --git a/packages/atchops/include/atchops/platform.h b/packages/atchops/include/atchops/platform.h index b07faf10..d897f0de 100644 --- a/packages/atchops/include/atchops/platform.h +++ b/packages/atchops/include/atchops/platform.h @@ -7,8 +7,6 @@ // Platforms we support -// Default MbedTLS version - #if defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) #define ATCHOPS_TARGET_UNIX diff --git a/packages/atclient/CMakeLists.txt b/packages/atclient/CMakeLists.txt index 89bcb655..23eb381c 100644 --- a/packages/atclient/CMakeLists.txt +++ b/packages/atclient/CMakeLists.txt @@ -21,6 +21,8 @@ set( ${CMAKE_CURRENT_LIST_DIR}/src/atnotification.c ${CMAKE_CURRENT_LIST_DIR}/src/connection_hooks.c ${CMAKE_CURRENT_LIST_DIR}/src/connection.c + ${CMAKE_CURRENT_LIST_DIR}/src/socket.c + ${CMAKE_CURRENT_LIST_DIR}/src/socket_mbedtls.c ${CMAKE_CURRENT_LIST_DIR}/src/encryption_key_helpers.c ${CMAKE_CURRENT_LIST_DIR}/src/metadata.c ${CMAKE_CURRENT_LIST_DIR}/src/monitor.c diff --git a/packages/atclient/include/atclient/connection.h b/packages/atclient/include/atclient/connection.h index 08aea6a2..42cc695d 100644 --- a/packages/atclient/include/atclient/connection.h +++ b/packages/atclient/include/atclient/connection.h @@ -1,12 +1,24 @@ +/* + * + * The connection family of types and methods represents a single connection to + * if you want a pure socket representation see net_socket.h. + * + * At the moment _socket represents a singular tcp socket, but in the future it may be altered + * to be a union of different connection types, such as a websocket or other construct. + * It is considered an internal construct and is subject to breaking changes across + * minor releases, especially while at_c remains in beta status. + * + */ #ifndef ATCLIENT_CONNECTION_H #define ATCLIENT_CONNECTION_H #ifdef __cplusplus extern "C" { #endif -#include "atchops/mbedtls.h" -#include "atclient/connection_hooks.h" #include // IWYU pragma: keep + +#include "atclient/connection_hooks.h" +#include "atclient/socket.h" #include #include @@ -31,12 +43,8 @@ typedef struct atclient_connection { // _is_connection_enabled also serves as an internal boolean to check if the following mbedlts contexts have been // initialized and need to be freed at the end bool _is_connection_enabled : 1; - mbedtls_net_context net; - mbedtls_ssl_context ssl; - mbedtls_ssl_config ssl_config; - mbedtls_x509_crt cacert; - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; + + struct atclient_tls_socket _socket; bool _is_hooks_enabled : 1; atclient_connection_hooks *hooks; @@ -80,8 +88,7 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons * @param value_max_len the maximum length of the data to read, setting this to 0 means no limit * @return int 0 on success */ -int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len, - const size_t value_max_len); +int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len); /** * @brief Write data to the connection diff --git a/packages/atclient/include/atclient/mbedtls.h b/packages/atclient/include/atclient/mbedtls.h index a50a1e0d..63c547fd 100644 --- a/packages/atclient/include/atclient/mbedtls.h +++ b/packages/atclient/include/atclient/mbedtls.h @@ -4,10 +4,7 @@ extern "C" { #endif -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export -#include // IWYU pragma: export +#include // IWYU pragma: export #ifdef __cplusplus } diff --git a/packages/atclient/include/atclient/monitor.h b/packages/atclient/include/atclient/monitor.h index e8874ad9..41e01aef 100644 --- a/packages/atclient/include/atclient/monitor.h +++ b/packages/atclient/include/atclient/monitor.h @@ -9,6 +9,11 @@ extern "C" { #include // IWYU pragma: keep #include +// HACK let's just get it working for now this is so wrong +#ifndef MBEDTLS_ERR_SSL_TIMEOUT +#define MBEDTLS_ERR_SSL_TIMEOUT -37 +#endif + /** * @brief Represents a message received from the monitor connection, typically derived from the prefix of the response * (e.g. "data:ok"'s message type would be "data" = ATCLIENT_MONITOR_MESSAGE_TYPE_DATA_RESPONSE) diff --git a/packages/atclient/include/atclient/socket.h b/packages/atclient/include/atclient/socket.h new file mode 100644 index 00000000..bdb15e22 --- /dev/null +++ b/packages/atclient/include/atclient/socket.h @@ -0,0 +1,119 @@ +#ifndef ATCLIENT_SOCKET_H +#define ATCLIENT_SOCKET_H +#include +#ifndef ATCLIENT_SOCKET_SHARED_H +#include +#endif +#include +#include +#ifdef __cplusplus +extern "C" { +#endif + +// IWYU pragma: begin_exports + +// Export the appropriate platform specific struct implementation +#if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) +#include "socket_mbedtls.h" +#endif + +// IWYU pragma: end_exports + +/** + * @brief Initializes a raw socket + * + * @param socket The socket structure to initialize + */ +void atclient_raw_socket_init(struct atclient_raw_socket *socket); + +/** + * @brief Frees resources associated with a network socket + * + * @param socket The socket structure to free resources from + */ +void atclient_raw_socket_free(struct atclient_raw_socket *socket); + +/** + * @brief Initializes a tls socket with the specified parameters + * + * @param socket The socket structure to initialize + */ +void atclient_tls_socket_init(struct atclient_tls_socket *socket); + +/** + * @brief Configures the SSL on a TLS socket + * + * @param ca_pem The X.509 CA certificates in pem format (leave NULL to use the provided default certificates) + * @param ca_pem_len Length of the ca_pem, ignored if ca_pem is NULL + * + * @return 0 on success, non-zero on failure + * + * @note Should be called after atclient_tls_socket_init, note that this + * contains the rest of the initialization operations which have potential + * to fail + */ +int atclient_tls_socket_configure(struct atclient_tls_socket *socket, unsigned char *ca_pem, size_t ca_pem_len); + +/** + * @brief Frees resources associated with a network socket + * + * @param socket The socket structure to free resources from + */ +void atclient_tls_socket_free(struct atclient_tls_socket *socket); + +/** + * @brief Establishes a connection to the specified host and port using the network socket + * + * @param socket Pointer to the initialized network socket structure + * @param host The hostname or IP address to connect to + * @param port The port number to connect to + * + * @return 0 on success, non-zero on failure + */ +int atclient_tls_socket_connect(struct atclient_tls_socket *socket, const char *host, const uint16_t port); + +/** + * @brief Disconnects and closes an established network socket connection + * + * @param socket Pointer to the network socket structure to disconnect + * + * @return 0 on success, non-zero on failure + */ +int atclient_tls_socket_disconnect(struct atclient_tls_socket *socket); + +/** + * @brief Writes data to an established network socket connection + * + * @param socket Pointer to the network socket structure + * @param value Pointer to the buffer containing data to write + * @param value_len Length of the data to write in bytes + * + * @return 0 on success, non-zero on failure + */ +int atclient_tls_socket_write(struct atclient_tls_socket *socket, const unsigned char *value, size_t value_len); + +/** + * @brief Reads data from an established network socket connection + * + * @param socket Pointer to the network socket structure + * @param value Pointer to the buffer where read data will be stored + * @param value_len Pointer to store the length of data read in bytes + * @param options Options which specify the behaviour of reading the data + * + * @return 0 on success, non-zero on failure + */ +int atclient_tls_socket_read(struct atclient_tls_socket *socket, unsigned char **value, size_t *value_len, + const struct atclient_socket_read_options options); + +/** + * @brief Sets the read timeout for a TLS socket + * + * @param socket Pointer to the initialized TLS socket structure + * @param timeout_ms The timeout value in milliseconds + */ +void atclient_tls_socket_set_read_timeout(struct atclient_tls_socket *socket, const int timeout_ms); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/packages/atclient/include/atclient/socket_mbedtls.h b/packages/atclient/include/atclient/socket_mbedtls.h new file mode 100644 index 00000000..e2668bf0 --- /dev/null +++ b/packages/atclient/include/atclient/socket_mbedtls.h @@ -0,0 +1,34 @@ +// IWYU pragma: private, include "atclient/net_socket.h" +// IWYU pragma: friend "net_socket_mbedtls.*" +#ifndef ATCLIENT_NET_SOCKET_MBEDTLS_H +#define ATCLIENT_NET_SOCKET_MBEDTLS_H +#include +#if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) +#include +#include +#include +#include +#include +#include +#ifdef __cplusplus +extern "C" { +#endif + +// Make this type more portable to consume later +struct atclient_raw_socket { + mbedtls_net_context net; +}; + +struct atclient_tls_socket { + struct atclient_raw_socket raw; + mbedtls_ssl_context ssl; + mbedtls_ssl_config ssl_config; + mbedtls_x509_crt cacert; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; +}; +#ifdef __cplusplus +} +#endif +#endif +#endif diff --git a/packages/atclient/include/atclient/socket_shared.h b/packages/atclient/include/atclient/socket_shared.h new file mode 100644 index 00000000..7e7729a0 --- /dev/null +++ b/packages/atclient/include/atclient/socket_shared.h @@ -0,0 +1,69 @@ +// IWYU pragma: private, include "atclient/net_socket.h" +// IWYU pragma: friend "net_socket_mbedtls.*" +#ifndef ATCLIENT_SOCKET_SHARED_H +#define ATCLIENT_SOCKET_SHARED_H +#include +#ifdef __cplusplus +extern "C" { +#endif + +#include + +#if defined(ATCLIENT_SOCKET_PROVIDER_EXTERNAL) +// Noop, this indicates an external socket provider will be linked +#else +#define ATCLIENT_SOCKET_PROVIDER_MBEDTLS +#endif + +#ifdef ATCLIENT_SOCKET_PROVIDER_EXTERNAL +#include "../atsdk_socket.h" // IWYU pragma: export +#else +// Defined later based on platform specific implementation +struct atclient_tls_socket; + +// Raw socket is only implemented as an internal construct for now +// In the future it will be a supported standalone socket that can +// be used directly +struct atclient_raw_socket; +#endif + +enum atclient_socket_read_type { + // ATCLIENT_SOCKET_READ_NUM_BYTES, + ATCLIENT_SOCKET_READ_UNTIL_CHAR, + ATCLIENT_SOCKET_READ_CLEAR_AT_PROMPT, +}; + +// Define how much we should try to read +struct atclient_socket_read_options { + enum atclient_socket_read_type type; + union { + // size_t num_bytes; + char until_char; + }; +}; + +/** + * @brief Creates read options configured to read until a number of characters have been read + * + * @param bytes The number of characters to try to read + * + * @return struct atclient_socket_read_options Configuration structure for read operation + */ +// struct atclient_socket_read_options atclient_socket_read_num_bytes(size_t bytes); + +/** + * @brief Creates read options configured to read until a specific character is encountered + * + * @param read_until The character to read until (delimiter) + */ +struct atclient_socket_read_options atclient_socket_read_until_char(char read_until); + +/** + * @brief Creates read options configured to read until a specific character is encountered + */ +struct atclient_socket_read_options atclient_socket_read_clear_at_prompt(); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/packages/atclient/src/atclient.c b/packages/atclient/src/atclient.c index c079affa..1f91ff77 100755 --- a/packages/atclient/src/atclient.c +++ b/packages/atclient/src/atclient.c @@ -236,6 +236,12 @@ int atclient_pkam_authenticate(atclient *ctx, const char *atsign, const atclient char *pkam_cmd = NULL; char *atsign_with_at = NULL; + bool should_free_atserver_host = false; + + // expected result on a successful login + size_t expected_len = 1 + strlen(atsign) + strlen("@data:success"); + char expected_buf[expected_len]; + /* * 3. Ensure that the atsign has the @ symbol. */ @@ -249,8 +255,9 @@ int atclient_pkam_authenticate(atclient *ctx, const char *atsign, const atclient /* * 4. Get atdirectory_host and atdirectory_port */ - if(options != NULL && atclient_authenticate_options_is_atdirectory_host_initialized(options) && options->atdirectory_host != NULL && - atclient_authenticate_options_is_atdirectory_port_initialized(options) && options->atdirectory_port != 0) { + if (options != NULL && atclient_authenticate_options_is_atdirectory_host_initialized(options) && + options->atdirectory_host != NULL && atclient_authenticate_options_is_atdirectory_port_initialized(options) && + options->atdirectory_port != 0) { atdirectory_host = options->atdirectory_host; atdirectory_port = options->atdirectory_port; } else { @@ -261,7 +268,6 @@ int atclient_pkam_authenticate(atclient *ctx, const char *atsign, const atclient /* * 5. Get atserver_host and atserver_port */ - bool should_free_atserver_host; if (options != NULL && atclient_authenticate_options_is_atserver_host_initialized(options) && options->atserver_host != NULL && atclient_authenticate_options_is_atserver_port_initialized(options) && options->atserver_port != 0) { @@ -272,9 +278,9 @@ int atclient_pkam_authenticate(atclient *ctx, const char *atsign, const atclient if (atserver_host == NULL || atserver_port == 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, - "Missing atServer host or port. Using %s:%lu atDirectory to find atServer address\n", atdirectory_host, atdirectory_port); - if ((ret = atclient_utils_find_atserver_address(atdirectory_host, - atdirectory_port, atsign, &atserver_host, + "Missing atServer host or port. Using %s:%lu atDirectory to find atServer address\n", atdirectory_host, + atdirectory_port); + if ((ret = atclient_utils_find_atserver_address(atdirectory_host, atdirectory_port, atsign, &atserver_host, &atserver_port)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_utils_find_atserver_address: %d\n", ret); goto exit; @@ -311,6 +317,14 @@ int atclient_pkam_authenticate(atclient *ctx, const char *atsign, const atclient goto exit; } + // We get @data:success when doing pkam auth instead of data:success so read off an '@' + if ((ret = atclient_tls_socket_read(&ctx->atserver_connection._socket, NULL, NULL, + atclient_socket_read_until_char('@'))) != 0) { + + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_connection_send: %d\n", ret); + goto exit; + } + char *str_with_data_prefix = NULL; if (atclient_string_utils_get_substring_position((char *)recv, ATCLIENT_DATA_TOKEN, &str_with_data_prefix) != 0) { ret = 1; @@ -446,7 +460,8 @@ int atclient_cram_authenticate(atclient *ctx, const char *atsign, const char *cr unsigned char digest[SHA_512_DIGEST_SIZE]; memset(digest, 0, sizeof(unsigned char) * SHA_512_DIGEST_SIZE); - + char *atsign_without_at = NULL; + bool should_free_atserver_host = false; /* * 3. Ensure that the atsign has the @ symbol. */ @@ -454,7 +469,7 @@ int atclient_cram_authenticate(atclient *ctx, const char *atsign, const char *cr atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_string_utils_atsign_with_at: %d\n", ret); goto exit; } - char *atsign_without_at = malloc(sizeof(char) * strlen(atsign_with_at) + 1); + atsign_without_at = malloc(sizeof(char) * strlen(atsign_with_at) + 1); if (atsign_without_at == NULL) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Could not allocate memory for atsign_without_at"); ret = -1; @@ -466,7 +481,6 @@ int atclient_cram_authenticate(atclient *ctx, const char *atsign, const char *cr /* * 4. Get atserver_host and atserver_port */ - bool should_free_atserver_host = false; if (options != NULL) { if (atclient_authenticate_options_is_atdirectory_host_initialized(options) && atclient_authenticate_options_is_atdirectory_port_initialized(options)) { @@ -701,7 +715,7 @@ exit: { bool atclient_is_connected(atclient *ctx) { return atclient_connection_is_connected(&(ctx->atserver_connection)); } void atclient_set_read_timeout(atclient *ctx, const int timeout_ms) { - mbedtls_ssl_conf_read_timeout(&ctx->atserver_connection.ssl_config, timeout_ms); + atclient_tls_socket_set_read_timeout(&ctx->atserver_connection._socket, timeout_ms); } static void atclient_set_atsign_initialized(atclient *ctx, const bool initialized) { diff --git a/packages/atclient/src/atclient_get_atkeys.c b/packages/atclient/src/atclient_get_atkeys.c index f8ea08e3..d796180e 100755 --- a/packages/atclient/src/atclient_get_atkeys.c +++ b/packages/atclient/src/atclient_get_atkeys.c @@ -46,7 +46,7 @@ int atclient_get_atkeys(atclient *atclient, atclient_atkey **atkey, size_t *outp char scan_cmd[scan_cmd_size]; - const size_t recv_size = 8192; // TODO change using atclient_connection_read which will handle realloc + const size_t recv_size = 16384; // TODO change using atclient_connection_read which will handle realloc unsigned char recv[recv_size]; size_t recv_len = 0; diff --git a/packages/atclient/src/connection.c b/packages/atclient/src/connection.c index faa8a0d2..3a9dcbf3 100644 --- a/packages/atclient/src/connection.c +++ b/packages/atclient/src/connection.c @@ -1,10 +1,9 @@ #include "atclient/connection.h" -#include "atchops/constants.h" -#include "atclient/cacerts.h" #include "atclient/connection_hooks.h" #include "atclient/constants.h" +#include "atclient/socket.h" +#include "atclient/string_utils.h" #include "atlogger/atlogger.h" -#include #include #include #include @@ -14,13 +13,6 @@ #define TAG "connection" -/* Concatenation of all available CA certificates in PEM format */ -static const char cas_pem[] = LETS_ENCRYPT_ROOT GOOGLE_GLOBAL_SIGN GOOGLE_GTS_ROOT_R1 GOOGLE_GTS_ROOT_R2 - GOOGLE_GTS_ROOT_R3 GOOGLE_GTS_ROOT_R4 ZEROSSL_INTERMEDIATE ""; -static const size_t cas_pem_len = sizeof(cas_pem); - -static void my_debug(void *ctx, int level, const char *file, int line, const char *str); - static void atclient_connection_set_is_connection_enabled(atclient_connection *ctx, const bool should_be_connected); static bool atclient_connection_is_connection_enabled(const atclient_connection *ctx); static void atclient_connection_enable_connection(atclient_connection *ctx); @@ -39,13 +31,6 @@ static void atclient_connection_unset_port(atclient_connection *ctx); void atclient_connection_init(atclient_connection *ctx, atclient_connection_type type) { memset(ctx, 0, sizeof(atclient_connection)); ctx->type = type; - ctx->_is_host_initialized = false; - ctx->host = NULL; - ctx->_is_port_initialized = false; - ctx->port = 0; - ctx->_is_connection_enabled = false; - ctx->_is_hooks_enabled = false; - ctx->hooks = NULL; } void atclient_connection_free(atclient_connection *ctx) { @@ -83,18 +68,7 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons } /* - * 2. Variables - */ - const size_t recv_size = 256; - unsigned char recv[recv_size]; - memset(recv, 0, sizeof(unsigned char) * recv_size); - size_t recv_len = 0; - - const size_t port_str_size = 6; - char port_str[port_str_size]; - - /* - * 3. Disable and Reenable connection + * 2. Disable and Reenable connection */ if (atclient_connection_is_connection_enabled(ctx)) { atclient_connection_disable_connection(ctx); @@ -102,76 +76,19 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons atclient_connection_enable_connection(ctx); - /* - * 3. Parse CA certs - */ - if ((ret = mbedtls_x509_crt_parse(&(ctx->cacert), (unsigned char *)cas_pem, cas_pem_len)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_x509_crt_parse failed with exit code: %d\n", ret); - goto exit; - } - - /* - * 4. Seed the random number generator - */ - if ((ret = mbedtls_ctr_drbg_seed(&(ctx->ctr_drbg), mbedtls_entropy_func, &(ctx->entropy), - (unsigned char *)ATCHOPS_RNG_PERSONALIZATION, - strlen(ATCHOPS_RNG_PERSONALIZATION))) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ctr_drbg_seed failed with exit code: %d\n", ret); - goto exit; - } - - /* - * 5. Start the socket connection - */ - snprintf(port_str, port_str_size, "%d", port); - if ((ret = mbedtls_net_connect(&(ctx->net), host, port_str, MBEDTLS_NET_PROTO_TCP)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_net_connect failed with exit code: %d\n", ret); - goto exit; - } - - /* - * 6. Prepare the SSL connection - */ - if ((ret = mbedtls_ssl_config_defaults(&(ctx->ssl_config), MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, - MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_config_defaults failed with exit code: %d\n", ret); - goto exit; - } - - mbedtls_ssl_conf_ca_chain(&(ctx->ssl_config), &(ctx->cacert), NULL); - mbedtls_ssl_conf_authmode(&(ctx->ssl_config), MBEDTLS_SSL_VERIFY_REQUIRED); - mbedtls_ssl_conf_rng(&(ctx->ssl_config), mbedtls_ctr_drbg_random, &(ctx->ctr_drbg)); - mbedtls_ssl_conf_dbg(&(ctx->ssl_config), my_debug, stdout); - mbedtls_ssl_conf_read_timeout(&(ctx->ssl_config), - ATCLIENT_CLIENT_READ_TIMEOUT_MS); // recv will timeout after X seconds - - if ((ret = mbedtls_ssl_setup(&(ctx->ssl), &(ctx->ssl_config))) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_setup failed with exit code: %d\n", ret); - goto exit; - } - - if ((ret = mbedtls_ssl_set_hostname(&(ctx->ssl), host)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_set_hostname failed with exit code: %d\n", ret); - goto exit; + // 3. Setup ssl configuration + ret = atclient_tls_socket_configure(&ctx->_socket, NULL, 0); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "failed to setup ssl configuration\n"); + return ret; } - mbedtls_ssl_set_bio(&(ctx->ssl), &(ctx->net), mbedtls_net_send, NULL, mbedtls_net_recv_timeout); - - /* - * 7. Perform the SSL handshake - */ - if ((ret = mbedtls_ssl_handshake(&(ctx->ssl))) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_handshake failed with exit code: %d\n", ret); - goto exit; + ret = atclient_tls_socket_connect(&ctx->_socket, host, port); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "failed to connect to %s:%u", host, port); + return ret; } - /* - * 7. Verify the server certificate - */ - if ((ret = mbedtls_ssl_get_verify_result(&(ctx->ssl))) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_get_verify_result failed with exit code: %d\n", ret); - goto exit; - } atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, "Connected\n"); // =============== @@ -179,38 +96,32 @@ int atclient_connection_connect(atclient_connection *ctx, const char *host, cons // =============== // read anything that was already sent - // TODO: better read handling - // - see the improved implementation in atclient_monitor_read - // - this might require special handling since we are attempting to empty read buffer - if ((ret = mbedtls_ssl_read(&(ctx->ssl), recv, recv_size)) <= 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret); + + // FIXME: temporary hack to adapt TLS socket read's heap allocated reading to + // the existing functions which expect stack allocated memory + // all callers of this function should support dynamic memory allocations + // to ensure we are able to read the result in full + // the atclient_tls_socket_read function has a built in limit + unsigned char *buf1, *buf2; + size_t n1, n2; + ret = atclient_tls_socket_read(&ctx->_socket, &buf1, &n1, atclient_socket_read_until_char('@')); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read from the connection\n", ret); goto exit; } + free(buf1); - // TODO: better write handling - // We should retry if we get: - // MBEDTLS_ERR_SSL_WANT_WRITE - // MBEDTLS_ERR_SSL_WANT_READ - // MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS - // MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS - // if return value is positive and less than src_len - // we should continue to write from the appropriate offset - // (multiple writes must be summed to determine total data written) - // press enter - if ((ret = mbedtls_ssl_write(&(ctx->ssl), (const unsigned char *)"\r\n", strlen("\r\n"))) <= 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); + if ((ret = atclient_tls_socket_write(&(ctx->_socket), (const unsigned char *)"\r\n", strlen("\r\n"))) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_write failed with exit code: %d\n", ret); goto exit; } - // read anything that was sent - memset(recv, 0, sizeof(unsigned char) * recv_size); - // TODO: better read handling - // - see the improved implementation in atclient_monitor_read - // - this might require special handling since we are attempting to empty read buffer - if ((ret = mbedtls_ssl_read(&(ctx->ssl), recv, recv_size)) <= 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret); + ret = atclient_tls_socket_read(&ctx->_socket, &buf2, &n2, atclient_socket_read_until_char('@')); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read from the connection\n", ret); goto exit; } + free(buf2); // now we are guaranteed a blank canvas @@ -286,17 +197,8 @@ int atclient_connection_write(atclient_connection *ctx, const unsigned char *val /* * 2. Write the value */ - // TODO: better write handling - // We should retry if we get: - // MBEDTLS_ERR_SSL_WANT_WRITE - // MBEDTLS_ERR_SSL_WANT_READ - // MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS - // MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS - // if return value is positive and less than src_len - // we should continue to write from the appropriate offset - // (multiple writes must be summed to determine total data written) - if ((ret = mbedtls_ssl_write(&(ctx->ssl), value, value_len)) <= 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); + if ((ret = atclient_tls_socket_write(&ctx->_socket, value, value_len)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_write failed with exit code: %d\n", ret); goto exit; } @@ -342,10 +244,40 @@ int atclient_connection_write(atclient_connection *ctx, const unsigned char *val exit: { return ret; } } +// TODO: unit test later, this is a pure function +// TODO: name this better later, this is a private function +// read_buf is the buffer to search +// read_n is the length of the buffer +// read_i is the output of the start of `data:` or other token like error + +static int find_index_past_at_prompt(const unsigned char *read_buf, size_t read_n, size_t *read_i) { + // NOTE: if you change this if, check the second while loop + // it depends on this guard clause + if (read_n != 0 && read_buf[0] != '@') { // Doesn't start with a prompt + return 0; + } + + while (++*read_i < read_n && read_buf[*read_i] != ':') + ; // Walks forward to the end of the buffer or first ':' + if (*read_i == read_n) { // Past the end of the buffer, did not find `:` + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, + "Unable to find command result token `:`, connection should be reset\n"); + return 1; + } + // We are at a `:` + while (--*read_i > 0 && read_buf[*read_i] != '@') + ; // Walk backwards to the first '@' we find + // We are at the first character or last '@' before a `:` + // but the first character is '@' so we are at '@' + + ++*read_i; // move forward one to be after the '@' + + return 0; +} + int atclient_connection_send(atclient_connection *ctx, const unsigned char *src, const size_t src_len, unsigned char *recv, const size_t recv_size, size_t *recv_len) { int ret = 1; - char error_buf[100]; /* * 1. Validate arguments @@ -395,28 +327,19 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src, /* * 4. Write the value */ - // TODO: better write handling - // We should retry if we get: - // MBEDTLS_ERR_SSL_WANT_WRITE - // MBEDTLS_ERR_SSL_WANT_READ - // MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS - // MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS - // if return value is positive and less than src_len - // we should continue to write from the appropriate offset - // (multiple writes must be summed to determine total data written) if (src[src_len - 1] != '\n') { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_WARN, "command does not have a trailing \\n character:\t%s\n", src); } - if ((ret = mbedtls_ssl_write(&ctx->ssl, src, src_len)) <= 0) { // error only when the returned value is negative - mbedtls_strerror(ret, error_buf, sizeof(error_buf)); - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write returned -0x%x: %s\n", -ret, error_buf); + + if ((ret = atclient_tls_socket_write(&ctx->_socket, src, src_len)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_write failed with exit code: %d\n", ret); goto exit; } /* * 5. Print debug log */ - if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG && ret == src_len) { + if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG) { unsigned char *srccopy = NULL; if ((srccopy = malloc(sizeof(unsigned char) * src_len)) != NULL) { memcpy(srccopy, src, src_len); @@ -485,30 +408,39 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src, /* * 9. Read the value */ - int tries = 0; - bool found = false; - size_t l = 0; - do { - // TODO: better read handling - // - see the improved implementation in atclient_monitor_read - if ((ret = mbedtls_ssl_read(&ctx->ssl, recv + l, recv_size - l)) <= 0) { - mbedtls_strerror(ret, error_buf, sizeof(error_buf)); - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read returned err: -0x%x: %s\n", -ret, error_buf); - goto exit; - } - l = l + ret; - for (int i = l; i >= l - ret && i >= 0; i--) { - if (*(recv + i) == '\n' || *(recv + i) == '\r') { - *recv_len = i; - found = true; - break; - } - } - if (found) { - break; - } - } while (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == 0 || !found); - recv[*recv_len] = '\0'; // null terminate the string + + // FIXME: temporary hack to adapt TLS socket read's heap allocated reading to + // the existing functions which expect stack allocated memory + // all callers of this function should support dynamic memory allocations + // to ensure we are able to read the result in full + // the atclient_tls_socket_read function has a built in limit + unsigned char *read_buf; + size_t read_n; + ret = atclient_tls_socket_read(&ctx->_socket, &read_buf, &read_n, atclient_socket_read_until_char('\n')); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read from the connection\n", ret); + goto exit; + } + + size_t read_i = 0; // will store where the start of `:` is (if happy path) + ret = find_index_past_at_prompt(read_buf, read_n, &read_i); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to parse the result read from the connection\n"); + free(read_buf); + goto exit; + } + if (read_n - read_i > recv_size) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, + "Read amount exceeds the stack allocated limit (will be fixed in a future update)\n"); + free(read_buf); + goto exit; + } + + // copy to recv, discarding the prompt + memcpy(recv, read_buf + read_i, read_n - read_i); + free(read_buf); + recv[read_n - 1] = '\0'; // null terminate the string + *recv_len = read_n; /* * 10. Run post read hook, if it exists @@ -569,14 +501,12 @@ int atclient_connection_disconnect(atclient_connection *ctx) { return ret; } - do { - ret = mbedtls_ssl_close_notify(&(ctx->ssl)); - } while (ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == MBEDTLS_ERR_SSL_WANT_READ || ret != 0); + // intentionally disregarding the return value + atclient_tls_socket_disconnect(&ctx->_socket); atclient_connection_disable_connection(ctx); - ret = 0; -exit: { return ret; } + return 0; } bool atclient_connection_is_connected(atclient_connection *ctx) { @@ -626,8 +556,7 @@ bool atclient_connection_is_connected(atclient_connection *ctx) { return true; } -int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len, - const size_t value_max_len) { +int atclient_connection_read(atclient_connection *ctx, unsigned char **value, size_t *value_len) { int ret = 1; /* @@ -648,18 +577,6 @@ int atclient_connection_read(atclient_connection *ctx, unsigned char **value, si return ret; } - /* - * 2. Variables - */ - size_t recv_size; - if (value_max_len == 0) { - // we read 4 KB at a time, TODO: make a constant - recv_size = 4096; - } else { - recv_size = value_max_len; - } - unsigned char *recv = malloc(sizeof(unsigned char) * recv_size); - /* * 3. Call pre_read hook, if it exists */ @@ -685,81 +602,18 @@ int atclient_connection_read(atclient_connection *ctx, unsigned char **value, si /* * 4. Read the value */ - bool found_end = false; - size_t pos = 0; - size_t recv_len = 0; - do { - // TODO: better read handling - // - see the improved implementation in atclient_monitor_read - if ((ret = mbedtls_ssl_read(&(ctx->ssl), recv + pos, recv_size - pos)) <= 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret); - goto exit; - } - pos += ret; - - // check if we found the end of the message - int i = pos; - while (!found_end && i-- > 0) { - found_end = recv[i] == '\n' || recv[i] == '\r'; - } - - if (found_end) { - recv_len = i; - } else { - if (value_max_len != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_WARN, "Message is too long, it exceeds the maximum length of %d\n", - value_max_len); - recv_len = value_max_len; - break; - } else { - recv_size *= 2; // double the buffer size - unsigned char *temp = realloc(recv, sizeof(unsigned char) * recv_size); - if (temp == NULL) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to reallocate memory\n"); - free(recv); - goto exit; - } - recv = temp; - } - } - } while (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == 0 || !found_end); - + ret = atclient_tls_socket_read(&ctx->_socket, value, value_len, atclient_socket_read_until_char('\n')); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read from the connection\n", ret); + goto exit; + } + *value[*value_len - 1] = '\0'; // replace '\n' with '\0' /* * 5. Print debug log */ if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG) { - unsigned char *recvcopy = NULL; - if ((recvcopy = malloc(sizeof(unsigned char) * recv_len)) != NULL) { - memcpy(recvcopy, recv, recv_len); - atlogger_fix_stdout_buffer((char *)recvcopy, recv_len); - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, recv_len, recvcopy, - ATCLIENT_RESET); - free(recvcopy); - } else { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, - "Failed to allocate memory to pretty print the network received buffer\n"); - } - } - - /* - * 6. Set the value and value_len - */ - if (found_end) { - if (recv_len != 0 && recv_len < recv_size) { - recv[recv_len] = '\0'; - } - } - if (value_len != NULL) { - *value_len = recv_len; - } - if (value != NULL) { - if ((*value = malloc(sizeof(unsigned char) * (recv_len + 1))) == NULL) { - ret = 1; - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory for value\n"); - goto exit; - } - memcpy(*value, recv, recv_len); - (*value)[recv_len] = '\0'; + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, *value_len, *value, + ATCLIENT_RESET); } /* @@ -770,9 +624,9 @@ int atclient_connection_read(atclient_connection *ctx, unsigned char **value, si atclient_connection_hook_params params; params.src = NULL; params.src_len = 0; - params.recv = recv; - params.recv_size = recv_size; - params.recv_len = &recv_len; + params.recv = *value; + params.recv_size = *value_len; + params.recv_len = value_len; ret = ctx->hooks->post_read(¶ms); if (ctx->hooks != NULL) { ctx->hooks->_is_nested_call = false; @@ -788,12 +642,6 @@ int atclient_connection_read(atclient_connection *ctx, unsigned char **value, si exit: { return ret; } } -static void my_debug(void *ctx, int level, const char *file, int line, const char *str) { - ((void)level); - fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str); - fflush((FILE *)ctx); -} - static void atclient_connection_set_is_connection_enabled(atclient_connection *ctx, const bool should_be_connected) { ctx->_is_connection_enabled = should_be_connected; } @@ -821,12 +669,7 @@ static void atclient_connection_enable_connection(atclient_connection *ctx) { /* * 3. Enable the connection */ - mbedtls_net_init(&(ctx->net)); - mbedtls_ssl_init(&(ctx->ssl)); - mbedtls_ssl_config_init(&(ctx->ssl_config)); - mbedtls_x509_crt_init(&(ctx->cacert)); - mbedtls_entropy_init(&(ctx->entropy)); - mbedtls_ctr_drbg_init(&(ctx->ctr_drbg)); + atclient_tls_socket_init(&ctx->_socket); /* * 4. Set the connection enabled flag @@ -846,13 +689,10 @@ static void atclient_connection_disable_connection(atclient_connection *ctx) { /* * 2. Free the contexts */ + // This is bad behavior for portability + // We should not free the whole socket if (atclient_connection_is_connection_enabled(ctx)) { - mbedtls_net_free(&(ctx->net)); - mbedtls_ssl_free(&(ctx->ssl)); - mbedtls_ssl_config_free(&(ctx->ssl_config)); - mbedtls_x509_crt_free(&(ctx->cacert)); - mbedtls_entropy_free(&(ctx->entropy)); - mbedtls_ctr_drbg_free(&(ctx->ctr_drbg)); + atclient_tls_socket_free(&ctx->_socket); } /* diff --git a/packages/atclient/src/monitor.c b/packages/atclient/src/monitor.c index 0533a4f1..8b34257a 100644 --- a/packages/atclient/src/monitor.c +++ b/packages/atclient/src/monitor.c @@ -54,7 +54,7 @@ exit: { return ret; } } void atclient_monitor_set_read_timeout(atclient *monitor_conn, const int timeoutms) { - mbedtls_ssl_conf_read_timeout(&(monitor_conn->atserver_connection.ssl_config), timeoutms); + atclient_tls_socket_set_read_timeout(&monitor_conn->atserver_connection._socket, timeoutms); } int atclient_monitor_start(atclient *monitor_conn, const char *regex) { @@ -106,72 +106,21 @@ exit: { int atclient_monitor_read(atclient *monitor_conn, atclient *atclient, atclient_monitor_response *message, atclient_monitor_hooks *hooks) { - int ret = -1; - char *buffertemp = NULL; - char *buffer = NULL; + unsigned char *buffer = NULL; + size_t buffer_len; - size_t chunks = 0; - const size_t chunksize = ATCLIENT_MONITOR_BUFFER_LEN; + int ret = atclient_tls_socket_read(&monitor_conn->atserver_connection._socket, &buffer, &buffer_len, + atclient_socket_read_until_char('@')); - buffer = malloc(sizeof(char) * chunksize); - if (buffer == NULL) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory for buffer\n"); + // TODO: move this later for now it's fine as it should + if (ret == MBEDTLS_ERR_SSL_TIMEOUT) { + // treat a timeout as empty message, non error + message->type = ATCLIENT_MONITOR_MESSAGE_TYPE_EMPTY; + ret = 0; goto exit; } - memset(buffer, 0, sizeof(char) * chunksize); - - bool done_reading = false; - while (!done_reading) { - if (chunks > 0) { - buffertemp = realloc(buffer, sizeof(char) * (chunksize + (chunksize * chunks))); - buffer = buffertemp; - buffertemp = NULL; - } - - size_t off = chunksize * chunks; - size_t i = 0; - while (i < chunksize) { - ret = mbedtls_ssl_read(&(monitor_conn->atserver_connection.ssl), (unsigned char *)buffer + off + i, 1); - // successfully read - if (buffer[off + i] == '\n') { - buffer[off + i] = '\0'; - done_reading = true; - goto exit_loop; - } - // successfully read something, continue - if (ret > 0) { - i++; - continue; - } - // Handle errors - switch (ret) { - // Special error cases where we should try reading again - case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS: // async operation in progress - case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS: // crypto operation in progress - case MBEDTLS_ERR_SSL_WANT_READ: // handshake incomplete - case MBEDTLS_ERR_SSL_WANT_WRITE: // handshake incomplete - usleep(10000); // Try again in 10 milliseconds - break; - // Timeout means nothing to read, return EMPTY message type - case MBEDTLS_ERR_SSL_TIMEOUT: - message->type = ATCLIENT_MONITOR_MESSAGE_TYPE_EMPTY; - return 0; - // Monitor connection bad, must be discarded - case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: // transport closed with close notify - case 0: // transport closed without close notify - default: // Other errors - done_reading = true; - if (ret == 0) { - ret = -1; - } - goto exit_loop; - } - } - chunks = chunks + 1; - } -exit_loop: if (ret <= 0) { // you should reconnect... message->type = ATCLIENT_MONITOR_ERROR_READ; message->error_read.error_code = ret; @@ -179,12 +128,12 @@ int atclient_monitor_read(atclient *monitor_conn, atclient *atclient, atclient_m goto exit; } - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, (int)strlen(buffer), buffer, - ATCLIENT_RESET); + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, (int)strlen((char *)buffer), + buffer, ATCLIENT_RESET); char *messagetype = NULL; char *messagebody = NULL; - ret = parse_message(buffer, &messagetype, &messagebody); + ret = parse_message((char *)buffer, &messagetype, &messagebody); if (ret != 0) { message->type = ATCLIENT_MONITOR_ERROR_PARSE_NOTIFICATION; atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Failed to find message type and message body from: %s\n", buffer); @@ -251,12 +200,10 @@ int atclient_monitor_read(atclient *monitor_conn, atclient *atclient, atclient_m ret = -1; goto exit; } - ret = 0; goto exit; exit: { free(buffer); - free(buffertemp); return ret; } } diff --git a/packages/atclient/src/socket.c b/packages/atclient/src/socket.c new file mode 100644 index 00000000..a2f1710e --- /dev/null +++ b/packages/atclient/src/socket.c @@ -0,0 +1,18 @@ +#include +#include +// Most of the implementation for net_socket is platform specific, +// See the other net_socket_*.c files + +struct atclient_socket_read_options atclient_socket_read_until_char(char until) { + return (struct atclient_socket_read_options){ + ATCLIENT_SOCKET_READ_UNTIL_CHAR, + {.until_char = until}, + }; +} + +// struct atclient_socket_read_options atclient_socket_read_num_bytes(size_t bytes) { +// return (struct atclient_socket_read_options){ +// ATCLIENT_SOCKET_READ_NUM_BYTES, +// {.num_bytes = bytes}, +// }; +// } diff --git a/packages/atclient/src/socket_mbedtls.c b/packages/atclient/src/socket_mbedtls.c new file mode 100644 index 00000000..724ca0d4 --- /dev/null +++ b/packages/atclient/src/socket_mbedtls.c @@ -0,0 +1,456 @@ +// These two headers must be included in a specific order +#include "atchops/platform.h" // IWYU pragma: keep +// Don't move them +#include "atclient/socket.h" + +#if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) +#include "atchops/constants.h" +#include "atclient/cacerts.h" +#include "atclient/constants.h" +#include "atclient/socket_mbedtls.h" +#include "atlogger/atlogger.h" +#include "mbedtls/error.h" +#include "mbedtls/net_sockets.h" +#include "mbedtls/x509_crt.h" +#include +#include +#include +#ifndef PRIu16 +#define PRIu16 "hu" +#endif + +#define TAG "atclient_socket_mbedtls" + +#ifndef SIZE_T_MAX +#define SIZE_T_MAX (size_t) - 1 +#endif +// must be less than the maximum for a positive int +// otherwise read_num_bytes may have undefined behavior +#define READ_BLOCK_LEN 4096 + +// I think the -1 is unnecessary but better safe than sorry +#define MAX_READ_BLOCKS (SIZE_T_MAX / READ_BLOCK_LEN - 1) +static const int MAX_READ_TIMEOUTS = 3; + +// Hey fellow engineer, if you want to understand this file, you better have this link on hand: +// https://mbed-tls.readthedocs.io/projects/api/en/v3.6.1/api/file/ssl_8h +// mbedtls sockets are tricky, reading and writing have gotchas +// so you NEED to look at the documentation when you work with them. + +static void my_debug(void *ctx, int level, const char *file, int line, const char *str) { + ((void)level); + fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str); + fflush((FILE *)ctx); +} + +// Default CA certs for net_sockets +static const char default_ca_pem[] = LETS_ENCRYPT_ROOT GOOGLE_GLOBAL_SIGN GOOGLE_GTS_ROOT_R1 GOOGLE_GTS_ROOT_R2 + GOOGLE_GTS_ROOT_R3 GOOGLE_GTS_ROOT_R4 ZEROSSL_INTERMEDIATE ""; + +void atclient_raw_socket_init(struct atclient_raw_socket *socket) { mbedtls_net_init(&socket->net); } + +void atclient_raw_socket_free(struct atclient_raw_socket *socket) { mbedtls_net_free(&socket->net); } + +void atclient_tls_socket_init(struct atclient_tls_socket *socket) { + memset(socket, 0, sizeof(struct atclient_tls_socket)); + atclient_raw_socket_init(&socket->raw); + + mbedtls_x509_crt_init(&(socket->cacert)); + mbedtls_ctr_drbg_init(&(socket->ctr_drbg)); + mbedtls_entropy_init(&(socket->entropy)); + mbedtls_ssl_config_init(&socket->ssl_config); + mbedtls_ssl_init(&socket->ssl); +} + +void atclient_tls_socket_set_read_timeout(struct atclient_tls_socket *socket, const int timeout_ms) { + mbedtls_ssl_conf_read_timeout(&socket->ssl_config, timeout_ms); +} + +// Expected to be called after init +int atclient_tls_socket_configure(struct atclient_tls_socket *socket, unsigned char *ca_pem, size_t ca_pem_len) { + int ret = 1; + + // 1. Parse the CA certs + unsigned char *pem; + size_t pem_len; + if (ca_pem == NULL) { + pem = (unsigned char *)default_ca_pem; + pem_len = sizeof(default_ca_pem); + } else { + pem = ca_pem; + pem_len = ca_pem_len; + } + + if ((ret = mbedtls_x509_crt_parse(&(socket->cacert), pem, pem_len)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_x509_crt_parse failed with exit code: %d\n", ret); + goto cancel_x509; + } + + // 2. Seed RNG + if ((ret = mbedtls_ctr_drbg_seed(&(socket->ctr_drbg), mbedtls_entropy_func, &(socket->entropy), + (unsigned char *)ATCHOPS_RNG_PERSONALIZATION, + strlen(ATCHOPS_RNG_PERSONALIZATION))) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ctr_drbg_seed failed with exit code: %d\n", ret); + goto cancel_seed; + } + + // 3. Configure SSL + if ((ret = mbedtls_ssl_config_defaults(&(socket->ssl_config), MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_config_defaults failed with exit code: %d\n", ret); + goto cancel_ssl_config; + } + + mbedtls_ssl_conf_ca_chain(&(socket->ssl_config), &(socket->cacert), NULL); + mbedtls_ssl_conf_authmode(&(socket->ssl_config), MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_ssl_conf_rng(&(socket->ssl_config), mbedtls_ctr_drbg_random, &(socket->ctr_drbg)); + mbedtls_ssl_conf_dbg(&(socket->ssl_config), my_debug, stdout); + mbedtls_ssl_conf_read_timeout(&(socket->ssl_config), + ATCLIENT_CLIENT_READ_TIMEOUT_MS); // recv will timeout after X seconds + + // 4. Prepare the SSL context + if ((ret = mbedtls_ssl_setup(&(socket->ssl), &(socket->ssl_config))) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_setup failed with exit code: %d\n", ret); + goto cancel_ssl; + } + + // we made it to the happy path: skip freeing all the things + ret = 0; + goto exit; +cancel_ssl: + mbedtls_ssl_free(&socket->ssl); +cancel_ssl_config: + mbedtls_ssl_config_free(&socket->ssl_config); +cancel_seed: + mbedtls_entropy_free(&(socket->entropy)); + mbedtls_ctr_drbg_free(&(socket->ctr_drbg)); +cancel_x509: + mbedtls_x509_crt_free(&(socket->cacert)); +exit: + return ret; +} + +void atclient_tls_socket_free(struct atclient_tls_socket *socket) { + if (socket != NULL) { + atclient_raw_socket_free(&socket->raw); + mbedtls_ssl_free(&socket->ssl); + mbedtls_ssl_config_free(&socket->ssl_config); + mbedtls_entropy_free(&(socket->entropy)); + mbedtls_ctr_drbg_free(&(socket->ctr_drbg)); + mbedtls_x509_crt_free(&(socket->cacert)); + memset(socket, 0, sizeof(struct atclient_tls_socket)); + } +} + +static int atclient_tls_socket_ssl_handshake(struct atclient_tls_socket *socket, const char *host); + +int atclient_tls_socket_connect(struct atclient_tls_socket *socket, const char *host, const uint16_t port) { + if (socket == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "tls socket is when trying to connect NULL\n"); + return 1; + } + + char port_str[5]; + snprintf(port_str, 5, "%" PRIu16, port); + + int ret; + // 1. Connect + // TODO: move to raw_connect function + if ((ret = mbedtls_net_connect(&socket->raw.net, host, port_str, MBEDTLS_NET_PROTO_TCP)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_net_connect failed with exit code: %d\n", ret); + return ret; + } + + if ((ret = atclient_tls_socket_ssl_handshake(socket, host)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_ssl_handshake failed with exit code: %d\n", + ret); + return ret; + } + + return ret; +} + +static int atclient_tls_socket_ssl_handshake(struct atclient_tls_socket *socket, const char *host) { + int ret; + // 2. Set SSL hostname + if ((ret = mbedtls_ssl_set_hostname(&socket->ssl, host)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_set_hostname failed with exit code: %d\n", ret); + return ret; + } + + // 3. Link SSL to the raw socket + mbedtls_ssl_set_bio(&socket->ssl, &socket->raw.net, mbedtls_net_send, NULL, mbedtls_net_recv_timeout); + + /* + * 4. Do SSL handshake + */ + if ((ret = mbedtls_ssl_handshake(&socket->ssl)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_handshake failed with exit code: %d\n", ret); + return ret; + } + + /* + * 5. Verify the certificate + */ + if ((ret = mbedtls_ssl_get_verify_result(&socket->ssl)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_get_verify_result failed with exit code: %d\n", ret); + return ret; + } + + return ret; +} + +int atclient_tls_socket_disconnect(struct atclient_tls_socket *socket) { + if (socket == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_disconnect: socket is NULL\n"); + return 1; + } + + int ret = mbedtls_ssl_close_notify(&(socket->ssl)); + + // If we got a non want read/write error don't try again: + // we may segfault or deadlock trying to disconnect + // just warn that we silently closed the socket and move on + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_WARN, + "mbedtls_ssl_close_notify failed, socket will be silently closed, exit code: %d\n", ret); + return ret; + } + + return ret; +} + +static bool should_continue_write(size_t pos, size_t len, int ret) { + return pos < len || ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || + ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS; +} +int atclient_tls_socket_write(struct atclient_tls_socket *socket, const unsigned char *value, size_t value_len) { + if (socket == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_write: socket is NULL\n"); + return 1; + } + if (value == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_write: value is NULL\n"); + return 2; + } + size_t pos = 0; + int ret; + do { + ret = mbedtls_ssl_write(&socket->ssl, value + pos, value_len - pos); + if (ret > 0) { + pos += (size_t)ret; + ret = 0; + continue; + } + } while (should_continue_write(pos, value_len, ret)); + if (ret != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret); + } + return ret; +} + +static int atclient_tls_socket_read_until_char(struct atclient_tls_socket *socket, unsigned char **value, + size_t *value_len, char until_char); + +// static int atclient_tls_socket_read_num_bytes(struct atclient_tls_socket *socket, unsigned char **value, +// size_t *value_len, size_t num_bytes); +int atclient_tls_socket_read(struct atclient_tls_socket *socket, unsigned char **value, size_t *value_len, + const struct atclient_socket_read_options options) { + if (socket == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_read: socket is NULL\n"); + return 1; + } + + switch (options.type) { + + // case ATCLIENT_SOCKET_READ_NUM_BYTES: + // return atclient_tls_socket_read_num_bytes(socket, value, value_len, options.num_bytes); + case ATCLIENT_SOCKET_READ_UNTIL_CHAR: + return atclient_tls_socket_read_until_char(socket, value, value_len, options.until_char); + default: + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_tls_socket_read: read type %d is not a valid type\n", + options.type); + return 4; + } +} +// int atclient_tls_socket_read_num_bytes(struct atclient_tls_socket *socket, unsigned char **value, size_t *value_len, +// size_t num_bytes) { +// // Assume params have been validated by socket_read +// int ret; +// unsigned char *recv = NULL; +// size_t blocks = 0; // number of allocated blocks +// +// do { +// size_t offset = READ_BLOCK_LEN * blocks; // offset to current block +// // Allocate memory +// unsigned char *temp = realloc(recv, sizeof(unsigned char) * (offset + READ_BLOCK_LEN)); +// if (temp == NULL) { +// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate receive buffer\n"); +// if (recv != NULL) { +// free(recv); +// } +// return 1; +// } +// recv = temp; // once we ensure realloc was successful we set recv to the new memory +// +// // Read into current block +// size_t pos = 0; // position in current block +// do { +// size_t remaining_for_block; +// if (READ_BLOCK_LEN + offset > num_bytes) { +// // We are in the final block, so only read the amount that will make +// // us reach num_bytes +// remaining_for_block = num_bytes - (offset + pos); +// } else { +// remaining_for_block = READ_BLOCK_LEN - pos; +// } +// +// ret = mbedtls_ssl_read(&socket->ssl, recv + offset + pos, remaining_for_block); +// if (ret > 0) { +// pos += ret; // successful read, increment pos +// +// if (offset + pos == num_bytes) { +// *value = recv; +// *value_len = (offset + pos); +// return 0; // The only return where recv should not be freed +// } +// +// continue; // not done reading yet +// } +// +// // handle non-happy path +// switch (ret) { +// case 0: // connection is closed +// case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: // connection is closed +// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Socket closed while reading: %d\n", ret); +// ret = MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY; // ensure ret val is not 0 +// free(recv); +// return ret; +// case MBEDTLS_ERR_SSL_WANT_READ: // handshake incomplete +// case MBEDTLS_ERR_SSL_WANT_WRITE: // handshake incomplete +// case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS: // async operation in progress +// case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS: // crypto operation in progress +// // async error, we need to try again +// break; +// case MBEDTLS_ERR_SSL_TIMEOUT: // treat a timeout as +// +// if (value != NULL) { +// *value = recv; +// } else { +// free(recv); +// } +// if (value_len != NULL) { +// *value_len = (offset + pos); +// } +// return 0; // The only return where recv should not be freed +// // unexpected errors while reading +// default: +// if (ret > 0) { +// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Unexpected read value %d\n", ret); +// } else { +// char strerr[512]; +// mbedtls_strerror(ret, strerr, 512); +// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "%s", strerr); +// } +// free(recv); +// return ret; +// } // don't put anything after switch without checking it first +// } while (pos < READ_BLOCK_LEN); +// blocks++; +// } while (blocks < MAX_READ_BLOCKS); +// // We should only arrive at this point if we max out blocks +// // Every other code path should return early +// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read within the maximum allowed number of read +// blocks\n"); free(recv); return 1; +// } + +int atclient_tls_socket_read_until_char(struct atclient_tls_socket *socket, unsigned char **value, size_t *value_len, + char until_char) { + // Assume params have been validated by socket_read + int ret; + unsigned char *recv = NULL; + size_t blocks = 0; // number of allocated blocks + + do { + size_t offset = READ_BLOCK_LEN * blocks; // offset to current block + // Allocate memory + unsigned char *temp = realloc(recv, sizeof(unsigned char) * (offset + READ_BLOCK_LEN)); + if (temp == NULL) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate receive buffer\n"); + if (recv != NULL) { + free(recv); + } + return 1; + } + recv = temp; // once we ensure realloc was successful we set recv to the new memory + + // Read into current block + size_t pos = 0; // position in current block + int timeout_count = 0; + do { + // When reading to a character we must read byte by byte to prevent + // over reading and risk corrupting the next message + // do not change the 1 without consulting the code below + ret = mbedtls_ssl_read(&socket->ssl, recv + offset + pos, 1); + if (ret > 0) { + if (until_char == *(recv + offset + pos)) { // check if this is the char we need + if (value != NULL) { + *value = recv; + } else { + free(recv); + } + if (value_len != NULL) { + *value_len = offset + pos + 1; + } + // The only return where recv should not be freed + return 0; + } + pos += ret; // successful read, increment position + continue; // continue if not found char yet + } + // handle non-happy path + switch (ret) { + case 0: // connection is closed + case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: // connection is closed + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Socket closed while reading: %d\n", ret); + ret = MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY; // ensure ret val is not 0 + free(recv); + return ret; + case MBEDTLS_ERR_SSL_WANT_READ: // handshake incomplete + case MBEDTLS_ERR_SSL_WANT_WRITE: // handshake incomplete + case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS: // async operation in progress + case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS: // crypto operation in progress + // async error, we need to try again + break; + case MBEDTLS_ERR_SSL_TIMEOUT: // timeout before reading the expected character + // timeout usually indicates nothing to read + timeout_count++; + if (timeout_count == MAX_READ_TIMEOUTS) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Failed to read the full message after %d attempts\n", + MAX_READ_TIMEOUTS); + return ret; + } + usleep(1000); + break; + // unexpected errors while reading + default: + if (ret > 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Unexpected read value %d\n", ret); + } else { + char strerr[512]; + mbedtls_strerror(ret, strerr, 512); + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "%s", strerr); + } + free(recv); + return ret; + } // don't put anything after switch without checking it first + } while (pos < READ_BLOCK_LEN); + blocks++; + } while (blocks < MAX_READ_BLOCKS); + // We should only arrive at this point if we max out blocks + // Every other code path should return early + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read within the maximum allowed number of read blocks\n"); + free(recv); + return 1; +} +#endif diff --git a/packages/atclient/src/socket_raw_mbedtls.c b/packages/atclient/src/socket_raw_mbedtls.c new file mode 100644 index 00000000..f16a7937 --- /dev/null +++ b/packages/atclient/src/socket_raw_mbedtls.c @@ -0,0 +1,19 @@ +#include "atclient/socket.h" +#include +#include + +#if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) + +void atclient_raw_socket_init(struct atclient_raw_socket *socket) { + memset(socket, 0, sizeof(struct atclient_raw_socket)); + mbedtls_net_init(&socket->net); +} + +void atclient_raw_socket_free(struct atclient_raw_socket *socket) { + if (socket != NULL) { + mbedtls_net_free(&socket->net); + memset(socket, 0, sizeof(struct atclient_raw_socket)); + } +} + +#endif diff --git a/tests/functional_tests/lib/src/config.c b/tests/functional_tests/lib/src/config.c index c35fe460..622bf3a5 100644 --- a/tests/functional_tests/lib/src/config.c +++ b/tests/functional_tests/lib/src/config.c @@ -1,11 +1,12 @@ +#include "functional_tests/config.h" #include -#include +#include #include #include -#include -#include "functional_tests/config.h" +#include -int functional_tests_get_atkeys_path(const char *atsign, const size_t atsignlen, char *path, const size_t pathsize, size_t *pathlen) { +int functional_tests_get_atkeys_path(const char *atsign, const size_t atsignlen, char *path, const size_t pathsize, + size_t *pathlen) { // for home directory struct passwd *pw = getpwuid(getuid()); const char *homedir = pw->pw_dir; diff --git a/tests/functional_tests/tests/test_atclient_connection.c b/tests/functional_tests/tests/test_atclient_connection.c index fde2f855..7ab8a1ac 100644 --- a/tests/functional_tests/tests/test_atclient_connection.c +++ b/tests/functional_tests/tests/test_atclient_connection.c @@ -250,7 +250,7 @@ exit: { static int test_6_is_connected_should_be_false(atclient_connection *conn) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, "test_6_is_connected_should_be_false Begin\n"); - int ret= 1; + int ret = 1; if (atclient_connection_is_connected(conn)) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, @@ -419,7 +419,7 @@ static int test_14_simulate_server_not_responding(atclient_connection *conn) { int ret = 1; // simulate server not responding - ret = mbedtls_ssl_close_notify(&conn->ssl); + ret = mbedtls_ssl_close_notify(&conn->_socket.ssl); if (ret != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to close notify: %d\n", ret); goto exit; @@ -454,7 +454,8 @@ static int test_15_send_should_fail(atclient_connection *conn) { goto exit; } - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, "Successfully failed at sending message to a disconnected connection: %d\n", ret); + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, + "Successfully failed at sending message to a disconnected connection: %d\n", ret); ret = 0; goto exit; diff --git a/tests/functional_tests/tests/test_atclient_get_atkeys.c b/tests/functional_tests/tests/test_atclient_get_atkeys.c index 6f6c1213..c949843d 100644 --- a/tests/functional_tests/tests/test_atclient_get_atkeys.c +++ b/tests/functional_tests/tests/test_atclient_get_atkeys.c @@ -146,6 +146,7 @@ exit: { static int test_4_atclient_get_atkeys_null_regex(atclient *ctx, const char *scan_regex, const bool showhidden) { int ret = 1; + return 0; // NOTE: disabled temporarily DONT LET THIS GET MERGED atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "test_4_atclient_get_atkeys_null_regex\n"); diff --git a/tests/functional_tests/tests/test_atclient_monitor.c b/tests/functional_tests/tests/test_atclient_monitor.c index a65ff231..fa291085 100644 --- a/tests/functional_tests/tests/test_atclient_monitor.c +++ b/tests/functional_tests/tests/test_atclient_monitor.c @@ -176,22 +176,22 @@ static int send_notification(atclient *atclient) { goto exit; } - if((ret = atclient_notify_params_set_operation(¶ms, ATCLIENT_NOTIFY_OPERATION_UPDATE)) != 0) { + if ((ret = atclient_notify_params_set_operation(¶ms, ATCLIENT_NOTIFY_OPERATION_UPDATE)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set operation: %d\n", ret); goto exit; } - if((ret = atclient_notify_params_set_atkey(¶ms, &atkey)) != 0) { + if ((ret = atclient_notify_params_set_atkey(¶ms, &atkey)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set atkey: %d\n", ret); goto exit; } - if((ret = atclient_notify_params_set_value(¶ms, ATKEY_VALUE)) != 0) { + if ((ret = atclient_notify_params_set_value(¶ms, ATKEY_VALUE)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set value: %d\n", ret); goto exit; } - if((ret = atclient_notify_params_set_should_encrypt(¶ms, true)) != 0) { + if ((ret = atclient_notify_params_set_should_encrypt(¶ms, true)) != 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to set should_encrypt: %d\n", ret); goto exit; } @@ -231,7 +231,7 @@ static int monitor_for_notification(atclient *monitor_conn, atclient *atclient2) continue; } - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, "Decrypted Value: %s\n",message.notification.decrypted_value); + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, "Decrypted Value: %s\n", message.notification.decrypted_value); // compare the decrypted value with the expected value if (strcmp(message.notification.decrypted_value, ATKEY_VALUE) != 0) { @@ -322,8 +322,8 @@ static int test_4_re_pkam_auth_and_start_monitor(atclient *monitor_conn) { char *atserver_host = strdup(monitor_conn->atserver_connection.host); int atserver_port = monitor_conn->atserver_connection.port; - if ((ret = atclient_monitor_pkam_authenticate(monitor_conn, monitor_conn->atsign, &(monitor_conn->atkeys), - NULL)) != 0) { + if ((ret = atclient_monitor_pkam_authenticate(monitor_conn, monitor_conn->atsign, &(monitor_conn->atkeys), NULL)) != + 0) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to authenticate with PKAM: %d\n", ret); goto exit; } @@ -377,4 +377,4 @@ exit: { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_INFO, "test_6_monitor_for_notification End: %d\n", ret); return ret; } -} \ No newline at end of file +} diff --git a/tests/functional_tests/tests/test_atclient_sharedkey.c b/tests/functional_tests/tests/test_atclient_sharedkey.c index 39d0ac77..17cb1221 100644 --- a/tests/functional_tests/tests/test_atclient_sharedkey.c +++ b/tests/functional_tests/tests/test_atclient_sharedkey.c @@ -161,8 +161,9 @@ static int test_2_get_as_sharedby(atclient *atclient) { goto exit; } - if((ret = atclient_get_shared_key_request_options_set_store_atkey_metadata(&request_options, true)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_get_shared_key_request_options_set_store_atkey_metadata: %d\n", ret); + if ((ret = atclient_get_shared_key_request_options_set_store_atkey_metadata(&request_options, true)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, + "atclient_get_shared_key_request_options_set_store_atkey_metadata: %d\n", ret); goto exit; } @@ -224,8 +225,9 @@ static int test_3_get_as_sharedwith(atclient *atclient2) { goto exit; } - if((ret = atclient_get_shared_key_request_options_set_store_atkey_metadata(&request_options, true)) != 0) { - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "atclient_get_shared_key_request_options_set_store_atkey_metadata: %d\n", ret); + if ((ret = atclient_get_shared_key_request_options_set_store_atkey_metadata(&request_options, true)) != 0) { + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, + "atclient_get_shared_key_request_options_set_store_atkey_metadata: %d\n", ret); goto exit; } From 4ddebdaba88dfbe416f32c752f9dc943afd59f87 Mon Sep 17 00:00:00 2001 From: XavierChanth Date: Mon, 16 Dec 2024 16:15:56 -0500 Subject: [PATCH 2/3] perf: improve atclient_connection struct packing --- packages/atclient/include/atclient/connection.h | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/packages/atclient/include/atclient/connection.h b/packages/atclient/include/atclient/connection.h index 42cc695d..a9a2ffb2 100644 --- a/packages/atclient/include/atclient/connection.h +++ b/packages/atclient/include/atclient/connection.h @@ -30,23 +30,21 @@ typedef enum atclient_connection_type { typedef struct atclient_connection { atclient_connection_type type; // set in atclient_connection_init - + uint16_t port; // example: 64 bool _is_host_initialized : 1; - char *host; // example: "root.atsign.org" - bool _is_port_initialized : 1; - uint16_t port; // example: 64 + bool _is_connection_enabled : 1; + bool _is_hooks_enabled : 1; + + char *host; // example: "root.atsign.org" // atclient_connection_connect sets this to true and atclient_connection_disconnect sets this to false // this does not mean that the connection is still alive, it just means that the connection was established at least // once, at some point, check atclient_connection_is_connected for a live status on the connection // _is_connection_enabled also serves as an internal boolean to check if the following mbedlts contexts have been // initialized and need to be freed at the end - bool _is_connection_enabled : 1; struct atclient_tls_socket _socket; - - bool _is_hooks_enabled : 1; atclient_connection_hooks *hooks; } atclient_connection; From e36b5932aedad7553971f666a06a6344bb98c1b7 Mon Sep 17 00:00:00 2001 From: XavierChanth Date: Tue, 17 Dec 2024 09:02:35 -0500 Subject: [PATCH 3/3] chore: address review comments & cleanup work --- .../atclient/include/atclient/connection.h | 2 +- packages/atclient/include/atclient/monitor.h | 6 +- packages/atclient/include/atclient/socket.h | 16 ++++ .../include/atclient/socket_mbedtls.h | 10 +- .../atclient/include/atclient/socket_shared.h | 4 +- packages/atclient/src/monitor.c | 11 +-- packages/atclient/src/socket.c | 2 - packages/atclient/src/socket_mbedtls.c | 94 +------------------ .../tests/test_atclient_get_atkeys.c | 1 - 9 files changed, 32 insertions(+), 114 deletions(-) diff --git a/packages/atclient/include/atclient/connection.h b/packages/atclient/include/atclient/connection.h index a9a2ffb2..118245b2 100644 --- a/packages/atclient/include/atclient/connection.h +++ b/packages/atclient/include/atclient/connection.h @@ -1,7 +1,7 @@ /* * * The connection family of types and methods represents a single connection to - * if you want a pure socket representation see net_socket.h. + * if you want a pure socket representation see socket.h. * * At the moment _socket represents a singular tcp socket, but in the future it may be altered * to be a union of different connection types, such as a websocket or other construct. diff --git a/packages/atclient/include/atclient/monitor.h b/packages/atclient/include/atclient/monitor.h index 41e01aef..db57877e 100644 --- a/packages/atclient/include/atclient/monitor.h +++ b/packages/atclient/include/atclient/monitor.h @@ -6,14 +6,10 @@ extern "C" { #include "atclient/atclient.h" #include "atclient/atnotification.h" +#include "atclient/socket.h" #include // IWYU pragma: keep #include -// HACK let's just get it working for now this is so wrong -#ifndef MBEDTLS_ERR_SSL_TIMEOUT -#define MBEDTLS_ERR_SSL_TIMEOUT -37 -#endif - /** * @brief Represents a message received from the monitor connection, typically derived from the prefix of the response * (e.g. "data:ok"'s message type would be "data" = ATCLIENT_MONITOR_MESSAGE_TYPE_DATA_RESPONSE) diff --git a/packages/atclient/include/atclient/socket.h b/packages/atclient/include/atclient/socket.h index bdb15e22..0a205f90 100644 --- a/packages/atclient/include/atclient/socket.h +++ b/packages/atclient/include/atclient/socket.h @@ -10,6 +10,22 @@ extern "C" { #endif +#ifndef ATCLIENT_SSL_TIMEOUT_EXITCODE + +#if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) +#define ATCLIENT_SSL_TIMEOUT_EXITCODE MBEDTLS_ERR_SSL_TIMEOUT + +#elif defined(ATCLIENT_SOCKET_PROVIDER_ARDUINO_BEARSSL) +// Most arduino libraries only use -1 or positive integers +#define ATCLIENT_SSL_TIMEOUT_EXITCODE -101 + +#else +#error "ATCLIENT_ERR_SSL_TIMEOUT is undefined" + +#endif + +#endif + // IWYU pragma: begin_exports // Export the appropriate platform specific struct implementation diff --git a/packages/atclient/include/atclient/socket_mbedtls.h b/packages/atclient/include/atclient/socket_mbedtls.h index e2668bf0..09ef806c 100644 --- a/packages/atclient/include/atclient/socket_mbedtls.h +++ b/packages/atclient/include/atclient/socket_mbedtls.h @@ -1,7 +1,7 @@ -// IWYU pragma: private, include "atclient/net_socket.h" -// IWYU pragma: friend "net_socket_mbedtls.*" -#ifndef ATCLIENT_NET_SOCKET_MBEDTLS_H -#define ATCLIENT_NET_SOCKET_MBEDTLS_H +// IWYU pragma: private, include "atclient/socket.h" +// IWYU pragma: friend "socket_mbedtls.*" +#ifndef ATCLIENT_SOCKET_MBEDTLS_H +#define ATCLIENT_SOCKET_MBEDTLS_H #include #if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) #include @@ -14,7 +14,7 @@ extern "C" { #endif -// Make this type more portable to consume later +// TODO: Make this type more portable to consume later struct atclient_raw_socket { mbedtls_net_context net; }; diff --git a/packages/atclient/include/atclient/socket_shared.h b/packages/atclient/include/atclient/socket_shared.h index 7e7729a0..8d7624ea 100644 --- a/packages/atclient/include/atclient/socket_shared.h +++ b/packages/atclient/include/atclient/socket_shared.h @@ -1,5 +1,5 @@ -// IWYU pragma: private, include "atclient/net_socket.h" -// IWYU pragma: friend "net_socket_mbedtls.*" +// IWYU pragma: private, include "atclient/socket.h" +// IWYU pragma: friend "socket_mbedtls.*" #ifndef ATCLIENT_SOCKET_SHARED_H #define ATCLIENT_SOCKET_SHARED_H #include diff --git a/packages/atclient/src/monitor.c b/packages/atclient/src/monitor.c index 8b34257a..9eadbae6 100644 --- a/packages/atclient/src/monitor.c +++ b/packages/atclient/src/monitor.c @@ -111,20 +111,17 @@ int atclient_monitor_read(atclient *monitor_conn, atclient *atclient, atclient_m size_t buffer_len; int ret = atclient_tls_socket_read(&monitor_conn->atserver_connection._socket, &buffer, &buffer_len, - atclient_socket_read_until_char('@')); + atclient_socket_read_until_char('\n')); - // TODO: move this later for now it's fine as it should - if (ret == MBEDTLS_ERR_SSL_TIMEOUT) { + if (ret == ATCLIENT_SSL_TIMEOUT_EXITCODE) { // treat a timeout as empty message, non error message->type = ATCLIENT_MONITOR_MESSAGE_TYPE_EMPTY; ret = 0; goto exit; - } - - if (ret <= 0) { // you should reconnect... + } else if (ret != 0) { // you should reconnect... message->type = ATCLIENT_MONITOR_ERROR_READ; message->error_read.error_code = ret; - atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Read nothing from the monitor connection: %d\n", ret); + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Error: monitor exited with code %d\n", ret); goto exit; } diff --git a/packages/atclient/src/socket.c b/packages/atclient/src/socket.c index a2f1710e..0b639485 100644 --- a/packages/atclient/src/socket.c +++ b/packages/atclient/src/socket.c @@ -1,7 +1,5 @@ #include #include -// Most of the implementation for net_socket is platform specific, -// See the other net_socket_*.c files struct atclient_socket_read_options atclient_socket_read_until_char(char until) { return (struct atclient_socket_read_options){ diff --git a/packages/atclient/src/socket_mbedtls.c b/packages/atclient/src/socket_mbedtls.c index 724ca0d4..51d4b4ef 100644 --- a/packages/atclient/src/socket_mbedtls.c +++ b/packages/atclient/src/socket_mbedtls.c @@ -1,6 +1,7 @@ // These two headers must be included in a specific order #include "atchops/platform.h" // IWYU pragma: keep // Don't move them +#include "atclient/monitor.h" #include "atclient/socket.h" #if defined(ATCLIENT_SOCKET_PROVIDER_MBEDTLS) @@ -273,96 +274,6 @@ int atclient_tls_socket_read(struct atclient_tls_socket *socket, unsigned char * return 4; } } -// int atclient_tls_socket_read_num_bytes(struct atclient_tls_socket *socket, unsigned char **value, size_t *value_len, -// size_t num_bytes) { -// // Assume params have been validated by socket_read -// int ret; -// unsigned char *recv = NULL; -// size_t blocks = 0; // number of allocated blocks -// -// do { -// size_t offset = READ_BLOCK_LEN * blocks; // offset to current block -// // Allocate memory -// unsigned char *temp = realloc(recv, sizeof(unsigned char) * (offset + READ_BLOCK_LEN)); -// if (temp == NULL) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate receive buffer\n"); -// if (recv != NULL) { -// free(recv); -// } -// return 1; -// } -// recv = temp; // once we ensure realloc was successful we set recv to the new memory -// -// // Read into current block -// size_t pos = 0; // position in current block -// do { -// size_t remaining_for_block; -// if (READ_BLOCK_LEN + offset > num_bytes) { -// // We are in the final block, so only read the amount that will make -// // us reach num_bytes -// remaining_for_block = num_bytes - (offset + pos); -// } else { -// remaining_for_block = READ_BLOCK_LEN - pos; -// } -// -// ret = mbedtls_ssl_read(&socket->ssl, recv + offset + pos, remaining_for_block); -// if (ret > 0) { -// pos += ret; // successful read, increment pos -// -// if (offset + pos == num_bytes) { -// *value = recv; -// *value_len = (offset + pos); -// return 0; // The only return where recv should not be freed -// } -// -// continue; // not done reading yet -// } -// -// // handle non-happy path -// switch (ret) { -// case 0: // connection is closed -// case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: // connection is closed -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Socket closed while reading: %d\n", ret); -// ret = MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY; // ensure ret val is not 0 -// free(recv); -// return ret; -// case MBEDTLS_ERR_SSL_WANT_READ: // handshake incomplete -// case MBEDTLS_ERR_SSL_WANT_WRITE: // handshake incomplete -// case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS: // async operation in progress -// case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS: // crypto operation in progress -// // async error, we need to try again -// break; -// case MBEDTLS_ERR_SSL_TIMEOUT: // treat a timeout as -// -// if (value != NULL) { -// *value = recv; -// } else { -// free(recv); -// } -// if (value_len != NULL) { -// *value_len = (offset + pos); -// } -// return 0; // The only return where recv should not be freed -// // unexpected errors while reading -// default: -// if (ret > 0) { -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Unexpected read value %d\n", ret); -// } else { -// char strerr[512]; -// mbedtls_strerror(ret, strerr, 512); -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "%s", strerr); -// } -// free(recv); -// return ret; -// } // don't put anything after switch without checking it first -// } while (pos < READ_BLOCK_LEN); -// blocks++; -// } while (blocks < MAX_READ_BLOCKS); -// // We should only arrive at this point if we max out blocks -// // Every other code path should return early -// atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read within the maximum allowed number of read -// blocks\n"); free(recv); return 1; -// } int atclient_tls_socket_read_until_char(struct atclient_tls_socket *socket, unsigned char **value, size_t *value_len, char until_char) { @@ -409,6 +320,7 @@ int atclient_tls_socket_read_until_char(struct atclient_tls_socket *socket, unsi continue; // continue if not found char yet } // handle non-happy path + atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Socket read error: %d\n", ret); switch (ret) { case 0: // connection is closed case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: // connection is closed @@ -428,7 +340,7 @@ int atclient_tls_socket_read_until_char(struct atclient_tls_socket *socket, unsi if (timeout_count == MAX_READ_TIMEOUTS) { atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Failed to read the full message after %d attempts\n", MAX_READ_TIMEOUTS); - return ret; + return ATCLIENT_SSL_TIMEOUT_EXITCODE; } usleep(1000); break; diff --git a/tests/functional_tests/tests/test_atclient_get_atkeys.c b/tests/functional_tests/tests/test_atclient_get_atkeys.c index c949843d..6f6c1213 100644 --- a/tests/functional_tests/tests/test_atclient_get_atkeys.c +++ b/tests/functional_tests/tests/test_atclient_get_atkeys.c @@ -146,7 +146,6 @@ exit: { static int test_4_atclient_get_atkeys_null_regex(atclient *ctx, const char *scan_regex, const bool showhidden) { int ret = 1; - return 0; // NOTE: disabled temporarily DONT LET THIS GET MERGED atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "test_4_atclient_get_atkeys_null_regex\n");