diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index c03e8ade04..af23e85dca 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -263,38 +263,12 @@ impl NetworkDriver for VirtioNetDriver { }; let mut header = Box::new_in(::default(), DeviceAlloc); - // If a checksum isn't necessary, we have inform the host within the header - // see Virtio specification 5.1.6.2 - if !self.checksums.tcp.tx() || !self.checksums.udp.tx() { + + if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) { header.flags = HdrF::NEEDS_CSUM; - let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> = - EthernetFrame::new_unchecked(&packet); - let packet_header_len: u16; - let protocol; - match ethernet_frame.ethertype() { - smoltcp::wire::EthernetProtocol::Ipv4 => { - let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload()); - packet_header_len = packet.header_len().into(); - protocol = Some(packet.next_header()); - } - smoltcp::wire::EthernetProtocol::Ipv6 => { - let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload()); - packet_header_len = packet.header_len().try_into().unwrap(); - protocol = Some(packet.next_header()); - } - _ => { - packet_header_len = 0; - protocol = None; - } - } header.csum_start = - (u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into(); - header.csum_offset = match protocol { - Some(smoltcp::wire::IpProtocol::Tcp) => 16, - Some(smoltcp::wire::IpProtocol::Udp) => 6, - _ => 0, - } - .into(); + (u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into(); + header.csum_offset = csum_offset.into(); } let buff_tkn = AvailBufferToken::new( @@ -778,6 +752,87 @@ impl VirtioNetDriver { Ok(()) } + + /// Sets the TCP or UDP checksum field to the checksum of the psuedo-header if necessary or returns None otherwise. + fn should_request_checksum + AsMut<[u8]>>( + &self, + frame: T, + ) -> Option<(u16, u16)> { + if !self.checksums.tcp.tx() || !self.checksums.udp.tx() { + // If a checksum calculation by the host is necessary, we have to inform the host within the header + // see Virtio specification 5.1.6.2 + let mut ethernet_frame = EthernetFrame::new_unchecked(frame); + // If the Ethernet protocol is not one of these two, we default to not asking for checksum, + // as otherwise the frame will be corrupted by the device trying to write the checksum. + if let ip @ (smoltcp::wire::EthernetProtocol::Ipv4 + | smoltcp::wire::EthernetProtocol::Ipv6) = ethernet_frame.ethertype() + { + let ip_header_len: u16; + let ip_packet_len: usize; + let protocol; + let pseudo_header_checksum; + match ip { + smoltcp::wire::EthernetProtocol::Ipv4 => { + let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut()); + ip_header_len = ip_packet.header_len().into(); + ip_packet_len = ip_packet.total_len().into(); + protocol = ip_packet.next_header(); + pseudo_header_checksum = + partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet); + } + smoltcp::wire::EthernetProtocol::Ipv6 => { + let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut()); + ip_header_len = ip_packet.header_len().try_into().expect( + "VIRTIO does not support IP headers that are longer than u16::MAX bytes.", + ); + ip_packet_len = ip_packet.total_len(); + protocol = ip_packet.next_header(); + pseudo_header_checksum = + partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet); + } + _ => unreachable!(), + } + // Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field. + if let smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp = protocol { + let ip_payload = + &mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len]; + + // We do not care about the offset of the checksum for the protocol if we don't require checksum + // from the host, so we use None to signal that checksum from the host is not neeeded. + let csum_offset = match protocol { + smoltcp::wire::IpProtocol::Tcp => { + if !self.checksums.tcp.tx() { + let mut tcp_packet = + smoltcp::wire::TcpPacket::new_unchecked(ip_payload); + tcp_packet.set_checksum(pseudo_header_checksum); + Some(16) + } else { + None + } + } + smoltcp::wire::IpProtocol::Udp => { + if !self.checksums.tcp.tx() { + let mut udp_packet = + smoltcp::wire::UdpPacket::new_unchecked(ip_payload); + udp_packet.set_checksum(pseudo_header_checksum); + Some(6) + } else { + None + } + } + _ => None, + }; + csum_offset.map(|csum_offset| (ip_header_len, csum_offset)) + } else { + None + } + } else { + None + } + } else { + None + } + } } pub mod constants { @@ -802,3 +857,66 @@ pub mod error { IncompatibleFeatureSets(virtio::net::F, virtio::net::F), } } + +/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header +/// and their results are meant to be combined with the TCP payload to calculate the real checksum. +/// They are only useful for the VIRTIO driver with the checksum offloading feature. +/// +/// The calculations here can theoritically be made faster by exploiting the properties described in +/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071). +mod partial_checksum { + use core::iter; + + use smoltcp::wire::{Ipv4Packet, Ipv6Packet}; + + /// Calculates the checksum for the IPv4 pseudo-header as described in + /// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion. + pub(super) fn ipv4_pseudo_header_partial_checksum>( + packet: &Ipv4Packet, + ) -> u16 { + let src_addr = packet.src_addr(); + let dst_addr = packet.dst_addr(); + let address_words = src_addr + .as_bytes() + .iter() + .chain(dst_addr.as_bytes()) + .copied() + .array_chunks::<{ size_of::() }>() + .map(u16::from_be_bytes); + let padded_protocol = u16::from(u8::from(packet.next_header())); + let payload_len = packet.total_len() - u16::from(packet.header_len()); + address_words + .chain(iter::once(padded_protocol)) + .chain(iter::once(payload_len)) + .fold(0u16, ones_complement_add) + } + + /// Calculates the checksum for the IPv6 pseudo-header as described in + /// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion. + pub(super) fn ipv6_pseudo_header_partial_checksum>( + packet: &Ipv6Packet, + ) -> u16 { + warn!("The IPv6 partial checksum implementation is untested!"); + let src_addr = packet.src_addr(); + let dst_addr = packet.dst_addr(); + let payload_len = packet.payload_len(); + let padded_protocol = u16::from(u8::from(packet.next_header())); + + src_addr + .as_bytes() + .iter() + .chain(dst_addr.as_bytes()) + .copied() + .array_chunks::<{ size_of::() }>() + .map(u16::from_be_bytes) + .chain(iter::once(payload_len)) + .chain(iter::once(padded_protocol)) + .fold(0u16, ones_complement_add) + } + + /// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1). + fn ones_complement_add(lhs: u16, rhs: u16) -> u16 { + let (sum, overflow) = u16::overflowing_add(lhs, rhs); + sum + u16::from(overflow) + } +} diff --git a/src/lib.rs b/src/lib.rs index 62caa3d277..44672e8050 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ )] #![cfg_attr(target_arch = "x86_64", feature(abi_x86_interrupt))] #![feature(allocator_api)] +#![feature(iter_array_chunks)] #![feature(linked_list_cursors)] #![feature(map_try_insert)] #![feature(maybe_uninit_as_bytes)]