From b6186f460c946f20b979f91e3d9181fd1bba0d30 Mon Sep 17 00:00:00 2001 From: Andrey Semashev Date: Tue, 28 Feb 2023 21:06:30 +0300 Subject: [PATCH] Changed stream list to arrays and optimized stream lookup with SSE2. The stream list in the SRTP context is now implemented with two arrays: an array of SSRCs and an array of pointers to the streams corresponding to the SSRCs. The streams no longer form a linked list. Stream lookup by SSRC is now performed over the array of SSRCs, which is considerably faster because it is more cache-friendly. Additionally, the lookup is optimized for SSE2, which provides an additional massive speedup with many streams in the list. Although the lookup still has linear complexity, its absolute times are reduced and with tens to hundreds elements are lower or comparable with a typical rb-tree equivalent. Expected speedup of SSE2 version over the previous implementation: SSRCs speedup (scalar) speedup (SSE2) 1 0.39x 0.22x 3 0.57x 0.23x 5 0.69x 0.62x 10 0.77x 1.43x 20 0.86x 2.38x 30 0.87x 3.44x 50 1.13x 6.21x 100 1.25x 8.51x 200 1.30x 9.83x These numbers were obtained on a Core i7 2600K. At small numbers of SSRCs the new algorithm is somewhat slower, but given that the absolute and relative times of the lookup are very small, that slowdown is not very significant. --- srtp/srtp.c | 231 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 197 insertions(+), 34 deletions(-) diff --git a/srtp/srtp.c b/srtp/srtp.c index 75b71fb6e..14e41d5cc 100644 --- a/srtp/srtp.c +++ b/srtp/srtp.c @@ -60,6 +60,8 @@ #include "aes_icm_ext.h" #endif +#include +#include #include #ifdef HAVE_NETINET_IN_H #include @@ -67,6 +69,13 @@ #include #endif +#if defined(__SSE2__) +#include +#if defined(_MSC_VER) +#include +#endif +#endif + /* the debug module for srtp */ srtp_debug_module_t mod_srtp = { 0, /* debugging is off by default */ @@ -79,6 +88,16 @@ srtp_debug_module_t mod_srtp = { #define uint32s_in_rtcp_header 2 #define octets_in_rtp_extn_hdr 4 +#ifndef SRTP_NO_STREAM_LIST +static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list); +static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list, + uint32_t new_capacity); +static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc); +static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list, + uint32_t pos); +static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos); +#endif // SRTP_NO_STREAM_LIST + static srtp_err_status_t srtp_validate_rtp_header(void *rtp_hdr, int *pkt_octet_len) { @@ -3030,18 +3049,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc) { srtp_stream_ctx_t *stream; srtp_err_status_t status; +#if !defined(SRTP_NO_STREAM_LIST) + uint32_t pos; +#endif /* sanity check arguments */ - if (session == NULL) + if (session == NULL) { return srtp_err_status_bad_param; + } /* find and remove stream from the list */ +#if !defined(SRTP_NO_STREAM_LIST) + pos = srtp_stream_list_find(session->stream_list, ssrc); + if (pos >= srtp_stream_list_size(session->stream_list)) + return srtp_err_status_no_ctx; + + stream = srtp_stream_list_get_at(session->stream_list, pos); + srtp_stream_list_remove_at(session->stream_list, pos); +#else stream = srtp_stream_list_get(session->stream_list, ssrc); if (stream == NULL) { return srtp_err_status_no_ctx; } srtp_stream_list_remove(session->stream_list, stream); +#endif /* deallocate the stream */ status = srtp_stream_dealloc(stream, session->stream_template); @@ -4840,11 +4872,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session, #ifndef SRTP_NO_STREAM_LIST -/* in the default implementation, we have an intrusive doubly-linked list */ typedef struct srtp_stream_list_ctx_t_ { - /* a stub stream that just holds pointers to the beginning and end of the - * list */ - srtp_stream_ctx_t data; + uint32_t *ssrcs; + srtp_stream_ctx_t **streams; + uint32_t size; + uint32_t capacity; } srtp_stream_list_ctx_t_; srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr) @@ -4855,9 +4887,6 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr) return srtp_err_status_alloc_fail; } - list->data.next = NULL; - list->data.prev = NULL; - *list_ptr = list; return srtp_err_status_ok; } @@ -4865,63 +4894,197 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr) srtp_err_status_t srtp_stream_list_dealloc(srtp_stream_list_t list) { /* list must be empty */ - if (list->data.next) { + if (list->size != 0u) { return srtp_err_status_fail; } + srtp_crypto_free(list->streams); + srtp_crypto_free(list->ssrcs); srtp_crypto_free(list); return srtp_err_status_ok; } +static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list) +{ + return list->size; +} + +static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list, + uint32_t new_capacity) +{ + if (new_capacity > list->capacity) { + uint32_t *ssrcs; + srtp_stream_ctx_t **stream_ptrs; + + if (new_capacity > (UINT32_MAX - 15u)) + return srtp_err_status_alloc_fail; + + new_capacity = (new_capacity + 15u) & ~((uint32_t)15u); + + ssrcs = (uint32_t *)srtp_crypto_alloc((size_t)new_capacity * + sizeof(uint32_t)); + if (!ssrcs) + return srtp_err_status_alloc_fail; + stream_ptrs = (srtp_stream_ctx_t **)srtp_crypto_alloc( + (size_t)new_capacity * sizeof(srtp_stream_ctx_t *)); + if (!stream_ptrs) { + srtp_crypto_free(ssrcs); + return srtp_err_status_alloc_fail; + } + + if (list->size > 0u) { + memcpy(ssrcs, list->ssrcs, (size_t)list->size * sizeof(uint32_t)); + memcpy(stream_ptrs, list->streams, + (size_t)list->size * sizeof(srtp_stream_ctx_t *)); + } + + srtp_crypto_free(list->ssrcs); + srtp_crypto_free(list->streams); + list->streams = stream_ptrs; + list->ssrcs = ssrcs; + + list->capacity = new_capacity; + } + + return srtp_err_status_ok; +} + srtp_err_status_t srtp_stream_list_insert(srtp_stream_list_t list, srtp_stream_t stream) { - /* insert at the head of the list */ - stream->next = list->data.next; - if (stream->next != NULL) { - stream->next->prev = stream; - } - list->data.next = stream; - stream->prev = &(list->data); + uint32_t pos; + srtp_err_status_t status = srtp_stream_list_reserve(list, list->size + 1u); + if (status) + return status; + pos = list->size++; + list->ssrcs[pos] = stream->ssrc; + list->streams[pos] = stream; return srtp_err_status_ok; } -srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc) +static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc) { - /* walk down list until ssrc is found */ - srtp_stream_t stream = list->data.next; - while (stream != NULL) { - if (stream->ssrc == ssrc) { - return stream; +#if defined(__SSE2__) + const uint32_t *const ssrcs = list->ssrcs; + const __m128i mm_ssrc = _mm_set1_epi32(ssrc); + uint32_t pos = 0u, n = (list->size + 7u) & ~(uint32_t)(7u); + for (uint32_t m = n & ~(uint32_t)(15u); pos < m; pos += 16u) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + __m128i mm3 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 8u)); + __m128i mm4 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 12u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm3 = _mm_cmpeq_epi32(mm3, mm_ssrc); + mm4 = _mm_cmpeq_epi32(mm4, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + mm3 = _mm_packs_epi32(mm3, mm4); + mm1 = _mm_packs_epi16(mm1, mm3); + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos; +#else + pos += __builtin_ctz(mask); +#endif + + goto done; + } + } + + if (pos < n) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos / 2u; +#else + pos += __builtin_ctz(mask) / 2u; +#endif + goto done; } - stream = stream->next; + + pos += 8u; + } + +done: + return pos; +#else + /* walk down list until ssrc is found */ + uint32_t pos = 0u, n = list->size; + for (; pos < n; ++pos) { + if (list->ssrcs[pos] == ssrc) + break; } + return pos; +#endif +} + +static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list, + uint32_t pos) +{ + return list->streams[pos]; +} + +srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc) +{ + uint32_t pos = srtp_stream_list_find(list, ssrc); + if (pos < list->size) + return list->streams[pos]; + /* we haven't found our ssrc, so return a null */ return NULL; } -void srtp_stream_list_remove(srtp_stream_list_t list, - srtp_stream_t stream_to_remove) +static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos) { - (void)list; + uint32_t tail_size, last_pos; - stream_to_remove->prev->next = stream_to_remove->next; - if (stream_to_remove->next != NULL) { - stream_to_remove->next->prev = stream_to_remove->prev; + last_pos = --list->size; + tail_size = last_pos - pos; + if (tail_size > 0u) { + memmove(list->streams + pos, list->streams + pos + 1, + (size_t)tail_size * sizeof(*list->streams)); + memmove(list->ssrcs + pos, list->ssrcs + pos + 1, + (size_t)tail_size * sizeof(*list->ssrcs)); } + + list->streams[last_pos] = NULL; + list->ssrcs[last_pos] = 0u; +} + +void srtp_stream_list_remove(srtp_stream_list_t list, + srtp_stream_t stream_to_remove) +{ + uint32_t pos = srtp_stream_list_find(list, stream_to_remove->ssrc); + if (pos < list->size) + srtp_stream_list_remove_at(list, pos); } void srtp_stream_list_for_each(srtp_stream_list_t list, int (*callback)(srtp_stream_t, void *), void *data) { - srtp_stream_t stream = list->data.next; - while (stream != NULL) { - srtp_stream_t tmp = stream; - stream = stream->next; - if (callback(tmp, data)) + uint32_t size = list->size; + for (uint32_t i = 0u; i < size;) { + if (callback(list->streams[i], data)) break; + + /* check if the callback removed the current element */ + if (size == list->size) + ++i; + else + size = list->size; } }