Skip to content

Commit

Permalink
Rewrite kex as a state machine (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugeny authored Jan 5, 2025
1 parent 4f0a0d4 commit e0bc545
Show file tree
Hide file tree
Showing 27 changed files with 1,403 additions and 1,197 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,6 @@ jobs:

- name: Check with minimal dependency versions
run: |
rustup toolchain add 1.65.0
cargo +1.65.0 minimal-versions check --all-features --no-dev-deps
rustup toolchain add 1.73.0
# minimal-versions cannot correctly resolve the ssh-key dep on Rust <1.73
cargo +1.73.0 minimal-versions check --all-features --no-dev-deps
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ signature = "2.2"
ssh-encoding = { version = "0.2", features = [
"bytes",
] }
ssh-key = { version = "0.6.7-internal.6", features = [
ssh-key = { version = "=0.6.8+upstream-0.6.7", features = [
"ed25519",
"rsa",
"rsa-sha1",
Expand Down
11 changes: 10 additions & 1 deletion cryptovec/src/cryptovec.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
use std::fmt::Debug;
use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo};

use crate::platform::{self, memset, mlock, munlock};

/// A buffer which zeroes its memory on `.clear()`, `.resize()`, and
/// reallocations, to avoid copying secrets around.
#[derive(Debug)]
pub struct CryptoVec {
p: *mut u8, // `pub(crate)` allows access from platform modules
size: usize,
capacity: usize,
}

impl Debug for CryptoVec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.size == 0 {
return f.write_str("<empty>");
}
write!(f, "<{:?}>", self.size)
}
}

impl Unpin for CryptoVec {}
unsafe impl Send for CryptoVec {}
unsafe impl Sync for CryptoVec {}
Expand Down
11 changes: 1 addition & 10 deletions russh-keys/src/format/pkcs8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,7 @@ pub fn encode_pkcs8(key: &ssh_key::PrivateKey) -> Result<Vec<u8>, Error> {
sk.to_pkcs8_der()?
}
ssh_key::private::KeypairData::Rsa(ref pair) => {
// TODO: Implementation in ssh-key 0.6.7 is broken (fixed in 0.7.0-pre)
let sk = rsa::RsaPrivateKey::from_components(
rsa::BigUint::try_from(&pair.public.n)?,
rsa::BigUint::try_from(&pair.public.e)?,
rsa::BigUint::try_from(&pair.private.d)?,
vec![
rsa::BigUint::try_from(&pair.private.p)?,
rsa::BigUint::try_from(&pair.private.q)?,
],
)?;
let sk: rsa::RsaPrivateKey = pair.try_into()?;
sk.to_pkcs8_der()?
}
ssh_key::private::KeypairData::Ecdsa(ref pair) => match pair {
Expand Down
1 change: 1 addition & 0 deletions russh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ thiserror.workspace = true
russh-util = { version = "0.48.0", path = "../russh-util" }
des = { version = "0.8.1", optional = true }
tokio = { workspace = true, features = ["io-util", "sync", "time"] }
enum_dispatch = "0.3.13"

[dev-dependencies]
anyhow = "1.0.4"
Expand Down
13 changes: 7 additions & 6 deletions russh/src/cert.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use core::str;

use russh_keys::helpers::AlgorithmExt;
use russh_keys::key::PrivateKeyWithHashAlg;
use ssh_encoding::Decode;
use ssh_key::public::KeyData;
use ssh_key::{Algorithm, Certificate, HashAlg, PublicKey};
use ssh_key::{Certificate, HashAlg, PublicKey};
#[cfg(not(target_arch = "wasm32"))]
use {
russh_keys::helpers::AlgorithmExt, ssh_encoding::Decode, ssh_key::public::KeyData,
ssh_key::Algorithm,
};

#[derive(Debug)]
pub(crate) enum PublicKeyOrCertificate {
Expand All @@ -25,6 +25,7 @@ impl From<&PrivateKeyWithHashAlg> for PublicKeyOrCertificate {
}

impl PublicKeyOrCertificate {
#[cfg(not(target_arch = "wasm32"))]
pub fn decode(pubkey_algo: &str, buf: &[u8]) -> Result<Self, ssh_key::Error> {
let mut reader = buf;
match Algorithm::new_certificate_ext(pubkey_algo) {
Expand Down
20 changes: 10 additions & 10 deletions russh/src/cipher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use aes::{Aes128, Aes192, Aes256};
use byteorder::{BigEndian, ByteOrder};
use ctr::Ctr128BE;
use delegate::delegate;
use log::debug;
use log::trace;
use once_cell::sync::Lazy;
use ssh_encoding::Encode;
use tokio::io::{AsyncRead, AsyncReadExt};
Expand Down Expand Up @@ -210,12 +210,12 @@ pub(crate) trait SealingKey {
//
// The variables `payload`, `packet_length` and `padding_length` refer
// to the protocol fields of the same names.
debug!("writing, seqn = {:?}", buffer.seqn.0);
trace!("writing, seqn = {:?}", buffer.seqn.0);

let padding_length = self.padding_length(payload);
debug!("padding length {:?}", padding_length);
trace!("padding length {:?}", padding_length);
let packet_length = PADDING_LENGTH_LEN + payload.len() + padding_length;
debug!("packet_length {:?}", packet_length);
trace!("packet_length {:?}", packet_length);
let offset = buffer.buffer.len();

// Maximum packet length:
Expand Down Expand Up @@ -252,12 +252,12 @@ pub(crate) async fn read<R: AsyncRead + Unpin>(
let mut len = vec![0; cipher.packet_length_to_read_for_block_length()];

stream.read_exact(&mut len).await?;
debug!("reading, len = {:?}", len);
trace!("reading, len = {:?}", len);
{
let seqn = buffer.seqn.0;
buffer.buffer.clear();
buffer.buffer.extend(&len);
debug!("reading, seqn = {:?}", seqn);
trace!("reading, seqn = {:?}", seqn);
let len = cipher.decrypt_packet_length(seqn, &len);
let len = BigEndian::read_u32(&len) as usize;

Expand All @@ -266,26 +266,26 @@ pub(crate) async fn read<R: AsyncRead + Unpin>(
}

buffer.len = len + cipher.tag_len();
debug!("reading, clear len = {:?}", buffer.len);
trace!("reading, clear len = {:?}", buffer.len);
}
}

buffer.buffer.resize(buffer.len + 4);
debug!("read_exact {:?}", buffer.len + 4);
trace!("read_exact {:?}", buffer.len + 4);

let l = cipher.packet_length_to_read_for_block_length();

#[allow(clippy::indexing_slicing)] // length checked
stream.read_exact(&mut buffer.buffer[l..]).await?;

debug!("read_exact done");
trace!("read_exact done");
let seqn = buffer.seqn.0;
let ciphertext_len = buffer.buffer.len() - cipher.tag_len();
let (ciphertext, tag) = buffer.buffer.split_at_mut(ciphertext_len);
let plaintext = cipher.open(seqn, ciphertext, tag)?;

let padding_length = *plaintext.first().to_owned().unwrap_or(&0) as usize;
debug!("reading, padding_length {:?}", padding_length);
trace!("reading, padding_length {:?}", padding_length);
let plaintext_end = plaintext
.len()
.checked_sub(padding_length)
Expand Down
139 changes: 13 additions & 126 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,21 @@
//
use std::cell::RefCell;
use std::convert::TryInto;
use std::num::Wrapping;
use std::ops::Deref;

use bytes::Bytes;
use log::{debug, error, info, trace, warn};
use russh_keys::helpers::{map_err, sign_with_hash_alg, AlgorithmExt, EncodedExt};
use ssh_encoding::{Decode, Encode};

use super::IncomingSshPacket;
use crate::cert::PublicKeyOrCertificate;
use crate::client::{Handler, Msg, Prompt, Reply, Session};
use crate::keys::key::parse_public_key;
use crate::negotiation::Select;
use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage};
use crate::session::{Encrypted, EncryptedState, GlobalRequestResponse, Kex, KexInit};
use crate::session::{Encrypted, EncryptedState, GlobalRequestResponse};
use crate::{
auth, msg, negotiation, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams,
CryptoVec, Sig,
auth, msg, Channel, ChannelId, ChannelMsg, ChannelOpenFailure, ChannelParams, CryptoVec, Sig,
};

thread_local! {
Expand All @@ -41,136 +39,26 @@ impl Session {
pub(crate) async fn client_read_encrypted<H: Handler>(
&mut self,
client: &mut H,
seqn: &mut Wrapping<u32>,
buf: &[u8],
pkt: &mut IncomingSshPacket,
) -> Result<(), H::Error> {
#[allow(clippy::indexing_slicing)] // length checked
{
trace!(
"client_read_encrypted, buf = {:?}",
&buf[..buf.len().min(20)]
&pkt.buffer[..pkt.buffer.len().min(20)]
);
}
// Either this packet is a KEXINIT, in which case we start a key re-exchange.
if buf.first() == Some(&msg::KEXINIT) {
debug!("Received KEXINIT");
// Now, if we're encrypted:
if let Some(ref mut enc) = self.common.encrypted {
// If we're not currently re-keying, but buf is a rekey request
let kexinit = if let Some(Kex::Init(kexinit)) = enc.rekey.take() {
Some(kexinit)
} else if let Some(exchange) = enc.exchange.take() {
Some(KexInit::received_rekey(
exchange,
negotiation::Client::read_kex(
buf,
&self.common.config.as_ref().preferred,
None,
)?,
&enc.session_id,
))
} else {
None
};

if let Some(mut kexinit) = kexinit {
if let Some(ref mut algo) = kexinit.algo {
algo.strict_kex = algo.strict_kex || self.common.strict_kex;
}

let dhdone = kexinit.client_parse(
self.common.config.as_ref(),
&mut *self.common.cipher.local_to_remote,
buf,
&mut self.common.write_buffer,
)?;

if !enc.kex.skip_exchange() {
enc.rekey = Some(Kex::DhDone(dhdone));
}
}
} else {
unreachable!()
}
self.flush()?;
return Ok(());
}

if let Some(ref mut enc) = self.common.encrypted {
match enc.rekey.take() {
Some(Kex::DhDone(mut kexdhdone)) => {
return if kexdhdone.names.ignore_guessed {
kexdhdone.names.ignore_guessed = false;
enc.rekey = Some(Kex::DhDone(kexdhdone));
Ok(())
} else if buf.first() == Some(&msg::KEX_ECDH_REPLY) {
// We've sent ECDH_INIT, waiting for ECDH_REPLY

#[allow(clippy::indexing_slicing)] // length checked
let kex = kexdhdone
.server_key_check(true, client, &mut &buf[1..])
.await?;

enc.rekey = Some(Kex::Keys(kex));
self.common
.cipher
.local_to_remote
.write(&[msg::NEWKEYS], &mut self.common.write_buffer);
self.flush()?;
self.common.maybe_reset_seqn();
Ok(())
} else {
error!("Wrong packet received");
Err(crate::Error::Inconsistent.into())
};
}
Some(Kex::Keys(newkeys)) => {
if buf.first() != Some(&msg::NEWKEYS) {
return Err(crate::Error::Kex.into());
}
self.common.write_buffer.bytes = 0;
enc.last_rekey = russh_util::time::Instant::now();

// Ok, NEWKEYS received, now encrypted.
enc.flush_all_pending()?;
let mut pending = std::mem::take(&mut self.pending_reads);
for p in pending.drain(..) {
self.process_packet(client, &p).await?;
}
self.pending_reads = pending;
self.pending_len = 0;
self.common.newkeys(newkeys);
self.flush()?;

if self.common.strict_kex {
*seqn = Wrapping(0);
}

return Ok(());
}
Some(Kex::Init(k)) => {
enc.rekey = Some(Kex::Init(k));
self.pending_len += buf.len() as u32;
if self.pending_len > 2 * self.target_window_size {
return Err(crate::Error::Pending.into());
}
self.pending_reads.push(CryptoVec::from_slice(buf));
return Ok(());
}
rek => enc.rekey = rek,
}
}
self.process_packet(client, buf).await
self.process_packet(client, &pkt.buffer).await
}

async fn process_packet<H: Handler>(
pub(crate) async fn process_packet<H: Handler>(
&mut self,
client: &mut H,
buf: &[u8],
) -> Result<(), H::Error> {
// If we've successfully read a packet.
trace!("process_packet buf = {:?} bytes", buf.len());
trace!("buf = {:?}", buf);
let mut is_authenticated = false;
if let Some(ref mut enc) = self.common.encrypted {
match enc.state {
Expand Down Expand Up @@ -623,11 +511,10 @@ impl Session {
}
}
Some((&msg::CHANNEL_WINDOW_ADJUST, mut r)) => {
debug!("channel_window_adjust");
let channel_num = map_err!(ChannelId::decode(&mut r))?;
let amount = map_err!(u32::decode(&mut r))?;
let mut new_size = 0;
debug!("amount: {:?}", amount);
debug!("channel_window_adjust amount: {:?}", amount);
if let Some(ref mut enc) = self.common.encrypted {
if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) {
channel.recipient_window_size += amount;
Expand Down Expand Up @@ -918,11 +805,11 @@ impl Session {
} => {
debug!("sending ssh-userauth service requset");
if !*sent {
let p = b"\x05\0\0\0\x0Cssh-userauth";
self.common
.cipher
.local_to_remote
.write(p, &mut self.common.write_buffer);
self.common.packet_writer.packet(|w| {
msg::SERVICE_REQUEST.encode(w)?;
"ssh-userauth".encode(w)?;
Ok(())
})?;
*sent = true
}
accepted
Expand Down
Loading

0 comments on commit e0bc545

Please sign in to comment.