diff --git a/src/stirling/source_connectors/socket_tracer/protocols/tls/parse.cc b/src/stirling/source_connectors/socket_tracer/protocols/tls/parse.cc index 93d90cfa6ea..4e5b5e788cb 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/tls/parse.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/tls/parse.cc @@ -31,6 +31,8 @@ namespace stirling { namespace protocols { namespace tls { +using px::utils::JSONObjectBuilder; + constexpr size_t kTLSRecordHeaderLength = 5; constexpr size_t kExtensionMinimumLength = 4; constexpr size_t kSNIExtensionMinimumLength = 3; @@ -39,11 +41,9 @@ constexpr size_t kSNIExtensionMinimumLength = 3; // In TLS 1.2 and earlier, gmt_unix_time is 4 bytes and Random is 28 bytes. constexpr size_t kRandomStructLength = 32; -StatusOr ExtractSNIExtension(std::map* exts, - BinaryDecoder* decoder) { +StatusOr ExtractSNIExtension(ReqExtensions* exts, BinaryDecoder* decoder) { PX_ASSIGN_OR(auto server_name_list_length, decoder->ExtractBEInt(), return ParseState::kInvalid); - std::vector server_names; while (server_name_list_length > 0) { PX_ASSIGN_OR(auto server_name_type, decoder->ExtractBEInt(), return error::Internal("Failed to extract server name type")); @@ -56,10 +56,9 @@ StatusOr ExtractSNIExtension(std::map* ext PX_ASSIGN_OR(auto server_name, decoder->ExtractString(server_name_length), return error::Internal("Failed to extract server name")); - server_names.push_back(std::string(server_name)); + exts->server_names.push_back(std::string(server_name)); server_name_list_length -= kSNIExtensionMinimumLength + server_name_length; } - exts->insert({"server_name", ToJSONString(server_names)}); return ParseState::kSuccess; } @@ -162,6 +161,8 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) { return ParseState::kSuccess; } + ReqExtensions req_extensions; + RespExtensions resp_extensions; while (extensions_length > 0) { PX_ASSIGN_OR(auto extension_type, decoder->ExtractBEInt(), return ParseState::kInvalid); @@ -170,7 +171,7 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) { if (extension_length > 0) { if (extension_type == 0x00) { - if (!ExtractSNIExtension(&frame->extensions, decoder).ok()) { + if (!ExtractSNIExtension(&req_extensions, decoder).ok()) { return ParseState::kInvalid; } } else { @@ -182,6 +183,13 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) { extensions_length -= kExtensionMinimumLength + extension_length; } + JSONObjectBuilder req_body_builder; + req_body_builder.WriteKVRecursive("extensions", req_extensions); + frame->req_body = req_body_builder.GetString(); + + JSONObjectBuilder resp_body_builder; + resp_body_builder.WriteKVRecursive("extensions", resp_extensions); + frame->resp_body = resp_body_builder.GetString(); return ParseState::kSuccess; } diff --git a/src/stirling/source_connectors/socket_tracer/protocols/tls/parse_test.cc b/src/stirling/source_connectors/socket_tracer/protocols/tls/parse_test.cc index bbffb9618f7..e1847f0b2cb 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/tls/parse_test.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/tls/parse_test.cc @@ -315,8 +315,7 @@ TEST_F(TLSParserTest, ParseValidClientHello) { ASSERT_GT(frame.session_id.size(), 0); // Validate the SNI extension was parsed properly - ASSERT_EQ(frame.extensions.size(), 1); - ASSERT_EQ(frame.extensions["server_name"], "[\"argocd-cluster-repo-server\"]"); + ASSERT_EQ(frame.req_body, R"({"extensions":{"server_name":["argocd-cluster-repo-server"]}})"); ASSERT_EQ(state, ParseState::kSuccess); } diff --git a/src/stirling/source_connectors/socket_tracer/protocols/tls/types.h b/src/stirling/source_connectors/socket_tracer/protocols/tls/types.h index c7f8a785fb9..2ffa2bc3e96 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/tls/types.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/tls/types.h @@ -43,8 +43,6 @@ namespace stirling { namespace protocols { namespace tls { -using ::px::utils::ToJSONString; - enum class ContentType : uint8_t { kChangeCipherSpec = 0x14, kAlert = 0x15, @@ -186,6 +184,25 @@ enum class ExtensionType : uint16_t { kRenegotiationInfo = 65281, }; +// Extensions that are common to both the client and server side +// of a TLS handshake +struct SharedExtensions { + void ToJSON(::px::utils::JSONObjectBuilder* /*builder*/) const {} +}; + +struct ReqExtensions : public SharedExtensions { + std::vector server_names; + + void ToJSON(::px::utils::JSONObjectBuilder* builder) const { + SharedExtensions::ToJSON(builder); + builder->WriteKV("server_name", server_names); + } +}; + +struct RespExtensions : public SharedExtensions { + void ToJSON(::px::utils::JSONObjectBuilder* builder) const { SharedExtensions::ToJSON(builder); } +}; + struct Frame : public FrameBase { ContentType content_type; @@ -200,7 +217,8 @@ struct Frame : public FrameBase { LegacyVersion handshake_version; std::string session_id; - std::map extensions; + std::string req_body; + std::string resp_body; bool consumed = false; @@ -209,9 +227,9 @@ struct Frame : public FrameBase { std::string ToString() const override { return absl::Substitute( "TLS Frame [len=$0 content_type=$1 legacy_version=$2 handshake_version=$3 " - "handshake_type=$4 extensions=$5]", - length, content_type, legacy_version, handshake_version, handshake_type, - ToJSONString(extensions)); + "handshake_type=$4 req_body=$5 resp_body=$6]", + length, content_type, legacy_version, handshake_version, handshake_type, req_body, + resp_body); } }; diff --git a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc index 3fd774d09ea..62dbb535d5a 100644 --- a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc +++ b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc @@ -203,12 +203,12 @@ using px::utils::ToJSONString; // Most HTTP servers support 8K headers, so we truncate after that. // https://stackoverflow.com/questions/686217/maximum-on-http-header-values constexpr size_t kMaxHTTPHeadersBytes = 8192; -// TLS records have a maximum size of 16KiB. While there isn't a size limit -// for the extensions, we limit it to 1 KiB to avoid excessive memory usage. -// A typical ClientHello from curl is around 500 bytes. This assumes that +// TLS records have a maximum size of 16KiB. The bulk of the body columns are extensions +// and while there isn't a size limit for them, we limit it to 1 KiB to avoid excessive +// memory usage. A typical ClientHello from curl is around 500 bytes. This assumes that // all extensions are captured, but we won't support capturing all extensions and // will avoid large extensions like the padding extension, -constexpr size_t kMaxTLSExtensionsBytes = 1024; +constexpr size_t kMaxTLSBodyBytes = 1024; // Protobuf printer will limit strings to this length. constexpr size_t kMaxPBStringLen = 64; @@ -1721,9 +1721,10 @@ void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracke r.Append(conn_tracker.local_endpoint().AddrStr()); r.Append(conn_tracker.local_endpoint().port()); r.Append(conn_tracker.role()); - r.Append(static_cast(req_message.content_type)); + r.Append(static_cast(req_message.content_type)); r.Append(static_cast(req_message.legacy_version)); - r.Append(ToJSONString(req_message.extensions), kMaxTLSExtensionsBytes); + r.Append(req_message.req_body, kMaxTLSBodyBytes); + r.Append(resp_message.resp_body, kMaxTLSBodyBytes); r.Append( CalculateLatency(req_message.timestamp_ns, resp_message.timestamp_ns)); #ifndef NDEBUG diff --git a/src/stirling/source_connectors/socket_tracer/tls_table.h b/src/stirling/source_connectors/socket_tracer/tls_table.h index 865fee39944..9f762dca384 100644 --- a/src/stirling/source_connectors/socket_tracer/tls_table.h +++ b/src/stirling/source_connectors/socket_tracer/tls_table.h @@ -37,7 +37,7 @@ static constexpr DataElement kTLSElements[] = { canonical_data_elements::kLocalAddr, canonical_data_elements::kLocalPort, canonical_data_elements::kTraceRole, - {"req_type", "The type of request from the TLS record (Client/ServerHello, etc.)", + {"content_type", "The content type of the TLS record (e.g. handshake, alert, heartbeat, etc)", types::DataType::INT64, types::SemanticType::ST_NONE, types::PatternType::GENERAL_ENUM}, @@ -45,10 +45,14 @@ static constexpr DataElement kTLSElements[] = { types::DataType::INT64, types::SemanticType::ST_NONE, types::PatternType::GENERAL_ENUM}, - {"extensions", "Extensions in the TLS record", + {"req_body", "Request body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions)", types::DataType::STRING, types::SemanticType::ST_NONE, - types::PatternType::GENERAL}, + types::PatternType::STRUCTURED}, + {"resp_body", "Response body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions)", + types::DataType::STRING, + types::SemanticType::ST_NONE, + types::PatternType::STRUCTURED}, canonical_data_elements::kLatencyNS, #ifndef NDEBUG canonical_data_elements::kPXInfo, @@ -61,9 +65,9 @@ static constexpr auto kTLSTable = DEFINE_PRINT_TABLE(TLS) constexpr int kTLSUPIDIdx = kTLSTable.ColIndex("upid"); -constexpr int kTLSCmdIdx = kTLSTable.ColIndex("req_type"); +constexpr int kTLSCmdIdx = kTLSTable.ColIndex("content_type"); constexpr int kTLSVersionIdx = kTLSTable.ColIndex("version"); -constexpr int kTLSExtensionsIdx = kTLSTable.ColIndex("extensions"); +constexpr int kTLSReqBodyIdx = kTLSTable.ColIndex("req_body"); } // namespace stirling } // namespace px diff --git a/src/stirling/source_connectors/socket_tracer/tls_trace_bpf_test.cc b/src/stirling/source_connectors/socket_tracer/tls_trace_bpf_test.cc index e2cebdc892b..03511c06b48 100644 --- a/src/stirling/source_connectors/socket_tracer/tls_trace_bpf_test.cc +++ b/src/stirling/source_connectors/socket_tracer/tls_trace_bpf_test.cc @@ -50,7 +50,7 @@ using ::testing::UnorderedElementsAre; struct TraceRecords { std::vector tls_records; - std::vector tls_extensions; + std::vector req_body; }; class NginxOpenSSL_3_0_8_ContainerWrapper @@ -80,11 +80,11 @@ tls::Record GetExpectedTLSRecord() { return expected_record; } -inline std::vector GetExtensions(const types::ColumnWrapperRecordBatch& rb, - const std::vector& indices) { +inline std::vector GetRequestBody(const types::ColumnWrapperRecordBatch& rb, + const std::vector& indices) { std::vector exts; for (size_t idx : indices) { - exts.push_back(rb[kTLSExtensionsIdx]->Get(idx)); + exts.push_back(rb[kTLSReqBodyIdx]->Get(idx)); } return exts; } @@ -127,9 +127,9 @@ class TLSVersionParameterizedTest TraceRecords records = this->GetTraceRecords(this->server_.PID()); EXPECT_THAT(records.tls_records, SizeIs(1)); - EXPECT_THAT(records.tls_extensions, SizeIs(1)); - auto sni_str = R"({"server_name":"[\"test-host\"]"})"; - EXPECT_THAT(records.tls_extensions[0], StrEq(sni_str)); + EXPECT_THAT(records.req_body, SizeIs(1)); + auto sni_str = R"({"extensions":{"server_name":["test-host"]}})"; + EXPECT_THAT(records.req_body[0], StrEq(sni_str)); } // Returns the trace records of the process specified by the input pid. @@ -144,7 +144,7 @@ class TLSVersionParameterizedTest FindRecordIdxMatchesPID(record_batch, kTLSUPIDIdx, pid); std::vector tls_records = ToRecordVector(record_batch, server_record_indices); - std::vector extensions = GetExtensions(record_batch, server_record_indices); + std::vector extensions = GetRequestBody(record_batch, server_record_indices); return {std::move(tls_records), std::move(extensions)}; }