Skip to content

Commit

Permalink
Merge pull request #297 from atsign-foundation/xc/conn-hooks
Browse files Browse the repository at this point in the history
feat: monitor hooks, connection send hooks
  • Loading branch information
XavierChanth authored Jun 10, 2024
2 parents 051c16c + b6be628 commit b757b28
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/desktop/at_talk/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ static void *monitor_handler(void *xargs) {
atclient_monitor_message *message = NULL;
pthread_mutex_lock(&monitor_mutex);
pthread_mutex_lock(&client_mutex);
ret = atclient_monitor_read(monitor, ctx, &message);
ret = atclient_monitor_read(monitor, ctx, &message, NULL);
pthread_mutex_unlock(&monitor_mutex);
pthread_mutex_unlock(&client_mutex);

Expand Down
2 changes: 1 addition & 1 deletion examples/desktop/events/monitor.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ int main(int argc, char *argv[]) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "Starting main monitor loop...\n");
while (true) {

ret = atclient_monitor_read(&monitor_conn, &atclient2, &message);
ret = atclient_monitor_read(&monitor_conn, &atclient2, &message, NULL);
if (ret != 0) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to read monitor message: %d\n", ret);
continue;
Expand Down
2 changes: 1 addition & 1 deletion examples/desktop/events/resilient_monitor.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int main(int argc, char *argv[]) {
size_t max_tries = 10;
while (true) {

ret = atclient_monitor_read(&monitor_conn, &atclient2, &message);
ret = atclient_monitor_read(&monitor_conn, &atclient2, &message, NULL);

switch (message->type) {
case ATCLIENT_MONITOR_MESSAGE_TYPE_NONE: {
Expand Down
53 changes: 52 additions & 1 deletion packages/atclient/include/atclient/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,30 @@
// represents the type of connection
typedef enum atclient_connection_type {
ATCLIENT_CONNECTION_TYPE_ATDIRECTORY, // uses '\n' to check if it is connected
ATCLIENT_CONNECTION_TYPE_ATSERVER // uses 'noop:0\r\n' to check if it is connected
ATCLIENT_CONNECTION_TYPE_ATSERVER // uses 'noop:0\r\n' to check if it is connected
} atclient_connection_type;

typedef int(atclient_connection_send_hook)(const unsigned char *src, const size_t srclen, unsigned char *recv,
const size_t recvsize, size_t *recvlen);

typedef enum atclient_connection_hook_type {
ATCLIENT_CONNECTION_HOOK_TYPE_NONE = 0,
ATCLIENT_CONNECTION_HOOK_TYPE_PRE_SEND,
ATCLIENT_CONNECTION_HOOK_TYPE_POST_SEND,
ATCLIENT_CONNECTION_HOOK_TYPE_PRE_RECV,
ATCLIENT_CONNECTION_HOOK_TYPE_POST_RECV,
} atclient_connection_hook_type;

typedef struct atclient_connection_hooks {
bool _is_nested_call; // internal variable for preventing infinite recursion (hooks cannot trigger other hooks in
// their nested calls)
atclient_connection_send_hook *pre_send;
atclient_connection_send_hook *post_send;
atclient_connection_send_hook *pre_recv;
atclient_connection_send_hook *post_recv;
bool readonly_src;
} atclient_connection_hooks;

typedef struct atclient_connection {
char host[ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE];
int port; // example: 64
Expand All @@ -33,6 +54,8 @@ typedef struct atclient_connection {
// this does not mean that the connection is still alive, it just means that the connection was established or taken
// down at some point, check atclient_connection_is_connected for a live status on the connection
bool should_be_connected;

atclient_connection_hooks *hooks;
} atclient_connection;

/**
Expand All @@ -42,6 +65,7 @@ typedef struct atclient_connection {
* @param type the type of connection to initialize,
* if it is ATCLIENT_CONNECTION_TYPE_ROOT, then '\\n' will be used to check if it is connected.
* if it is ATCLIENT_CONNECTION_TYPE_ATSERVER, then 'noop:0\r\n' will be used to check if it is connected
*
*/
void atclient_connection_init(atclient_connection *ctx, atclient_connection_type type);

Expand Down Expand Up @@ -105,4 +129,31 @@ void atclient_connection_free(atclient_connection *ctx);
*/
int atclient_connection_get_host_and_port(atclient_atstr *host, int *port, const atclient_atstr url);

/**
* @brief Initialize the hooks memory allocation
*
* @param ctx the struct for the connection
*/
void atclient_connection_enable_hooks(atclient_connection *ctx);

/**
* @brief Add a hook to be called during the connection lifecycle
*
* @param ctx the struct for the connection
* @param type the hook type you want to add
* @param hook the hook function itself
*
* @return int 0 on success, otherwise error
*/
int atclient_connection_hooks_set(atclient_connection *ctx, atclient_connection_hook_type type, void *hook);

/**
* @brief Set whether the readonly_src status for all hooks
*
* @param ctx the struct for the connection
* @param readonly_src the new state for readonly_src
*
* @note For performance, keep readonly_src set to true if you don't need to write access to src
*/
void atclient_connection_hooks_set_readonly_src(atclient_connection *ctx, bool readonly_src);
#endif
8 changes: 7 additions & 1 deletion packages/atclient/include/atclient/monitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ int atclient_monitor_pkam_authenticate(atclient *monitor_conn, const char *atser
*/
void atclient_monitor_set_read_timeout(atclient *monitor_conn, const int timeout_ms);

typedef struct atclient_monitor_hooks {
int (*pre_decrypt_notification)(void);
int (*post_decrypt_notification)(void);
} atclient_monitor_hooks;

/**
* @brief Sends the monitor command to the atserver to start monitoring notifications, assumed that the monitor atclient
* context is already pkam authenticated
Expand All @@ -271,7 +276,8 @@ int atclient_monitor_start(atclient *monitor_conn, const char *regex, const size
* @note Message may be a notification, a data response, or an error response, check the type field to determine which
* data field to use
*/
int atclient_monitor_read(atclient *monitor_conn, atclient *atclient, atclient_monitor_message **message);
int atclient_monitor_read(atclient *monitor_conn, atclient *atclient, atclient_monitor_message **message,
atclient_monitor_hooks *hooks);

/**
* @brief Check if the monitor connection is still established (client is listening for notifications, and the server
Expand Down
3 changes: 1 addition & 2 deletions packages/atclient/src/atclient.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "atclient/atclient_utils.h"
#include "atchops/base64.h"
#include "atchops/rsa.h"
#include "atclient/atbytes.h"
#include "atclient/atclient.h"
#include "atclient/atkeys.h"
#include "atclient/atsign.h"
Expand Down Expand Up @@ -42,7 +41,7 @@ void atclient_free(atclient *ctx) {
atclient_atsign_free(&(ctx->atsign));
}

if(!ctx->atkeys_is_allocated_by_caller) {
if (!ctx->atkeys_is_allocated_by_caller) {
atclient_atkeys_free(&(ctx->atkeys));
}

Expand Down
166 changes: 152 additions & 14 deletions packages/atclient/src/connection.c
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

#include "atclient/connection.h"
#include "atchops/constants.h"
#include "atclient/atstr.h"
#include "atclient/cacerts.h"
#include "atclient/constants.h"
#include "atclient/stringutils.h"
#include "atlogger/atlogger.h"
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h>
Expand Down Expand Up @@ -204,29 +202,77 @@ exit: {
}
}

int atclient_connection_send(atclient_connection *ctx, const unsigned char *src, const size_t srclen,
unsigned char *recv, const size_t recvsize, size_t *recvlen) {
int atclient_connection_send(atclient_connection *ctx, const unsigned char *src_r, const size_t srclen_r,
unsigned char *recv, const size_t recvsize_r, size_t *recvlen) {
int ret = 1;

// Clone readonly inputs so it is editable by the hooks
size_t srclen = srclen_r;
size_t recvsize = recvsize_r;

bool try_hooks = ctx->hooks != NULL && !ctx->hooks->_is_nested_call;
bool allocate_src = try_hooks && ctx->hooks->readonly_src == false;

unsigned char *src;

if (allocate_src) {
src = malloc(sizeof(unsigned char) * srclen);
if (src == NULL) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Failed to allocate memory for src\n");
allocate_src = false; // don't try to free since the memory failed to be allocated
goto exit;
}
memcpy(src, src_r, srclen);
} else {
src = (unsigned char *)src_r;
}

if (!ctx->should_be_connected) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR,
"ctx->should_be_connected should be true, but is false. You are trying to send messages to a "
"non-connected connection.\n");
goto exit;
}

if (try_hooks && ctx->hooks->pre_send != NULL) {
ctx->hooks->_is_nested_call = true;
ret = ctx->hooks->pre_send(src, srclen, recv, recvsize, recvlen);
ctx->hooks->_is_nested_call = false;
if (ret != 0) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "pre_send hook failed with exit code: %d\n", ret);
goto exit;
}
}

ret = mbedtls_ssl_write(&(ctx->ssl), src, srclen);
if (ret <= 0) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_write failed with exit code: %d\n", ret);
goto exit;
}

if (try_hooks && ctx->hooks->post_send != NULL) {
ctx->hooks->_is_nested_call = true;
ret = ctx->hooks->post_send(src, srclen, recv, recvsize, recvlen);
ctx->hooks->_is_nested_call = false;
if (ret != 0) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "post_send hook failed with exit code: %d\n", ret);
goto exit;
}
}

unsigned char *srccopy;
if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG && ret == srclen) {
unsigned char srccopy[srclen];
memcpy(srccopy, src, srclen);
atlogger_fix_stdout_buffer((char *)srccopy, srclen);
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sSENT: %s\"%.*s\"%s\n", BBLU, HCYN, strlen((char *)srccopy),
srccopy, reset);
srccopy = malloc(sizeof(unsigned char) * srclen);
if (srccopy != NULL) {
memcpy(srccopy, src, srclen);
atlogger_fix_stdout_buffer((char *)srccopy, srclen);
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sSENT: %s\"%.*s\"%s\n", BBLU, HCYN, strlen((char *)srccopy),
srccopy, reset);
free(srccopy);
} else {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR,
"Failed to allocate memory to pretty print the network sent transmission\n");
}
}

if (recv == NULL) {
Expand All @@ -236,6 +282,16 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src,

memset(recv, 0, sizeof(unsigned char) * recvsize);

if (try_hooks && ctx->hooks->pre_recv != NULL) {
ctx->hooks->_is_nested_call = true;
ret = ctx->hooks->pre_recv(src, srclen, recv, recvsize, recvlen);
ctx->hooks->_is_nested_call = false;
if (ret != 0) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "pre_recv hook failed with exit code: %d\n", ret);
goto exit;
}
}

int tries = 0;
bool found = false;
size_t l = 0;
Expand All @@ -245,6 +301,15 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src,
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "mbedtls_ssl_read failed with exit code: %d\n", ret);
goto exit;
}
if (ret == 0) {
tries++;
if (tries >= ATCLIENT_CONNECTION_MAX_READ_TRIES) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR,
"mbedtls_ssl_read tried to read %d times and found nothing: %d\n", tries, ret);
ret = 1;
goto exit;
}
}
l = l + ret;

for (int i = l; i >= l - ret && i >= 0; i--) {
Expand All @@ -264,16 +329,39 @@ int atclient_connection_send(atclient_connection *ctx, const unsigned char *src,
// atlogger_fix_stdout_buffer((char *)recv, *recvlen);
recv[*recvlen] = '\0'; // null terminate the string

unsigned char *recvcopy;
if (atlogger_get_logging_level() >= ATLOGGER_LOGGING_LEVEL_DEBUG) {
unsigned char recvcopy[*recvlen];
memcpy(recvcopy, recv, *recvlen);
atlogger_fix_stdout_buffer((char *)recvcopy, *recvlen);
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, *recvlen, recvcopy, reset);
recvcopy = malloc(sizeof(unsigned char) * (*recvlen));
if (recvcopy != NULL) {
memcpy(recvcopy, recv, *recvlen);
atlogger_fix_stdout_buffer((char *)recvcopy, *recvlen);
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_DEBUG, "\t%sRECV: %s\"%.*s\"%s\n", BMAG, HMAG, *recvlen, recvcopy,
reset);
free(recvcopy);
} else {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR,
"Failed to allocate memory to pretty print the network received buffer\n");
}
}

if (try_hooks && ctx->hooks->post_recv != NULL) {
ctx->hooks->_is_nested_call = true;
ret = ctx->hooks->post_recv(src, srclen, recv, recvsize, recvlen);
ctx->hooks->_is_nested_call = false;
if (ret != 0) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "post_recv hook failed with exit code: %d\n", ret);
goto exit;
}
}

ret = 0;
goto exit;
exit: { return ret; }
exit: {
if (allocate_src) {
free(src);
}
return ret;
}
}

int atclient_connection_disconnect(atclient_connection *ctx) {
Expand Down Expand Up @@ -344,6 +432,10 @@ void atclient_connection_free(atclient_connection *ctx) {
memset(ctx->host, 0, ATCLIENT_CONSTANTS_HOST_BUFFER_SIZE);
ctx->port = -1;
ctx->should_be_connected = false;

if (ctx->hooks != NULL) {
free(ctx->hooks);
}
}

int atclient_connection_get_host_and_port(atclient_atstr *host, int *port, const atclient_atstr url) {
Expand Down Expand Up @@ -379,3 +471,49 @@ int atclient_connection_get_host_and_port(atclient_atstr *host, int *port, const

exit: { return ret; }
}

void atclient_connection_enable_hooks(atclient_connection *ctx) {
ctx->hooks = malloc(sizeof(atclient_connection_hooks));
memset(ctx->hooks, 0, sizeof(atclient_connection_hooks));
ctx->hooks->readonly_src = true;
}

// Q. Why is hook a void pointer?
// A. In case we want to add future hook types which use a different function signature
int atclient_connection_hooks_set(atclient_connection *ctx, atclient_connection_hook_type type, void *hook) {
atclient_connection_hooks *hooks = ctx->hooks;
if (hooks == NULL) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR,
"Make sure to initialize hooks struct before trying to set a hook\n");
return -1;
}

switch (type) {
case ATCLIENT_CONNECTION_HOOK_TYPE_NONE:
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR, "Received 'NONE' hook as hook set input type\n");
return 1;
case ATCLIENT_CONNECTION_HOOK_TYPE_PRE_SEND:
hooks->pre_send = (atclient_connection_send_hook *)hook;
break;
case ATCLIENT_CONNECTION_HOOK_TYPE_POST_SEND:
hooks->post_send = (atclient_connection_send_hook *)hook;
break;
case ATCLIENT_CONNECTION_HOOK_TYPE_PRE_RECV:
hooks->pre_recv = (atclient_connection_send_hook *)hook;
break;
case ATCLIENT_CONNECTION_HOOK_TYPE_POST_RECV:
hooks->post_recv = (atclient_connection_send_hook *)hook;
break;
}

return 0;
}

void atclient_connection_hooks_set_readonly_src(atclient_connection *ctx, bool readonly_src) {
if (ctx->hooks == NULL) {
atlogger_log(TAG, ATLOGGER_LOGGING_LEVEL_ERROR,
"Make sure to initialize hooks struct before trying to set readonly_src\n");
return;
}
ctx->hooks->readonly_src = readonly_src;
}
Loading

0 comments on commit b757b28

Please sign in to comment.