Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[USM] Introduce tail cals for the Postgres monitoring #30547

Merged
merged 16 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 47 additions & 39 deletions pkg/network/ebpf/c/protocols/postgres/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,24 @@ static int __always_inline skip_string(pktbuf_t pkt, int message_len) {
return SKIP_STRING_FAILED;
}

// postgres_handle_message reads the first message header and decides what to do based on the
// Reads the first message header and decides what to do based on the
// message tag. If the message is a new query, it stores the query in the in-flight map.
// If the message is a parse message, we tail call to the dedicated process_parse_message program.
// If the message is a command complete, it calls the handle_command_complete program.
static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *conn_tuple, struct pg_message_header *header, __u8 tags) {
const __u32 zero = 0;
// If the message is a parse message, we tail call to the dedicated function to handle it as it is too large to be
// inlined in the main function.
if (header->message_tag == POSTGRES_PARSE_MAGIC_BYTE) {
pktbuf_tail_call_option_t process_parse_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
};
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
};
pktbuf_tail_call_compact(pkt, process_parse_tail_call_array);
return;
}
Expand All @@ -143,6 +142,7 @@ static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *
return;
}

const __u32 zero = 0;
postgres_tail_call_state_t *iteration_value = bpf_map_lookup_elem(&postgres_iterations, &zero);
if (iteration_value == NULL) {
return;
Expand All @@ -151,15 +151,15 @@ static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *
iteration_value->iteration = 0;
iteration_value->data_off = 0;
pktbuf_tail_call_option_t handle_response_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
pktbuf_tail_call_compact(pkt, handle_response_tail_call_array);
return;
}
Expand Down Expand Up @@ -201,16 +201,12 @@ static __always_inline void postgres_handle_parse_message(pktbuf_t pkt, conn_tup
// POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES) to continue processing.
static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tuple) {
const __u32 zero = 0;
bool found_command_complete = false;
struct pg_message_header header;
// We didn't find a new query, thus we assume we're in the middle of a transaction.
// We look up the transaction in the in-flight map, and if it doesn't exist, we ignore the message.
postgres_transaction_t *transaction = bpf_map_lookup_elem(&postgres_in_flight, &conn_tuple);
if (!transaction) {
return 0;
}

postgres_tail_call_state_t *iteration_value = bpf_map_lookup_elem(&postgres_iterations, &zero);
if (iteration_value == NULL) {
bpf_map_delete_elem(&postgres_in_flight, &conn_tuple);
return 0;
}

Expand All @@ -222,13 +218,21 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl
pktbuf_set_offset(pkt, iteration_value->data_off);
}

// We didn't find a new query, thus we assume we're in the middle of a transaction.
// We look up the transaction in the in-flight map, and if it doesn't exist, we ignore the message.
postgres_transaction_t *transaction = bpf_map_lookup_elem(&postgres_in_flight, &conn_tuple);
if (!transaction) {
return 0;
}

#pragma unroll(POSTGRES_MAX_MESSAGES_PER_TAIL_CALL)
for (__u32 iteration = 0; iteration < POSTGRES_MAX_MESSAGES_PER_TAIL_CALL; ++iteration) {
if (!read_message_header(pkt, &header)) {
break;
}
if (header.message_tag == POSTGRES_COMMAND_COMPLETE_MAGIC_BYTE) {
handle_command_complete(&conn_tuple, transaction);
found_command_complete = true;
break;
}
// We didn't find a command complete message, so we advance the data offset to the end of the message.
Expand All @@ -237,19 +241,23 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl
pktbuf_advance(pkt, header.message_len + 1);
}

iteration_value->iteration += 1;
iteration_value->data_off = pktbuf_data_offset(pkt);
pktbuf_tail_call_option_t handle_response_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
pktbuf_tail_call_compact(pkt, handle_response_tail_call_array);
if (!found_command_complete) {
// We didn't find a command complete message, so we need to continue processing the packet.
// We save the current data offset and increment the iteration counter.
iteration_value->iteration += 1;
iteration_value->data_off = pktbuf_data_offset(pkt);
pktbuf_tail_call_option_t handle_response_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
pktbuf_tail_call_compact(pkt, handle_response_tail_call_array);
}
guyarb marked this conversation as resolved.
Show resolved Hide resolved
return 0;
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/network/protocols/ebpf_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ const (
layerEncryptionBit = C.LAYER_ENCRYPTION_BIT
)

// Represents the maximum number of messages that can be processed in our Postgres decoding solution.
const (
// PostgresMaxMessagesPerTailCall is the maximum number of messages that can be processed in a single tail call in our Postgres decoding solution
PostgresMaxMessagesPerTailCall = C.POSTGRES_MAX_MESSAGES_PER_TAIL_CALL
PostgresMaxTailCalls = C.POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES
// PostgresMaxTailCalls is the maximum number of tail calls that can be made in our Postgres decoding solution
PostgresMaxTailCalls = C.POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES
)

// DispatcherProgramType is a C type to represent the eBPF programs used for tail calls.
Expand Down
3 changes: 2 additions & 1 deletion pkg/network/protocols/ebpf_types_linux.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 4 additions & 16 deletions pkg/network/usm/postgres_monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,7 @@ func testDecoding(t *testing.T, isTLS bool) {
ctx.extras["pg"] = pg
},
postMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
EnableTLS: isTLS,
})
require.NoError(t, err)
pg := ctx.extras["pg"].(*postgres.PGXClient)
require.NoError(t, pg.Ping())
ctx.extras["pg"] = pg
require.NoError(t, pg.RunQuery(createTableQuery))
Expand Down Expand Up @@ -589,7 +585,7 @@ func testDecoding(t *testing.T, isTLS bool) {
},
// The purpose of this test is to validate the POSTGRES_MAX_MESSAGES_PER_TAIL_CALL * POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES limit.
{
name: "validate supporting POSTGRES_MAX_MESSAGES_PER_TAIL_CALL limit",
name: "validate supporting max supported messages limit",
preMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
Expand All @@ -600,11 +596,7 @@ func testDecoding(t *testing.T, isTLS bool) {
ctx.extras["pg"] = pg
},
postMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
EnableTLS: isTLS,
})
require.NoError(t, err)
pg := ctx.extras["pg"].(*postgres.PGXClient)
require.NoError(t, pg.Ping())
ctx.extras["pg"] = pg
require.NoError(t, pg.RunQuery(createTableQuery))
Expand Down Expand Up @@ -636,11 +628,7 @@ func testDecoding(t *testing.T, isTLS bool) {
ctx.extras["pg"] = pg
},
postMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
EnableTLS: isTLS,
})
require.NoError(t, err)
pg := ctx.extras["pg"].(*postgres.PGXClient)
require.NoError(t, pg.Ping())
ctx.extras["pg"] = pg
require.NoError(t, pg.RunQuery(createTableQuery))
Expand Down
Loading