diff --git a/Cargo.toml b/Cargo.toml index acacfda..0351240 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ mc-varint = "0.1" rand = { version = "0.8", optional = true } serde = { version = "1.0", optional = true, features = ["serde_derive"] } serde_json = { version = "1.0", optional = true } -snafu = { version = "0.7", features = ["backtraces-impl-backtrace-crate"] } +snafu = { version = "0.8.1", features = ["backtraces-impl-backtrace-crate"] } tokio = { version = "1.21", features = [ "net", "io-util", @@ -35,7 +35,6 @@ tokio = { version = "1.21", features = [ ], optional = true } tracing = "0.1" trust-dns-resolver = { version = "0.23", optional = true } -#void = { version = "1.0", optional = true } [dev-dependencies] ctor = "0.2.4" diff --git a/src/lib.rs b/src/lib.rs index f90af07..ac799b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,12 @@ +use snafu::{Backtrace, Snafu}; use std::time::Duration; -#[cfg(feature = "java_connect")] -use tokio::net::{lookup_host, TcpStream}; -use tracing::{debug, info, instrument}; #[cfg(feature = "java_connect")] pub mod mc_string; #[cfg(feature = "java_connect")] pub mod protocol; #[cfg(feature = "java_connect")] -use crate::protocol::{ping_error, protocol_error, ProtocolError}; -#[cfg(feature = "simple")] -pub use protocol::PingError; +pub use crate::protocol::connect; #[cfg(feature = "java_connect")] pub use protocol::SlpProtocol; @@ -22,46 +18,16 @@ pub use parse::JavaServerInfo; #[cfg(feature = "bedrock")] pub mod bedrock; -#[cfg(feature = "java_connect")] -#[instrument] -pub async fn connect(mut addrs: (String, u16)) -> Result { - use tracing::debug; - use trust_dns_resolver::TokioAsyncResolver; - - let resolver = TokioAsyncResolver::tokio_from_system_conf()?; - if let Ok(records) = resolver - .srv_lookup(format!("_minecraft._tcp.{}", addrs.0)) - .await - { - if let Some(record) = records.iter().next() { - let record = record.target().to_utf8(); - debug!("Found SRV record: {} -> {}", addrs.0, record); - addrs.0 = record; - } - } - - // lookup_host can return multiple but we just need one so we discard the rest - let socket_addrs = match lookup_host(addrs.clone()).await?.next() { - Some(socket_addrs) => socket_addrs, - None => { - info!("DNS lookup failed for address"); - return Err(protocol_error::DNSLookupFailedSnafu { - address: format!("{:?}", addrs), - } - .build()); - } - }; - - match TcpStream::connect(socket_addrs).await { - Ok(stream) => { - info!("Connected to SLP server"); - Ok(SlpProtocol::new(addrs.0, addrs.1, stream)) - } - Err(error) => { - info!("Failed to connect to SLP server: {}", error); - Err(error.into()) - } - } +#[cfg(feature = "simple")] +#[derive(Snafu, Debug)] +pub enum PingError { + #[snafu(display("connection failed: {source}"), context(false))] + Protocol { + #[snafu(backtrace)] + source: crate::protocol::ProtocolError, + }, + #[snafu(display("connection did not respond in time"))] + Timeout { backtrace: Backtrace }, } #[cfg(feature = "simple")] @@ -86,7 +52,7 @@ pub async fn ping_or_timeout( select! { biased; info = ping(addrs) => info, - _ = sleep => Err(ping_error::TimeoutSnafu.build()), + _ = sleep => TimeoutSnafu.fail(), } } diff --git a/src/mc_string.rs b/src/mc_string.rs index b624589..b412cce 100644 --- a/src/mc_string.rs +++ b/src/mc_string.rs @@ -1,54 +1,53 @@ use bytes::Buf; use mc_varint::{VarInt, VarIntRead, VarIntWrite}; +use snafu::{OptionExt, Snafu}; use std::io::Cursor; -mod error { - use super::*; - use snafu::Snafu; - - #[derive(Snafu, Debug)] - pub enum McStringError { - #[snafu(display("io error: {source}"))] - Io { source: std::io::Error }, - #[snafu(display( - "string is too long (is {length} bytes, but expected less than {} bytes)", - MAX_LEN - ))] - TooLong { length: usize }, - #[snafu(display("invalid string format"))] - InvalidFormat, - } +#[derive(Snafu, Debug)] +pub enum McStringError { + #[snafu(display("io error: {source}"), context(false))] + Io { + source: std::io::Error, + backtrace: snafu::Backtrace, + }, + #[snafu(display( + "string is too long (is {length} bytes, but expected less than {} bytes)", + MAX_LEN + ))] + TooLong { + length: usize, + backtrace: snafu::Backtrace, + }, + #[snafu(display("invalid string format"))] + InvalidFormat { backtrace: snafu::Backtrace }, } -pub use error::McStringError; - pub const MAX_LEN: i32 = i32::MAX; pub fn encode_mc_string(string: &str) -> Result, McStringError> { let len = string.len(); // VarInt max length is 5 bytes let mut bytes = Vec::with_capacity(len + 5); - bytes - .write_var_int(VarInt::from( - i32::try_from(len) - .ok() - .ok_or(McStringError::TooLong { length: len })?, - )) - .map_err(|io| McStringError::Io { source: io })?; + bytes.write_var_int(VarInt::from( + i32::try_from(len) + .ok() + .context(TooLongSnafu { length: len })?, + ))?; bytes.extend_from_slice(string.as_bytes()); Ok(bytes) } pub fn decode_mc_string(cursor: &mut Cursor<&[u8]>) -> Result { - let len: i32 = cursor - .read_var_int() - .map_err(|io| McStringError::Io { source: io })? - .into(); - let len = usize::try_from(len).map_err(|_| McStringError::InvalidFormat)?; + let len: i32 = cursor.read_var_int()?.into(); + let len = usize::try_from(len).ok().context(InvalidFormatSnafu)?; let bytes = cursor.chunk(); - let string = std::str::from_utf8(&bytes[..len]) - .map_err(|_| McStringError::InvalidFormat)? + if len > bytes.len() { + return InvalidFormatSnafu.fail(); + } + let string = std::str::from_utf8(bytes.get(..len).context(InvalidFormatSnafu)?) + .ok() + .context(InvalidFormatSnafu)? .to_string(); cursor.advance(len); Ok(string) diff --git a/src/protocol.rs b/src/protocol.rs index b141675..ec5be4d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,11 +1,13 @@ +pub use self::frame::{Frame, FrameError, ServerState}; +use crate::mc_string::encode_mc_string; use crate::mc_string::McStringError; -use crate::mc_string::{decode_mc_string, encode_mc_string}; #[cfg(feature = "java_parse")] use crate::parse::JavaServerInfo; use bytes::{Buf, BytesMut}; -use mc_varint::{VarInt, VarIntRead, VarIntWrite}; +use mc_varint::{VarInt, VarIntWrite}; +use snafu::OptionExt; use snafu::{Backtrace, GenerateImplicitData, Snafu}; -use std::array::TryFromSliceError; +use std::str::FromStr; use std::{ fmt::Debug, io::{Cursor, Write}, @@ -17,95 +19,49 @@ use tokio::{ }; use tracing::{debug, event, instrument, trace, Level}; -// module adapted from tokio's mini-redis -// which is licensed here: https://github.com/tokio-rs/mini-redis/blob/cefca5377af54520904c55764d16fc7c0a291902/LICENSE +mod frame; #[derive(Snafu, Debug)] -#[snafu(visibility(pub(crate)), module)] pub enum ProtocolError { #[snafu(display("io error: {source}"), context(false))] Io { source: std::io::Error, backtrace: Backtrace, }, - #[snafu(display("dns lookup failed for address `{address}`"))] - DNSLookupFailed { - address: String, - backtrace: Backtrace, - }, #[snafu(display("failed to encode string as bytes: {source}"), context(false))] StringEncodeFailed { + #[snafu(backtrace)] source: McStringError, - backtrace: Backtrace, }, #[snafu(display( "failed to send packet because it is too long (more than {} bytes)", i32::MAX ))] PacketTooLong { backtrace: Backtrace }, - #[snafu(display("connection closed before packet finished being read"))] + #[snafu(display("connection closed unexpectedly"))] ConnectionClosed { backtrace: Backtrace }, #[snafu(display("failed to parse packet: {source}"), context(false))] ParseFailed { + #[snafu(backtrace)] source: FrameError, - backtrace: Backtrace, }, #[snafu(display("srv resolver creation failed: {source}"), context(false))] SrvResolveError { source: trust_dns_resolver::error::ResolveError, backtrace: Backtrace, }, -} - -#[derive(Snafu, Debug)] -#[snafu(visibility(pub(crate)), module)] -pub enum FrameError { - #[snafu(display("frame is missing data"))] - Incomplete { backtrace: Backtrace }, - #[snafu(display("io error: {source}"), context(false))] - Io { - source: std::io::Error, - backtrace: Backtrace, - }, - #[snafu(display("frame declares it has negative length"))] - InvalidLength { backtrace: Backtrace }, - #[snafu(display("cannot parse frame with id {id}"))] - InvalidFrame { id: i32, backtrace: Backtrace }, - #[snafu(display("failed to decode string: {source}"), context(false))] - StringDecodeFailed { - source: McStringError, - backtrace: Backtrace, - }, - #[snafu( - display("failed to decode ping response payload: {source}"), - context(false) - )] - PingResponseDecodeFailed { - source: TryFromSliceError, - backtrace: Backtrace, - }, -} - -#[cfg(feature = "simple")] -#[derive(Snafu, Debug)] -#[snafu(visibility(pub(crate)), module)] -pub enum PingError { - #[snafu(display("connection failed"), context(false))] - Protocol { - source: ProtocolError, + #[snafu(display("packet received out of order"))] + FrameOutOfOrder { backtrace: Backtrace }, + #[snafu(display("failed to parse server response: {source}"), context(false))] + JsonParse { + source: serde_json::Error, backtrace: Backtrace, }, - #[snafu(display("connection closed"))] - ConnectionClosed { backtrace: Backtrace }, - #[snafu(display("invalid response from server"))] - InvalidResponse { backtrace: Backtrace }, - #[snafu(display("failed to parse server response"), context(false))] - Parse { - source: serde_json::Error, + #[snafu(display("dns lookup failed for address `{address}`"))] + DNSLookupFailed { + address: String, backtrace: Backtrace, }, - #[snafu(display("server did not respond in time"))] - Timeout { backtrace: Backtrace }, } #[derive(Debug)] @@ -116,138 +72,6 @@ pub struct SlpProtocol { buffer: BytesMut, } -#[derive(Debug)] -#[non_exhaustive] -pub enum Frame { - Handshake { - protocol: VarInt, - address: String, - port: u16, - // should be 1 for status - state: VarInt, - }, - StatusRequest, - StatusResponse { - json: String, - }, - PingRequest { - payload: i64, - }, - PingResponse { - payload: i64, - }, -} - -/// Controls what packets a server can recieve -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ServerState { - /// Waiting for the Handshake packet - Handshake, - /// Ready to respond to status and ping requests - Status, -} - -impl Frame { - pub const PROTOCOL_VERSION: i32 = 754; - pub const HANDSHAKE_ID: i32 = 0x00; - pub const STATUS_REQUEST_ID: i32 = 0x00; - pub const STATUS_RESPONSE_ID: i32 = 0x00; - pub const PING_REQUEST_ID: i32 = 0x01; - pub const PING_RESPONSE_ID: i32 = 0x01; - - /// Checks if an entire message can be decoded from `buf`, advancing the cursor past the header - pub fn check(buf: &mut Cursor<&[u8]>) -> Result<(), FrameError> { - let available_data = buf.get_ref().len(); - - // the varint at the beginning contains the size of the rest of the frame - let remaining_data_len: usize = - i32::from(buf.read_var_int().map_err(|_| FrameError::Incomplete { - backtrace: Backtrace::generate(), - })?) - .try_into() - .map_err(|_| FrameError::InvalidLength { - backtrace: Backtrace::generate(), - })?; - let header_len = buf.position() as usize; - let total_len = header_len + remaining_data_len; - - // if we don't have enough data the frame isn't valid yet - let is_valid = available_data >= total_len; - - if is_valid { - trace!("Valid frame, packet size: {total_len}, header size: {header_len}, body size: {remaining_data_len}, downloaded: {available_data}"); - Ok(()) - } else { - trace!("Invalid frame, packet size: {total_len}, downloaded: {available_data}"); - Err(FrameError::Incomplete { - backtrace: Backtrace::generate(), - }) - } - } - - /// Parse the body of a frame, after the message has already been validated with `check`. - /// - /// # Arguments - /// - /// * `src` - The buffer containing the message - /// * `server_state` - Switches between which type of frame to accept. Set to None to accept frames for the client. - pub fn parse( - cursor: &mut Cursor<&[u8]>, - server_state: Option, - ) -> Result { - let id = i32::from(cursor.read_var_int()?); - - match server_state { - Some(ServerState::Handshake) => { - if id == Self::HANDSHAKE_ID { - let protocol = cursor.read_var_int()?; - let address = decode_mc_string(cursor)?; - let port = cursor.get_u16(); - let state = cursor.read_var_int()?; - return Ok(Frame::Handshake { - protocol, - address, - port, - state, - }); - } - } - Some(ServerState::Status) => { - match id { - Self::STATUS_REQUEST_ID => { - return Ok(Frame::StatusRequest); - } - Self::PING_REQUEST_ID => { - // ping request a contains (usually) meaningless Java long - let payload = cursor.get_i64(); - return Ok(Frame::PingRequest { payload }); - } - _ => {} - } - } - None => { - match id { - Self::STATUS_RESPONSE_ID => { - let json = decode_mc_string(cursor)?; - return Ok(Frame::StatusResponse { json }); - } - Self::PING_RESPONSE_ID => { - // ping response contains the same Java long as the request - let payload = cursor.get_i64(); - return Ok(Frame::PingResponse { payload }); - } - _ => {} - } - } - } - - Err(FrameError::InvalidFrame { - id, - backtrace: Backtrace::generate(), - }) - } -} - #[repr(i32)] pub enum ProtocolState { Status = 1, @@ -410,23 +234,21 @@ impl SlpProtocol { } #[cfg(feature = "simple")] - pub async fn get_status(&mut self) -> Result { - use std::str::FromStr; - + pub async fn get_status(&mut self) -> Result { self.write_frame(Frame::StatusRequest).await?; let frame = self .read_frame(None) .await? - .ok_or_else(|| ping_error::ConnectionClosedSnafu.build())?; + .context(ConnectionClosedSnafu)?; let frame_data = match frame { Frame::StatusResponse { json } => json, - _ => return Err(ping_error::InvalidResponseSnafu.build()), + _ => return FrameOutOfOrderSnafu.fail(), }; Ok(JavaServerInfo::from_str(&frame_data)?) } #[cfg(feature = "simple")] - pub async fn get_latency(&mut self) -> Result { + pub async fn get_latency(&mut self) -> Result { use std::time::Instant; const PING_PAYLOAD: i64 = 54321; @@ -439,10 +261,50 @@ impl SlpProtocol { let frame = self .read_frame(None) .await? - .ok_or_else(|| ping_error::ConnectionClosedSnafu.build())?; + .context(ConnectionClosedSnafu)?; match frame { Frame::PingResponse { payload: _ } => Ok(ping_time.elapsed()), - _ => Err(ping_error::InvalidResponseSnafu.build()), + _ => FrameOutOfOrderSnafu.fail(), + } + } +} + +#[cfg(feature = "java_connect")] +#[instrument] +pub async fn connect(mut addrs: (String, u16)) -> Result { + use tokio::net::lookup_host; + use tracing::{debug, info}; + use trust_dns_resolver::TokioAsyncResolver; + + let resolver = TokioAsyncResolver::tokio_from_system_conf()?; + if let Ok(records) = resolver + .srv_lookup(format!("_minecraft._tcp.{}", addrs.0)) + .await + { + if let Some(record) = records.iter().next() { + let record = record.target().to_utf8(); + debug!("Found SRV record: {} -> {}", addrs.0, record); + addrs.0 = record; + } + } + + // lookup_host can return multiple but we just need one so we discard the rest + let socket_addrs = match lookup_host(addrs.clone()).await?.next() { + Some(socket_addrs) => socket_addrs, + None => { + info!("DNS lookup failed for address"); + return DNSLookupFailedSnafu { address: addrs.0 }.fail(); + } + }; + + match TcpStream::connect(socket_addrs).await { + Ok(stream) => { + info!("Connected to SLP server"); + Ok(SlpProtocol::new(addrs.0, addrs.1, stream)) + } + Err(error) => { + info!("Failed to connect to SLP server: {}", error); + Err(error.into()) } } } diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs new file mode 100644 index 0000000..1ab7fd6 --- /dev/null +++ b/src/protocol/frame.rs @@ -0,0 +1,160 @@ +use std::{array::TryFromSliceError, io::Cursor}; + +use bytes::Buf; +use mc_varint::{VarInt, VarIntRead}; +use snafu::{Backtrace, OptionExt, Snafu}; +use tracing::trace; + +use crate::mc_string::{decode_mc_string, McStringError}; + +#[derive(Snafu, Debug)] +pub enum FrameError { + #[snafu(display("frame is missing data"))] + Incomplete { backtrace: Backtrace }, + #[snafu(display("io error: {source}"), context(false))] + Io { + source: std::io::Error, + backtrace: Backtrace, + }, + #[snafu(display("frame declares it has negative length"))] + InvalidLength { backtrace: Backtrace }, + #[snafu(display("cannot parse frame with id {id}"))] + InvalidFrame { id: i32, backtrace: Backtrace }, + #[snafu(display("failed to decode string: {source}"), context(false))] + StringDecodeFailed { + #[snafu(backtrace)] + source: McStringError, + }, + #[snafu( + display("failed to decode ping response payload: {source}"), + context(false) + )] + PingResponseDecodeFailed { + source: TryFromSliceError, + backtrace: Backtrace, + }, +} + +#[derive(Debug)] +#[non_exhaustive] +pub enum Frame { + Handshake { + protocol: VarInt, + address: String, + port: u16, + // should be 1 for status + state: VarInt, + }, + StatusRequest, + StatusResponse { + json: String, + }, + PingRequest { + payload: i64, + }, + PingResponse { + payload: i64, + }, +} + +/// Controls what packets a server can recieve +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ServerState { + /// Waiting for the Handshake packet + Handshake, + /// Ready to respond to status and ping requests + Status, +} + +impl Frame { + pub const PROTOCOL_VERSION: i32 = 754; + pub const HANDSHAKE_ID: i32 = 0x00; + pub const STATUS_REQUEST_ID: i32 = 0x00; + pub const STATUS_RESPONSE_ID: i32 = 0x00; + pub const PING_REQUEST_ID: i32 = 0x01; + pub const PING_RESPONSE_ID: i32 = 0x01; + + /// Checks if an entire message can be decoded from `buf`, advancing the cursor past the header + pub fn check(buf: &mut Cursor<&[u8]>) -> Result<(), FrameError> { + let available_data = buf.get_ref().len(); + + // the varint at the beginning contains the size of the rest of the frame + let remaining_data_len: usize = + i32::from(buf.read_var_int().ok().context(IncompleteSnafu)?) + .try_into() + .ok() + .context(InvalidLengthSnafu)?; + let header_len = buf.position() as usize; + let total_len = header_len + remaining_data_len; + + // if we don't have enough data the frame isn't valid yet + let is_valid = available_data >= total_len; + + if is_valid { + trace!("Valid frame, packet size: {total_len}, header size: {header_len}, body size: {remaining_data_len}, downloaded: {available_data}"); + Ok(()) + } else { + trace!("Invalid frame, packet size: {total_len}, downloaded: {available_data}"); + IncompleteSnafu.fail() + } + } + + /// Parse the body of a frame, after the message has already been validated with `check`. + /// + /// # Arguments + /// + /// * `src` - The buffer containing the message + /// * `server_state` - Switches between which type of frame to accept. Set to None to accept frames for the client. + pub fn parse( + cursor: &mut Cursor<&[u8]>, + server_state: Option, + ) -> Result { + let id = i32::from(cursor.read_var_int()?); + + match server_state { + Some(ServerState::Handshake) => { + if id == Self::HANDSHAKE_ID { + let protocol = cursor.read_var_int()?; + let address = decode_mc_string(cursor)?; + let port = cursor.get_u16(); + let state = cursor.read_var_int()?; + return Ok(Frame::Handshake { + protocol, + address, + port, + state, + }); + } + } + Some(ServerState::Status) => { + match id { + Self::STATUS_REQUEST_ID => { + return Ok(Frame::StatusRequest); + } + Self::PING_REQUEST_ID => { + // ping request a contains (usually) meaningless Java long + let payload = cursor.get_i64(); + return Ok(Frame::PingRequest { payload }); + } + _ => {} + } + } + None => { + match id { + Self::STATUS_RESPONSE_ID => { + let json = decode_mc_string(cursor)?; + return Ok(Frame::StatusResponse { json }); + } + Self::PING_RESPONSE_ID => { + // ping response contains the same Java long as the request + let payload = cursor.get_i64(); + return Ok(Frame::PingResponse { payload }); + } + _ => {} + } + } + } + + InvalidFrameSnafu { id }.fail() + } +}