diff --git a/pkg/network/ebpf/c/protocols/postgres/decoding.h b/pkg/network/ebpf/c/protocols/postgres/decoding.h index c96cce7479803..ab7ef0348fe56 100644 --- a/pkg/network/ebpf/c/protocols/postgres/decoding.h +++ b/pkg/network/ebpf/c/protocols/postgres/decoding.h @@ -108,25 +108,24 @@ static int __always_inline skip_string(pktbuf_t pkt, int message_len) { return SKIP_STRING_FAILED; } -// postgres_handle_message reads the first message header and decides what to do based on the +// Reads the first message header and decides what to do based on the // message tag. If the message is a new query, it stores the query in the in-flight map. // If the message is a parse message, we tail call to the dedicated process_parse_message program. // If the message is a command complete, it calls the handle_command_complete program. static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *conn_tuple, struct pg_message_header *header, __u8 tags) { - const __u32 zero = 0; // If the message is a parse message, we tail call to the dedicated function to handle it as it is too large to be // inlined in the main function. if (header->message_tag == POSTGRES_PARSE_MAGIC_BYTE) { pktbuf_tail_call_option_t process_parse_tail_call_array[] = { - [PKTBUF_SKB] = { - .prog_array_map = &protocols_progs, - .index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE, - }, - [PKTBUF_TLS] = { - .prog_array_map = &tls_process_progs, - .index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE, - }, - }; + [PKTBUF_SKB] = { + .prog_array_map = &protocols_progs, + .index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE, + }, + [PKTBUF_TLS] = { + .prog_array_map = &tls_process_progs, + .index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE, + }, + }; pktbuf_tail_call_compact(pkt, process_parse_tail_call_array); return; } @@ -143,6 +142,7 @@ static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t * return; } + const __u32 zero = 0; postgres_tail_call_state_t *iteration_value = bpf_map_lookup_elem(&postgres_iterations, &zero); if (iteration_value == NULL) { return; @@ -151,15 +151,15 @@ static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t * iteration_value->iteration = 0; iteration_value->data_off = 0; pktbuf_tail_call_option_t handle_response_tail_call_array[] = { - [PKTBUF_SKB] = { - .prog_array_map = &protocols_progs, - .index = PROG_POSTGRES_HANDLE_RESPONSE, - }, - [PKTBUF_TLS] = { - .prog_array_map = &tls_process_progs, - .index = PROG_POSTGRES_HANDLE_RESPONSE, - }, - }; + [PKTBUF_SKB] = { + .prog_array_map = &protocols_progs, + .index = PROG_POSTGRES_HANDLE_RESPONSE, + }, + [PKTBUF_TLS] = { + .prog_array_map = &tls_process_progs, + .index = PROG_POSTGRES_HANDLE_RESPONSE, + }, + }; pktbuf_tail_call_compact(pkt, handle_response_tail_call_array); return; } @@ -201,16 +201,12 @@ static __always_inline void postgres_handle_parse_message(pktbuf_t pkt, conn_tup // POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES) to continue processing. static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tuple) { const __u32 zero = 0; + bool found_command_complete = false; struct pg_message_header header; - // We didn't find a new query, thus we assume we're in the middle of a transaction. - // We look up the transaction in the in-flight map, and if it doesn't exist, we ignore the message. - postgres_transaction_t *transaction = bpf_map_lookup_elem(&postgres_in_flight, &conn_tuple); - if (!transaction) { - return 0; - } postgres_tail_call_state_t *iteration_value = bpf_map_lookup_elem(&postgres_iterations, &zero); if (iteration_value == NULL) { + bpf_map_delete_elem(&postgres_in_flight, &conn_tuple); return 0; } @@ -222,6 +218,13 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl pktbuf_set_offset(pkt, iteration_value->data_off); } + // We didn't find a new query, thus we assume we're in the middle of a transaction. + // We look up the transaction in the in-flight map, and if it doesn't exist, we ignore the message. + postgres_transaction_t *transaction = bpf_map_lookup_elem(&postgres_in_flight, &conn_tuple); + if (!transaction) { + return 0; + } + #pragma unroll(POSTGRES_MAX_MESSAGES_PER_TAIL_CALL) for (__u32 iteration = 0; iteration < POSTGRES_MAX_MESSAGES_PER_TAIL_CALL; ++iteration) { if (!read_message_header(pkt, &header)) { @@ -229,6 +232,7 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl } if (header.message_tag == POSTGRES_COMMAND_COMPLETE_MAGIC_BYTE) { handle_command_complete(&conn_tuple, transaction); + found_command_complete = true; break; } // We didn't find a command complete message, so we advance the data offset to the end of the message. @@ -237,19 +241,23 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl pktbuf_advance(pkt, header.message_len + 1); } - iteration_value->iteration += 1; - iteration_value->data_off = pktbuf_data_offset(pkt); - pktbuf_tail_call_option_t handle_response_tail_call_array[] = { - [PKTBUF_SKB] = { - .prog_array_map = &protocols_progs, - .index = PROG_POSTGRES_HANDLE_RESPONSE, - }, - [PKTBUF_TLS] = { - .prog_array_map = &tls_process_progs, - .index = PROG_POSTGRES_HANDLE_RESPONSE, - }, - }; - pktbuf_tail_call_compact(pkt, handle_response_tail_call_array); + if (!found_command_complete) { + // We didn't find a command complete message, so we need to continue processing the packet. + // We save the current data offset and increment the iteration counter. + iteration_value->iteration += 1; + iteration_value->data_off = pktbuf_data_offset(pkt); + pktbuf_tail_call_option_t handle_response_tail_call_array[] = { + [PKTBUF_SKB] = { + .prog_array_map = &protocols_progs, + .index = PROG_POSTGRES_HANDLE_RESPONSE, + }, + [PKTBUF_TLS] = { + .prog_array_map = &tls_process_progs, + .index = PROG_POSTGRES_HANDLE_RESPONSE, + }, + }; + pktbuf_tail_call_compact(pkt, handle_response_tail_call_array); + } return 0; } diff --git a/pkg/network/protocols/ebpf_types.go b/pkg/network/protocols/ebpf_types.go index 87d1617c1f80d..57bdb4539bf5e 100644 --- a/pkg/network/protocols/ebpf_types.go +++ b/pkg/network/protocols/ebpf_types.go @@ -18,10 +18,11 @@ const ( layerEncryptionBit = C.LAYER_ENCRYPTION_BIT ) -// Represents the maximum number of messages that can be processed in our Postgres decoding solution. const ( + // PostgresMaxMessagesPerTailCall is the maximum number of messages that can be processed in a single tail call in our Postgres decoding solution PostgresMaxMessagesPerTailCall = C.POSTGRES_MAX_MESSAGES_PER_TAIL_CALL - PostgresMaxTailCalls = C.POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES + // PostgresMaxTailCalls is the maximum number of tail calls that can be made in our Postgres decoding solution + PostgresMaxTailCalls = C.POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES ) // DispatcherProgramType is a C type to represent the eBPF programs used for tail calls. diff --git a/pkg/network/protocols/ebpf_types_linux.go b/pkg/network/protocols/ebpf_types_linux.go index 1386e5c97a2a2..ae855cfe48b77 100644 --- a/pkg/network/protocols/ebpf_types_linux.go +++ b/pkg/network/protocols/ebpf_types_linux.go @@ -11,7 +11,8 @@ const ( const ( PostgresMaxMessagesPerTailCall = 0x50 - PostgresMaxTailCalls = 0x1 + + PostgresMaxTailCalls = 0x1 ) type DispatcherProgramType uint32 diff --git a/pkg/network/usm/postgres_monitor_test.go b/pkg/network/usm/postgres_monitor_test.go index 37e8dc6447aa4..ccd8620940963 100644 --- a/pkg/network/usm/postgres_monitor_test.go +++ b/pkg/network/usm/postgres_monitor_test.go @@ -398,11 +398,7 @@ func testDecoding(t *testing.T, isTLS bool) { ctx.extras["pg"] = pg }, postMonitorSetup: func(t *testing.T, ctx pgTestContext) { - pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{ - ServerAddress: ctx.serverAddress, - EnableTLS: isTLS, - }) - require.NoError(t, err) + pg := ctx.extras["pg"].(*postgres.PGXClient) require.NoError(t, pg.Ping()) ctx.extras["pg"] = pg require.NoError(t, pg.RunQuery(createTableQuery)) @@ -589,7 +585,7 @@ func testDecoding(t *testing.T, isTLS bool) { }, // The purpose of this test is to validate the POSTGRES_MAX_MESSAGES_PER_TAIL_CALL * POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES limit. { - name: "validate supporting POSTGRES_MAX_MESSAGES_PER_TAIL_CALL limit", + name: "validate supporting max supported messages limit", preMonitorSetup: func(t *testing.T, ctx pgTestContext) { pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{ ServerAddress: ctx.serverAddress, @@ -600,11 +596,7 @@ func testDecoding(t *testing.T, isTLS bool) { ctx.extras["pg"] = pg }, postMonitorSetup: func(t *testing.T, ctx pgTestContext) { - pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{ - ServerAddress: ctx.serverAddress, - EnableTLS: isTLS, - }) - require.NoError(t, err) + pg := ctx.extras["pg"].(*postgres.PGXClient) require.NoError(t, pg.Ping()) ctx.extras["pg"] = pg require.NoError(t, pg.RunQuery(createTableQuery)) @@ -636,11 +628,7 @@ func testDecoding(t *testing.T, isTLS bool) { ctx.extras["pg"] = pg }, postMonitorSetup: func(t *testing.T, ctx pgTestContext) { - pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{ - ServerAddress: ctx.serverAddress, - EnableTLS: isTLS, - }) - require.NoError(t, err) + pg := ctx.extras["pg"].(*postgres.PGXClient) require.NoError(t, pg.Ping()) ctx.extras["pg"] = pg require.NoError(t, pg.RunQuery(createTableQuery))