Skip to content

Commit

Permalink
fixed cr notes
Browse files Browse the repository at this point in the history
  • Loading branch information
amitslavin committed Nov 3, 2024
1 parent 3b24f9c commit 89b5b95
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 58 deletions.
86 changes: 47 additions & 39 deletions pkg/network/ebpf/c/protocols/postgres/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,24 @@ static int __always_inline skip_string(pktbuf_t pkt, int message_len) {
return SKIP_STRING_FAILED;
}

// postgres_handle_message reads the first message header and decides what to do based on the
// Reads the first message header and decides what to do based on the
// message tag. If the message is a new query, it stores the query in the in-flight map.
// If the message is a parse message, we tail call to the dedicated process_parse_message program.
// If the message is a command complete, it calls the handle_command_complete program.
static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *conn_tuple, struct pg_message_header *header, __u8 tags) {
const __u32 zero = 0;
// If the message is a parse message, we tail call to the dedicated function to handle it as it is too large to be
// inlined in the main function.
if (header->message_tag == POSTGRES_PARSE_MAGIC_BYTE) {
pktbuf_tail_call_option_t process_parse_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
};
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_PROCESS_PARSE_MESSAGE,
},
};
pktbuf_tail_call_compact(pkt, process_parse_tail_call_array);
return;
}
Expand All @@ -143,6 +142,7 @@ static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *
return;
}

const __u32 zero = 0;
postgres_tail_call_state_t *iteration_value = bpf_map_lookup_elem(&postgres_iterations, &zero);
if (iteration_value == NULL) {
return;
Expand All @@ -151,15 +151,15 @@ static __always_inline void postgres_handle_message(pktbuf_t pkt, conn_tuple_t *
iteration_value->iteration = 0;
iteration_value->data_off = 0;
pktbuf_tail_call_option_t handle_response_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
pktbuf_tail_call_compact(pkt, handle_response_tail_call_array);
return;
}
Expand Down Expand Up @@ -201,16 +201,12 @@ static __always_inline void postgres_handle_parse_message(pktbuf_t pkt, conn_tup
// POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES) to continue processing.
static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tuple) {
const __u32 zero = 0;
bool found_command_complete = false;
struct pg_message_header header;
// We didn't find a new query, thus we assume we're in the middle of a transaction.
// We look up the transaction in the in-flight map, and if it doesn't exist, we ignore the message.
postgres_transaction_t *transaction = bpf_map_lookup_elem(&postgres_in_flight, &conn_tuple);
if (!transaction) {
return 0;
}

postgres_tail_call_state_t *iteration_value = bpf_map_lookup_elem(&postgres_iterations, &zero);
if (iteration_value == NULL) {
bpf_map_delete_elem(&postgres_in_flight, &conn_tuple);
return 0;
}

Expand All @@ -222,13 +218,21 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl
pktbuf_set_offset(pkt, iteration_value->data_off);
}

// We didn't find a new query, thus we assume we're in the middle of a transaction.
// We look up the transaction in the in-flight map, and if it doesn't exist, we ignore the message.
postgres_transaction_t *transaction = bpf_map_lookup_elem(&postgres_in_flight, &conn_tuple);
if (!transaction) {
return 0;
}

#pragma unroll(POSTGRES_MAX_MESSAGES_PER_TAIL_CALL)
for (__u32 iteration = 0; iteration < POSTGRES_MAX_MESSAGES_PER_TAIL_CALL; ++iteration) {
if (!read_message_header(pkt, &header)) {
break;
}
if (header.message_tag == POSTGRES_COMMAND_COMPLETE_MAGIC_BYTE) {
handle_command_complete(&conn_tuple, transaction);
found_command_complete = true;
break;
}
// We didn't find a command complete message, so we advance the data offset to the end of the message.
Expand All @@ -237,19 +241,23 @@ static __always_inline bool handle_response(pktbuf_t pkt, conn_tuple_t conn_tupl
pktbuf_advance(pkt, header.message_len + 1);
}

iteration_value->iteration += 1;
iteration_value->data_off = pktbuf_data_offset(pkt);
pktbuf_tail_call_option_t handle_response_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
pktbuf_tail_call_compact(pkt, handle_response_tail_call_array);
if (!found_command_complete) {
// We didn't find a command complete message, so we need to continue processing the packet.
// We save the current data offset and increment the iteration counter.
iteration_value->iteration += 1;
iteration_value->data_off = pktbuf_data_offset(pkt);
pktbuf_tail_call_option_t handle_response_tail_call_array[] = {
[PKTBUF_SKB] = {
.prog_array_map = &protocols_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
[PKTBUF_TLS] = {
.prog_array_map = &tls_process_progs,
.index = PROG_POSTGRES_HANDLE_RESPONSE,
},
};
pktbuf_tail_call_compact(pkt, handle_response_tail_call_array);
}
return 0;
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/network/protocols/ebpf_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ const (
layerEncryptionBit = C.LAYER_ENCRYPTION_BIT
)

// Represents the maximum number of messages that can be processed in our Postgres decoding solution.
const (
// PostgresMaxMessagesPerTailCall is the maximum number of messages that can be processed in a single tail call in our Postgres decoding solution
PostgresMaxMessagesPerTailCall = C.POSTGRES_MAX_MESSAGES_PER_TAIL_CALL
PostgresMaxTailCalls = C.POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES
// PostgresMaxTailCalls is the maximum number of tail calls that can be made in our Postgres decoding solution
PostgresMaxTailCalls = C.POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES
)

// DispatcherProgramType is a C type to represent the eBPF programs used for tail calls.
Expand Down
3 changes: 2 additions & 1 deletion pkg/network/protocols/ebpf_types_linux.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 4 additions & 16 deletions pkg/network/usm/postgres_monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,7 @@ func testDecoding(t *testing.T, isTLS bool) {
ctx.extras["pg"] = pg
},
postMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
EnableTLS: isTLS,
})
require.NoError(t, err)
pg := ctx.extras["pg"].(*postgres.PGXClient)
require.NoError(t, pg.Ping())
ctx.extras["pg"] = pg
require.NoError(t, pg.RunQuery(createTableQuery))
Expand Down Expand Up @@ -589,7 +585,7 @@ func testDecoding(t *testing.T, isTLS bool) {
},
// The purpose of this test is to validate the POSTGRES_MAX_MESSAGES_PER_TAIL_CALL * POSTGRES_MAX_TAIL_CALLS_FOR_MAX_MESSAGES limit.
{
name: "validate supporting POSTGRES_MAX_MESSAGES_PER_TAIL_CALL limit",
name: "validate supporting max supported messages limit",
preMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
Expand All @@ -600,11 +596,7 @@ func testDecoding(t *testing.T, isTLS bool) {
ctx.extras["pg"] = pg
},
postMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
EnableTLS: isTLS,
})
require.NoError(t, err)
pg := ctx.extras["pg"].(*postgres.PGXClient)
require.NoError(t, pg.Ping())
ctx.extras["pg"] = pg
require.NoError(t, pg.RunQuery(createTableQuery))
Expand Down Expand Up @@ -636,11 +628,7 @@ func testDecoding(t *testing.T, isTLS bool) {
ctx.extras["pg"] = pg
},
postMonitorSetup: func(t *testing.T, ctx pgTestContext) {
pg, err := postgres.NewPGXClient(postgres.ConnectionOptions{
ServerAddress: ctx.serverAddress,
EnableTLS: isTLS,
})
require.NoError(t, err)
pg := ctx.extras["pg"].(*postgres.PGXClient)
require.NoError(t, pg.Ping())
ctx.extras["pg"] = pg
require.NoError(t, pg.RunQuery(createTableQuery))
Expand Down

0 comments on commit 89b5b95

Please sign in to comment.