Skip to content

Commit

Permalink
feat!: Typestating of high-layer initiator
Browse files Browse the repository at this point in the history
This is a rather minimal version in that the API is only altered as
necessary -- setting c_r is not deferred yet. Note that this already not
only reduces the size of the Done initiator, but also frees it from
lifetime constraints (because at that point it doesn't need to know the
setup details any more).
  • Loading branch information
chrysn committed Nov 5, 2023
1 parent 4fef28a commit 08d1649
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 48 deletions.
12 changes: 6 additions & 6 deletions examples/coap/src/bin/coapclient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn main() {
// Send Message 1 over CoAP and convert the response to byte
let mut msg_1_buf = Vec::from([0xf5u8]); // EDHOC message_1 when transported over CoAP is prepended with CBOR true
let c_i = generate_connection_identifier_cbor();
let message_1 = initiator.prepare_message_1(c_i).unwrap();
let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap();
msg_1_buf.extend_from_slice(&message_1.content[..message_1.len]);
println!("message_1 len = {}", msg_1_buf.len());

Expand All @@ -37,15 +37,15 @@ fn main() {
println!("response_vec = {:02x?}", response.message.payload);
println!("message_2 len = {}", response.message.payload.len());

let c_r = initiator.process_message_2(
let m2result = initiator.process_message_2(
&response.message.payload[..]
.try_into()
.expect("wrong length"),
);

if c_r.is_ok() {
let mut msg_3 = Vec::from([c_r.unwrap()]);
let (message_3, prk_out) = initiator.prepare_message_3().unwrap();
if let Ok((initiator, c_r)) = m2result {
let mut msg_3 = Vec::from([c_r]);
let (mut initiator, message_3, prk_out) = initiator.prepare_message_3().unwrap();
msg_3.extend_from_slice(&message_3.content[..message_3.len]);
println!("message_3 len = {}", msg_3.len());

Expand Down Expand Up @@ -76,6 +76,6 @@ fn main() {
println!("OSCORE secret after key update: {:02x?}", _oscore_secret);
println!("OSCORE salt after key update: {:02x?}", _oscore_salt);
} else {
panic!("Message 2 processing error: {:#?}", c_r);
panic!("Message 2 processing error: {:#?}", m2result);
}
}
114 changes: 72 additions & 42 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

pub use {
edhoc_consts::State as EdhocState, edhoc_consts::*, edhoc_crypto::default_crypto,
edhoc_crypto_trait::Crypto as CryptoTrait, EdhocInitiatorState as EdhocInitiator,
EdhocResponderState as EdhocResponder,
edhoc_crypto_trait::Crypto as CryptoTrait, EdhocResponderState as EdhocResponder,
};

#[cfg(any(feature = "ead-none", feature = "ead-zeroconf"))]
Expand All @@ -15,14 +14,35 @@ use edhoc::*;

use edhoc_consts::*;

#[derive(Default, Copy, Clone, Debug)]
pub struct EdhocInitiatorState<'a> {
#[derive(Default, Debug)]
pub struct EdhocInitiator<'a> {
state: State, // opaque state
i: &'a [u8], // private authentication key of I
cred_i: &'a [u8], // I's full credential
cred_r: Option<&'a [u8]>, // R's full credential (if provided)
}

#[derive(Default, Debug)]
pub struct EdhocInitiatorWaitM2<'a> {
state: State, // opaque state
i: &'a [u8], // private authentication key of I
cred_i: &'a [u8], // I's full credential
cred_r: Option<&'a [u8]>, // R's full credential (if provided)
}

#[derive(Default, Debug)]
pub struct EdhocInitiatorBuildM3<'a> {
state: State, // opaque state
i: &'a [u8], // private authentication key of I
cred_i: &'a [u8], // I's full credential
cred_r: Option<&'a [u8]>, // R's full credential (if provided)
}

#[derive(Default, Debug)]
pub struct EdhocInitiatorDone {
state: State, // opaque state
}

#[derive(Default, Copy, Clone, Debug)]
pub struct EdhocResponderState<'a> {
state: State, // opaque state
Expand Down Expand Up @@ -146,16 +166,16 @@ impl<'a> EdhocResponderState<'a> {
}
}

impl<'a> EdhocInitiatorState<'a> {
impl<'a> EdhocInitiator<'a> {
pub fn new(
state: State,
i: &'a [u8],
cred_i: &'a [u8],
cred_r: Option<&'a [u8]>,
) -> EdhocInitiatorState<'a> {
) -> EdhocInitiator<'a> {
assert!(i.len() == P256_ELEM_LEN);

EdhocInitiatorState {
EdhocInitiator {
state,
i,
cred_i,
Expand All @@ -164,24 +184,31 @@ impl<'a> EdhocInitiatorState<'a> {
}

pub fn prepare_message_1(
self: &mut EdhocInitiatorState<'a>,
self: EdhocInitiator<'a>,
c_i: u8,
) -> Result<BufferMessage1, EDHOCError> {
) -> Result<(EdhocInitiatorWaitM2<'a>, BufferMessage1), EDHOCError> {
let (x, g_x) = default_crypto().p256_generate_key_pair();

match i_prepare_message_1(self.state, &mut default_crypto(), x, g_x, c_i) {
Ok((state, message_1)) => {
self.state = state;
Ok(message_1)
}
Ok((state, message_1)) => Ok((
EdhocInitiatorWaitM2 {
state,
i: self.i,
cred_i: self.cred_i,
cred_r: self.cred_r,
},
message_1,
)),
Err(error) => Err(error),
}
}
}

impl<'a> EdhocInitiatorWaitM2<'a> {
pub fn process_message_2(
self: &mut EdhocInitiatorState<'a>,
self,
message_2: &BufferMessage2,
) -> Result<u8, EDHOCError> {
) -> Result<(EdhocInitiatorBuildM3<'a>, u8), EDHOCError> {
match i_process_message_2(
self.state,
&mut default_crypto(),
Expand All @@ -191,33 +218,41 @@ impl<'a> EdhocInitiatorState<'a> {
.try_into()
.expect("Wrong length of initiator private key"),
) {
Ok((state, c_r, _kid)) => {
self.state = state;
Ok(c_r)
}
Ok((state, c_r, _kid)) => Ok((
EdhocInitiatorBuildM3 {
state,
i: self.i,
cred_i: self.cred_i,
cred_r: self.cred_r,
},
c_r,
)),
Err(error) => Err(error),
}
}
}

impl<'a> EdhocInitiatorBuildM3<'a> {
pub fn prepare_message_3(
self: &mut EdhocInitiatorState<'a>,
) -> Result<(BufferMessage3, [u8; SHA256_DIGEST_LEN]), EDHOCError> {
self,
) -> Result<(EdhocInitiatorDone, BufferMessage3, [u8; SHA256_DIGEST_LEN]), EDHOCError> {
match i_prepare_message_3(
self.state,
&mut default_crypto(),
&get_id_cred(self.cred_i),
self.cred_i,
) {
Ok((state, message_3, prk_out)) => {
self.state = state;
Ok((message_3, prk_out))
Ok((EdhocInitiatorDone { state }, message_3, prk_out))
}
Err(error) => Err(error),
}
}
}

impl EdhocInitiatorDone {
pub fn edhoc_exporter(
self: &mut EdhocInitiatorState<'a>,
&mut self,
label: u8,
context: &[u8],
length: usize,
Expand All @@ -242,7 +277,7 @@ impl<'a> EdhocInitiatorState<'a> {
}

pub fn edhoc_key_update(
self: &mut EdhocInitiatorState<'a>,
&mut self,
context: &[u8],
) -> Result<[u8; SHA256_DIGEST_LEN], EDHOCError> {
let mut context_buf = [0x00u8; MAX_KDF_CONTEXT_LEN];
Expand Down Expand Up @@ -311,8 +346,8 @@ mod test {
#[test]
fn test_new_initiator() {
let state: EdhocState = Default::default();
let _initiator = EdhocInitiatorState::new(state, I, CRED_I, Some(CRED_R));
let _initiator = EdhocInitiatorState::new(state, I, CRED_I, None);
let _initiator = EdhocInitiator::new(state, I, CRED_I, Some(CRED_R));
let _initiator = EdhocInitiator::new(state, I, CRED_I, None);
}

#[test]
Expand All @@ -325,7 +360,7 @@ mod test {
#[test]
fn test_prepare_message_1() {
let state: EdhocState = Default::default();
let mut initiator = EdhocInitiatorState::new(state, I, CRED_I, Some(CRED_R));
let mut initiator = EdhocInitiator::new(state, I, CRED_I, Some(CRED_R));

let c_i = generate_connection_identifier_cbor();
let message_1 = initiator.prepare_message_1(c_i);
Expand Down Expand Up @@ -359,15 +394,14 @@ mod test {
#[test]
fn test_handshake() {
let state_initiator: EdhocState = Default::default();
let mut initiator = EdhocInitiatorState::new(state_initiator, I, CRED_I, Some(CRED_R));
let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, Some(CRED_R));
let state_responder: EdhocState = Default::default();
let mut responder = EdhocResponderState::new(state_responder, R, CRED_R, Some(CRED_I));

let c_i: u8 = generate_connection_identifier_cbor();
let result = initiator.prepare_message_1(c_i); // to update the state
assert!(result.is_ok());
let (initiator, result) = initiator.prepare_message_1(c_i).unwrap(); // to update the state

let error = responder.process_message_1(&result.unwrap());
let error = responder.process_message_1(&result);
assert!(error.is_ok());

let c_r = generate_connection_identifier_cbor();
Expand All @@ -377,13 +411,9 @@ mod test {
let message_2 = ret.unwrap();

assert!(c_r != 0xff);
let _c_r = initiator.process_message_2(&message_2);
assert!(_c_r.is_ok());

let ret = initiator.prepare_message_3();
assert!(ret.is_ok());
let (initiator, _) = initiator.process_message_2(&message_2).unwrap();

let (message_3, i_prk_out) = ret.unwrap();
let (mut initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap();

let r_prk_out = responder.process_message_3(&message_3);
assert!(r_prk_out.is_ok());
Expand Down Expand Up @@ -437,7 +467,7 @@ mod test {
#[test]
fn test_ead_zeroconf() {
let state_initiator: EdhocState = Default::default();
let mut initiator = EdhocInitiatorState::new(state_initiator, I, CRED_I, None);
let mut initiator = EdhocInitiator::new(state_initiator, I, CRED_I, None);
let state_responder: EdhocState = Default::default();
let mut responder = EdhocResponderState::new(state_responder, R, CRED_V_TV, Some(CRED_I));

Expand Down Expand Up @@ -467,7 +497,7 @@ mod test {
));

let c_i = generate_connection_identifier_cbor();
let message_1 = initiator.prepare_message_1(c_i).unwrap();
let (initiator, message_1) = initiator.prepare_message_1(c_i).unwrap();
assert_eq!(
ead_initiator_state.protocol_state,
EADInitiatorProtocolState::WaitEAD2
Expand All @@ -486,14 +516,14 @@ mod test {
EADResponderProtocolState::Completed
);

initiator.process_message_2(&message_2).unwrap();
let (initiator, _) = initiator.process_message_2(&message_2).unwrap();

assert_eq!(
ead_initiator_state.protocol_state,
EADInitiatorProtocolState::Completed
);

let (message_3, i_prk_out) = initiator.prepare_message_3().unwrap();
let (initiator, message_3, i_prk_out) = initiator.prepare_message_3().unwrap();

let r_prk_out = responder.process_message_3(&message_3).unwrap();
assert_eq!(i_prk_out, r_prk_out);
Expand Down

0 comments on commit 08d1649

Please sign in to comment.