Skip to content

Pid to str no result #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 106 additions & 15 deletions dice-mfg-msgs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#![cfg_attr(not(any(test, feature = "std")), no_std)]
#![cfg_attr(not(feature = "std"), no_std)]

#[cfg(feature = "std")]
use const_oid::db::rfc4519::COMMON_NAME;
Expand Down Expand Up @@ -90,6 +90,30 @@ pub enum PlatformIdError {
Malformed,
}

// `thiserror` is used to derive a `Display` impl when we have access to the
// standard library. When we build for `no_std` we use this compatible
// `Display` impl to satisfy the serde try_from container attribute used on the
// `PlatformId` type.
#[cfg(not(feature = "std"))]
impl fmt::Display for PlatformIdError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
PlatformIdError::BadSize => {
write!(f, "PlatformId string is the wrong length")
}
PlatformIdError::Invalid { .. } => {
write!(f, "invalid character in PlatformId")
}
PlatformIdError::InvalidPrefix => {
write!(f, "Unknown prefix on PlatformId")
}
PlatformIdError::Malformed => {
write!(f, "PlatformId string contains non-UTF8 characters")
}
}
}
}

// see RFD 308 § 4.3.1
// 0XV2:PPP-PPPPPPP:RRR:LLLWWYYSSSS
const PREFIX_LEN: usize = 4;
Expand All @@ -107,6 +131,8 @@ pub const PLATFORM_ID_MAX_LEN: usize = PLATFORM_ID_V1_LEN;
Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize, SerializedSize,
)]
#[repr(C)]
// have serde use the TryFrom to validate the PlatformId when deserializing
#[serde(try_from = "[u8; PLATFORM_ID_MAX_LEN]")]
pub struct PlatformId([u8; PLATFORM_ID_MAX_LEN]);

fn validate_0xv2(s: &str) -> Result<(), PlatformIdError> {
Expand Down Expand Up @@ -263,6 +289,24 @@ impl TryFrom<&[u8]> for PlatformId {
}
}

impl TryFrom<[u8; PLATFORM_ID_MAX_LEN]> for PlatformId {
type Error = PlatformIdError;

fn try_from(b: [u8; PLATFORM_ID_MAX_LEN]) -> Result<Self, Self::Error> {
let pid = core::str::from_utf8(&b[..])
.map_err(|_| PlatformIdError::Malformed)?;
let pid = pid.trim_end_matches('\0');

Self::try_from(pid)
}
}

impl fmt::Display for PlatformId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}

#[cfg(feature = "std")]
#[cfg_attr(any(test, feature = "std"), derive(thiserror::Error))]
#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -324,10 +368,13 @@ impl PlatformId {
}
}

pub fn as_str(&self) -> Result<&str, PlatformIdError> {
Ok(core::str::from_utf8(self.as_bytes())
.map_err(|_| PlatformIdError::Malformed)?
.trim_end_matches('\0'))
pub fn as_str(&self) -> &str {
// We discard the Result here / `expect` becuase the constructor /
// try_from, and deserialization methods check the validity of the
// string when the instance is created.
core::str::from_utf8(self.as_bytes())
.expect("malformed platform id string")
.trim_end_matches('\0')
}
}

Expand Down Expand Up @@ -618,8 +665,8 @@ mod tests {
let pid = PlatformId::try_from(pid)?;
let mut dest = [0u8; PLATFORM_ID_MAX_LEN];

assert_eq!(pid.as_str()?.len(), PLATFORM_ID_V2_LEN);
dest[..pid.as_str()?.len()].copy_from_slice(pid.as_str()?.as_bytes());
assert_eq!(pid.as_str().len(), PLATFORM_ID_V2_LEN);
dest[..pid.as_str().len()].copy_from_slice(pid.as_str().as_bytes());

assert_eq!(&dest[..PREFIX_LEN], PLATFORM_ID_V2_PREFIX.as_bytes());
assert_eq!(&pid.as_bytes()[4..16], &dest[4..16]);
Expand All @@ -632,14 +679,14 @@ mod tests {
let pid = PlatformId::try_from(PID_V1_GOOD)?;
let mut dest = [0u8; PLATFORM_ID_MAX_LEN];

dest[..pid.as_str()?.len()].copy_from_slice(pid.as_str()?.as_bytes());
dest[..pid.as_str().len()].copy_from_slice(pid.as_str().as_bytes());
// mut no more
let dest = dest;

assert_eq!(pid.as_bytes(), &dest[..pid.as_str()?.len()]);
assert_eq!(pid.as_bytes(), &dest[..pid.as_str().len()]);

let pid = PlatformId::try_from(&dest[..])?;
assert_eq!(pid.as_str()?, PID_V1_GOOD);
assert_eq!(pid.as_str(), PID_V1_GOOD);

Ok(())
}
Expand All @@ -649,14 +696,56 @@ mod tests {
let pid = PlatformId::try_from(RFD308_V2_GOOD)?;
let mut dest = [0u8; PLATFORM_ID_MAX_LEN];

dest[..pid.as_str()?.len()].copy_from_slice(pid.as_str()?.as_bytes());
dest[..pid.as_str().len()].copy_from_slice(pid.as_str().as_bytes());
// mut no more
let dest = dest;

assert_eq!(pid.as_bytes(), &dest[..pid.as_str()?.len()]);
assert_eq!(pid.as_bytes(), &dest[..pid.as_str().len()]);

let pid = PlatformId::try_from(&dest[..])?;
assert_eq!(pid.as_str()?, PID_V2_GOOD);
assert_eq!(pid.as_str(), PID_V2_GOOD);

Ok(())
}

#[test]
fn pid_serialize() -> Result<(), PlatformIdError> {
let mut buf = [0u8; PlatformId::MAX_SIZE];

let pid = PlatformId::try_from(RFD308_V2_GOOD)?;
let count = hubpack::serialize(&mut buf, &pid).unwrap();

assert_eq!(count, buf.len());
Ok(())
}

const RFD308_V2_SERIALIZED: [u8; 32] = [
b'P', b'D', b'V', b'2', b':', b'P', b'P', b'P', b'-', b'P', b'P', b'P',
b'P', b'P', b'P', b'P', b':', b'R', b'R', b'R', b':', b'S', b'S', b'S',
b'S', b'S', b'S', b'S', b'S', b'S', b'S', b'S',
];

#[test]
fn pid_deserialize_good() -> Result<(), PlatformIdError> {
let (pid, _) =
hubpack::deserialize::<PlatformId>(&RFD308_V2_SERIALIZED)
.expect("deserialization failed for \"good\" test data");
let pid_expected = PlatformId::try_from(RFD308_V2_GOOD)
.expect("failed to create PlatformId from \"good\" test data");

assert_eq!(pid, pid_expected);
Ok(())
}

#[test]
fn pid_deserialize_bad() -> Result<(), PlatformIdError> {
// make a local copy of the good serialized value
let mut pid = RFD308_V2_SERIALIZED;
// set one character to an invalid value
pid[22] = b's';

let res = hubpack::deserialize::<PlatformId>(&pid);
assert_eq!(res, Err(hubpack::error::Error::Custom));

Ok(())
}
Expand Down Expand Up @@ -684,9 +773,11 @@ z1UhDy+0wtYKr4IhWWw3E8v3Y9JcjeT1s43Nc/wG

let platform_id = PlatformId::try_from(&cert_chain)
.context("PlatformId from cert chain")?;
let platform_id = platform_id.as_str().context("PlatformId to str")?;

Ok(assert_eq!(platform_id, "PDV2:PPP-PPPPPPP:RRR:SSSSSSSSSSS"))
Ok(assert_eq!(
platform_id.as_str(),
"PDV2:PPP-PPPPPPP:RRR:SSSSSSSSSSS"
))
}

#[cfg(feature = "std")]
Expand Down
5 changes: 1 addition & 4 deletions dice-mfg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ impl MfgDriver {
pub fn set_platform_id(&mut self, pid: PlatformId) -> Result<()> {
let mut retry = self.max_retry;
loop {
match pid.as_str() {
Ok(s) => print!("setting platform id to: \"{s}\" ... "),
Err(e) => return Err(Error::InvalidPlatformId(e).into()),
}
print!("setting platform id to: \"{pid}\"");
io::stdout().flush()?;

let hash = self.send_msg(&MfgMessage::PlatformId(pid))?;
Expand Down
5 changes: 2 additions & 3 deletions dice-mfg/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,9 @@ fn main() -> Result<()> {

let (cert, csr) = if let Some(ref w) = work_dir {
// use workdir to hold CSR if provided
let id = platform_id.as_str()?;
(
w.join(format!("{id}.cert.pem")),
w.join(format!("{id}.csr.pem")),
w.join(format!("{platform_id}.cert.pem")),
w.join(format!("{platform_id}.csr.pem")),
)
} else {
// otherwise use a tempdir
Expand Down
7 changes: 1 addition & 6 deletions verifier-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,7 @@ fn main() -> Result<()> {

let platform_id = PlatformId::try_from(&cert_chain)
.context("PlatformId from attestation cert chain")?;
let platform_id = platform_id
.as_str()
.map_err(|_| anyhow!("Invalid PlatformId"))?;
let platform_id = platform_id.as_str();

println!("{platform_id}");
}
Expand Down Expand Up @@ -276,9 +274,6 @@ fn main() -> Result<()> {
)?
}
};
let platform_id = platform_id
.as_str()
.map_err(|_| anyhow!("Invalid PlatformId"))?;

println!("{platform_id}");
}
Expand Down
Loading