Skip to content

Commit

Permalink
added bytes to LookupResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
joalopez1206 committed Dec 11, 2024
1 parent f70bf7c commit 340ae03
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 50 deletions.
49 changes: 25 additions & 24 deletions src/async_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ impl AsyncResolver {
// return the error, it should go to the next part of the code
};
if let Some(cache_lookup) = cache.clone().get(query.clone()) {
let new_lookup_response = LookupResponse::new(cache_lookup.clone());
let new_lookup_response = LookupResponse::new(cache_lookup.clone(), cache_lookup.to_bytes());

return Ok(new_lookup_response);
}
Expand Down Expand Up @@ -436,7 +436,8 @@ impl AsyncResolver {
Err(_) => Err(ClientError::TemporaryError("no DNS message found")),
};

let dns_response = lookup_response.unwrap().to_dns_msg();
let lookup_response = lookup_response.expect("error!");
let dns_response = lookup_response.to_dns_msg();

let key_bytes = self.config.get_key();
let shared_key_name = self.config.get_key_name();
Expand All @@ -455,7 +456,7 @@ impl AsyncResolver {
);

match rcode {
Rcode::NOERROR => Ok(LookupResponse::new(dns_response)),
Rcode::NOERROR => Ok(LookupResponse::new(dns_response, lookup_response.get_bytes())),
Rcode::FORMERR => Err(ClientError::FormatError("The name server was unable to interpret the query."))?,
Rcode::SERVFAIL => Err(ClientError::ServerFailure("The name server was unable to process this query due to a problem with the name server."))?,
Rcode::NXDOMAIN => Err(ClientError::NameError("The domain name referenced in the query does not exist."))?,
Expand Down Expand Up @@ -909,7 +910,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1138,7 +1139,7 @@ mod async_resolver_test {
header.set_qr(true);
header.set_rcode(Rcode::FORMERR);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1185,7 +1186,7 @@ mod async_resolver_test {
header.set_qr(true);
header.set_rcode(Rcode::SERVFAIL);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1231,7 +1232,7 @@ mod async_resolver_test {
header.set_qr(true);
header.set_rcode(Rcode::NXDOMAIN);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1278,7 +1279,7 @@ mod async_resolver_test {
header.set_qr(true);
header.set_rcode(Rcode::NOTIMP);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1325,7 +1326,7 @@ mod async_resolver_test {
header.set_qr(true);
header.set_rcode(Rcode::REFUSED);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1372,7 +1373,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1411,7 +1412,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_lookup = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_lookup {
Expand Down Expand Up @@ -1450,7 +1451,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1489,7 +1490,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1528,7 +1529,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1567,7 +1568,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1606,7 +1607,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1645,7 +1646,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1684,7 +1685,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1723,7 +1724,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1762,7 +1763,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1801,7 +1802,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1840,7 +1841,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1879,7 +1880,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down Expand Up @@ -1918,7 +1919,7 @@ mod async_resolver_test {
let mut header = dns_response.get_header();
header.set_qr(true);
dns_response.set_header(header);
let lookup_response = LookupResponse::new(dns_response);
let lookup_response = LookupResponse::new(dns_response, vec![]);
let result_vec_rr = resolver.check_error_from_msg(Ok(lookup_response));

if let Ok(lookup_response) = result_vec_rr {
Expand Down
49 changes: 29 additions & 20 deletions src/async_resolver/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct LookupStrategy {
/// Resolver configuration.
config: ResolverConfig,
/// Reference to the response of the query.
response_msg: Arc<std::sync::Mutex<Result<DnsMessage, ResolverError>>>,
response_msg: Arc<std::sync::Mutex<Result<(DnsMessage, Vec<u8>), ResolverError>>>,
}

impl LookupStrategy {
Expand Down Expand Up @@ -115,7 +115,7 @@ impl LookupStrategy {
// appropriate response to its caller.
pub fn received_appropriate_response(&self) -> bool {
let response_arc = self.response_msg.lock().unwrap();
if let Ok(dns_msg) = response_arc.as_ref() {
if let Ok((dns_msg, bytes)) = response_arc.as_ref() {
match dns_msg.get_header().get_rcode().into() {
Rcode::SERVFAIL => return false,
Rcode::NOTIMP => return false,
Expand Down Expand Up @@ -191,7 +191,7 @@ impl LookupStrategy {
) -> Result<LookupResponse, ResolverError> {
let response_arc= self.response_msg.clone();
let protocol = self.config.get_protocol();
let mut dns_msg_result: Result<DnsMessage, ResolverError>;
let mut dns_msg_result: Result<(DnsMessage, Vec<u8>), ResolverError>;
{
// Guard reference to modify the response
let mut response_guard = response_arc.lock().unwrap(); // TODO: add error handling
Expand All @@ -205,12 +205,16 @@ impl LookupStrategy {
.await
.unwrap_or_else(
|_| {Err(ResolverError::Message("Execute Strategy Timeout Error".into()))}
);
);
*response_guard = dns_msg_result.clone();
//*response_guard = dns_msg_result.clone();
}
if self.received_appropriate_response() {
return dns_msg_result.and_then(
|dns_msg| Ok(LookupResponse::new(dns_msg))
|dns_msg| {
let (dns_msg, bytes) = dns_msg;
Ok(LookupResponse::new(dns_msg, bytes))
}
)
}
if let ConnectionProtocol::UDP = protocol {
Expand All @@ -231,7 +235,10 @@ impl LookupStrategy {
*response_guard = dns_msg_result.clone();
}
dns_msg_result.and_then(
|dns_msg| Ok(LookupResponse::new(dns_msg))
|dns_msg| {
let (dns_msg, bytes) = dns_msg;
Ok(LookupResponse::new(dns_msg, bytes))
}
)
}
}
Expand All @@ -247,7 +254,7 @@ async fn send_query_by_protocol(
query: &DnsMessage,
protocol: ConnectionProtocol,
server_info: &ServerInfo,
) -> Result<DnsMessage, ResolverError> {
) -> Result<(DnsMessage, Vec<u8>), ResolverError> {
let query_id = query.get_query_id();
let dns_query = query.clone();
let dns_msg_result;
Expand All @@ -257,16 +264,17 @@ async fn send_query_by_protocol(
udp_connection.set_timeout(timeout);
let response_result = udp_connection.send(dns_query).await;
dns_msg_result = parse_response(response_result, query_id);
dns_msg_result
}
ConnectionProtocol::TCP => {
let mut tcp_connection = server_info.get_tcp_connection().clone();
tcp_connection.set_timeout(timeout);
let response_result = tcp_connection.send(dns_query).await;
dns_msg_result = parse_response(response_result, query_id);
dns_msg_result
}
_ => {dns_msg_result = Err(ResolverError::Message("Invalid Protocol".into()))}, // TODO: specific add error handling
};
dns_msg_result
_ => Err(ResolverError::Message("Invalid Protocol".into())), // TODO: specific add error handling
}
}

/// Parse the received response datagram to a `DnsMessage`.
Expand All @@ -288,25 +296,26 @@ async fn send_query_by_protocol(
/// excessively long TTL, say greater than 1 week, either discard
/// the whole response, or limit all TTLs in the response to 1
/// week.
fn parse_response(response_result: Result<Vec<u8>, ClientError>, query_id:u16) -> Result<DnsMessage, ResolverError> {
let dns_msg = response_result.map_err(Into::into)
.and_then(|response_message| {
DnsMessage::from_bytes(&response_message)
.map_err(|_| ResolverError::Parse("The name server was unable to interpret the query.".to_string()))
})?;
fn parse_response(response_result: Result<Vec<u8>, ClientError>, query_id:u16) -> Result<(DnsMessage, Vec<u8>), ResolverError> {
let response_msg = response_result.map_err(Into::<ResolverError>::into)?;

let dns_msg = DnsMessage::from_bytes(&response_msg).
map_err(|_| ResolverError::Parse("The name server was unable to interpret the query.".to_string()))?;

let header = dns_msg.get_header();

// check Header
header.format_check()
.map_err(|e| ResolverError::Parse(format!("Error formated Header: {}", e)))?;
header
.format_check()
.map_err(|e| ResolverError::Parse(format!("Error formated Header: {}", e)))?;

// Check ID
if dns_msg.get_query_id() != query_id {
return Err(ResolverError::Parse("Error expected ID from query".to_string()))
}

if header.get_qr() {
return Ok(dns_msg);
return Ok((dns_msg, response_msg));
}
Err(ResolverError::Parse("Message is a query. A response was expected.".to_string()))
}
Expand Down Expand Up @@ -555,7 +564,7 @@ mod async_resolver_test {
let response_dns_msg = parse_response(response_result,query_id);
println!("[###############] {:?}",response_dns_msg);
assert!(response_dns_msg.is_ok());
if let Ok(dns_msg) = response_dns_msg {
if let Ok((dns_msg,_)) = response_dns_msg {
assert_eq!(dns_msg.get_header().get_qr(), true); // response (1)
assert_eq!(dns_msg.get_header().get_ancount(), 1);
assert_eq!(dns_msg.get_header().get_rcode(), Rcode::NOERROR);
Expand Down
Loading

0 comments on commit 340ae03

Please sign in to comment.