diff --git a/changelog.d/21286-fix-logstash-frame-decode.fix.md b/changelog.d/21286-fix-logstash-frame-decode.fix.md new file mode 100644 index 00000000000000..51349b5c8d2764 --- /dev/null +++ b/changelog.d/21286-fix-logstash-frame-decode.fix.md @@ -0,0 +1 @@ +Fixed the `logstash` source coder to remove a panic that could trigger in the event of certain protocol errors. diff --git a/src/sources/logstash.rs b/src/sources/logstash.rs index 15ab21fc7d92b0..e049ee21eef064 100644 --- a/src/sources/logstash.rs +++ b/src/sources/logstash.rs @@ -520,148 +520,174 @@ impl Decoder for LogstashDecoder { } // https://github.com/logstash-plugins/logstash-input-beats/blob/master/PROTOCOL.md#data-frame-type LogstashDecoderReadState::ReadFrame(protocol, LogstashFrameType::Data) => { - let mut rest = src.as_ref(); - - if rest.remaining() < 8 { + let Some(frame) = decode_data_frame(protocol, src) else { return Ok(None); - } - let sequence_number = rest.get_u32(); - let pair_count = rest.get_u32(); - - let mut fields = BTreeMap::::new(); - for _ in 0..pair_count { - if src.remaining() < 4 { - return Ok(None); - } - let key_length = rest.get_u32() as usize; - - if rest.remaining() < key_length { - return Ok(None); - } - let (key, right) = rest.split_at(key_length); - rest = right; - - if src.remaining() < 4 { - return Ok(None); - } - let value_length = rest.get_u32() as usize; - if rest.remaining() < value_length { - return Ok(None); - } - let (value, right) = rest.split_at(value_length); - rest = right; - - fields.insert( - String::from_utf8_lossy(key).into(), - String::from_utf8_lossy(value).into(), - ); - } - - let remaining = rest.remaining(); - let byte_size = src.remaining() - remaining; - - src.advance(byte_size); + }; - let frames = vec![( - LogstashEventFrame { - protocol, - sequence_number, - fields, - }, - byte_size, - )] - .into(); - - LogstashDecoderReadState::PendingFrames(frames) + LogstashDecoderReadState::PendingFrames([frame].into()) } // https://github.com/logstash-plugins/logstash-input-beats/blob/master/PROTOCOL.md#json-frame-type LogstashDecoderReadState::ReadFrame(protocol, LogstashFrameType::Json) => { - let mut rest = src.as_ref(); - - if rest.remaining() < 8 { + let Some(frame) = decode_json_frame(protocol, src)? else { return Ok(None); - } - let sequence_number = rest.get_u32(); - let payload_size = rest.get_u32() as usize; + }; - if rest.remaining() < payload_size { + LogstashDecoderReadState::PendingFrames([frame].into()) + } + // https://github.com/logstash-plugins/logstash-input-beats/blob/master/PROTOCOL.md#compressed-frame-type + LogstashDecoderReadState::ReadFrame(_protocol, LogstashFrameType::Compressed) => { + let Some(frames) = decode_compressed_frame(src)? else { return Ok(None); - } + }; - let (slice, right) = rest.split_at(payload_size); - rest = right; + LogstashDecoderReadState::PendingFrames(frames) + } + }; + } + } +} - let fields_result: Result, _> = - serde_json::from_slice(slice).context(JsonFrameFailedDecodeSnafu {}); +/// Decode the Lumberjack version 1 protocol, which use the Key:Value format. +fn decode_data_frame( + protocol: LogstashProtocolVersion, + src: &mut BytesMut, +) -> Option<(LogstashEventFrame, usize)> { + let mut rest = src.as_ref(); - let remaining = rest.remaining(); - let byte_size = src.remaining() - remaining; + if rest.remaining() < 8 { + return None; + } + let sequence_number = rest.get_u32(); + let pair_count = rest.get_u32(); + if pair_count == 0 { + return None; // Invalid number of fields + } - src.advance(byte_size); + let mut fields = BTreeMap::::new(); + for _ in 0..pair_count { + let (key, value, right) = decode_pair(rest)?; + rest = right; - match fields_result { - Ok(fields) => { - let frames = vec![( - LogstashEventFrame { - protocol, - sequence_number, - fields, - }, - byte_size, - )] - .into(); + fields.insert( + String::from_utf8_lossy(key).into(), + String::from_utf8_lossy(value).into(), + ); + } - LogstashDecoderReadState::PendingFrames(frames) - } - Err(err) => return Err(err), - } - } - // https://github.com/logstash-plugins/logstash-input-beats/blob/master/PROTOCOL.md#compressed-frame-type - LogstashDecoderReadState::ReadFrame(_protocol, LogstashFrameType::Compressed) => { - let mut rest = src.as_ref(); + let byte_size = bytes_remaining(src, rest); + src.advance(byte_size); - if rest.remaining() < 4 { - return Ok(None); - } - let payload_size = rest.get_u32() as usize; + Some(( + LogstashEventFrame { + protocol, + sequence_number, + fields, + }, + byte_size, + )) +} - if rest.remaining() < payload_size { - src.reserve(payload_size); - return Ok(None); - } +fn decode_pair(mut rest: &[u8]) -> Option<(&[u8], &[u8], &[u8])> { + if rest.remaining() < 4 { + return None; + } + let key_length = rest.get_u32() as usize; - let (slice, right) = rest.split_at(payload_size); - rest = right; + if rest.remaining() < key_length { + return None; + } + let (key, right) = rest.split_at(key_length); + rest = right; - let mut buf = { - let mut buf = Vec::new(); + if rest.remaining() < 4 { + return None; + } + let value_length = rest.get_u32() as usize; + if rest.remaining() < value_length { + return None; + } + let (value, right) = rest.split_at(value_length); + Some((key, value, right)) +} - let res = ZlibDecoder::new(io::Cursor::new(slice)) - .read_to_end(&mut buf) - .context(DecompressionFailedSnafu) - .map(|_| BytesMut::from(&buf[..])); +fn decode_json_frame( + protocol: LogstashProtocolVersion, + src: &mut BytesMut, +) -> Result, DecodeError> { + let mut rest = src.as_ref(); - let remaining = rest.remaining(); - let byte_size = src.remaining() - remaining; + if rest.remaining() < 8 { + return Ok(None); + } + let sequence_number = rest.get_u32(); + let payload_size = rest.get_u32() as usize; - src.advance(byte_size); + if rest.remaining() < payload_size { + return Ok(None); + } - res - }?; + let (slice, right) = rest.split_at(payload_size); + rest = right; - let mut decoder = LogstashDecoder::new(); + let fields: BTreeMap = + serde_json::from_slice(slice).context(JsonFrameFailedDecodeSnafu {})?; - let mut frames = VecDeque::new(); + let byte_size = bytes_remaining(src, rest); + src.advance(byte_size); - while let Some(s) = decoder.decode(&mut buf)? { - frames.push_back(s); - } + Ok(Some(( + LogstashEventFrame { + protocol, + sequence_number, + fields, + }, + byte_size, + ))) +} - LogstashDecoderReadState::PendingFrames(frames) - } - }; - } +fn decode_compressed_frame( + src: &mut BytesMut, +) -> Result>, DecodeError> { + let mut rest = src.as_ref(); + + if rest.remaining() < 4 { + return Ok(None); } + let payload_size = rest.get_u32() as usize; + + if rest.remaining() < payload_size { + src.reserve(payload_size); + return Ok(None); + } + + let (slice, right) = rest.split_at(payload_size); + rest = right; + + let mut buf = Vec::new(); + + let res = ZlibDecoder::new(io::Cursor::new(slice)) + .read_to_end(&mut buf) + .context(DecompressionFailedSnafu) + .map(|_| BytesMut::from(&buf[..])); + + let byte_size = bytes_remaining(src, rest); + src.advance(byte_size); + + let mut buf = res?; + + let mut decoder = LogstashDecoder::new(); + + let mut frames = VecDeque::new(); + + while let Some(s) = decoder.decode(&mut buf)? { + frames.push_back(s); + } + Ok(Some(frames)) +} + +fn bytes_remaining(src: &BytesMut, rest: &[u8]) -> usize { + let remaining = rest.remaining(); + src.remaining() - remaining } impl From for Event { @@ -685,6 +711,7 @@ impl From for SmallVec<[Event; 1]> { #[cfg(test)] mod test { use bytes::BufMut; + use futures::Stream; use rand::{thread_rng, Rng}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use vector_lib::lookup::OwnedTargetPath; @@ -715,26 +742,32 @@ mod test { test_protocol(EventStatus::Rejected, false).await; } + async fn start_logstash( + status: EventStatus, + ) -> (SocketAddr, impl Stream + Unpin) { + let (sender, recv) = SourceSender::new_test_finalize(status); + let address = next_addr(); + let source = LogstashConfig { + address: address.into(), + tls: None, + permit_origin: None, + keepalive: None, + receive_buffer_bytes: None, + acknowledgements: true.into(), + connection_limit: None, + log_namespace: None, + } + .build(SourceContext::new_test(sender, None)) + .await + .unwrap(); + tokio::spawn(source); + wait_for_tcp(address).await; + (address, recv) + } + async fn test_protocol(status: EventStatus, sends_ack: bool) { let events = assert_source_compliance(&SOCKET_PUSH_SOURCE_TAGS, async { - let (sender, recv) = SourceSender::new_test_finalize(status); - let address = next_addr(); - let source = LogstashConfig { - address: address.into(), - tls: None, - permit_origin: None, - keepalive: None, - receive_buffer_bytes: None, - acknowledgements: true.into(), - connection_limit: None, - log_namespace: None, - } - .build(SourceContext::new_test(sender, None)) - .await - .unwrap(); - tokio::spawn(source); - wait_for_tcp(address).await; - + let (address, recv) = start_logstash(status).await; spawn_collect_n( send_req(address, &[("message", "Hello, world!")], sends_ack), recv, @@ -773,6 +806,18 @@ mod test { req.into() } + #[test] + fn v1_decoder_does_not_panic() { + let seq = thread_rng().gen_range(1..u32::MAX); + let req = encode_req(seq, &[("message", "Hello, World!")]); + for i in 0..req.len() - 1 { + assert!( + decode_data_frame(LogstashProtocolVersion::V1, &mut BytesMut::from(&req[..i])) + .is_none() + ); + } + } + async fn send_req(address: SocketAddr, pairs: &[(&str, &str)], sends_ack: bool) { let seq = thread_rng().gen_range(1..u32::MAX); let mut socket = tokio::net::TcpStream::connect(address).await.unwrap();