Skip to content

Commit

Permalink
feat!: Refactor RowId/SampleId/NamespacedDataId related API (#236)
Browse files Browse the repository at this point in the history
* feat!: Refactor RowID/SampleId/NamespacedDataId related API

* fix docs

* Add DataAvailabilityHeader::new_unchecked

* use usize::pow

* better doc

* remove RowId::size/SampleId::size/NamespacedDataId::size

* rename

* fix doc
  • Loading branch information
oblique authored Mar 8, 2024
1 parent 0f89531 commit debe5fe
Show file tree
Hide file tree
Showing 14 changed files with 406 additions and 351 deletions.
53 changes: 29 additions & 24 deletions node/src/daser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,13 @@ where

async fn sample_block(&mut self, header: &ExtendedHeader) -> Result<(Vec<Cid>, bool)> {
let now = Instant::now();
let block_len = header.dah.square_len() * header.dah.square_len();
let indexes = random_indexes(block_len, self.max_samples_needed);
let indexes = random_indexes(header.dah.square_width(), self.max_samples_needed);
let mut futs = FuturesUnordered::new();

for index in indexes {
for (row_index, column_index) in indexes {
let fut = self
.p2p
.get_sample(index, header.dah.square_len(), header.height().value());
.get_sample(row_index, column_index, header.height().value());
futs.push(fut);
}

Expand All @@ -158,7 +157,7 @@ where

while let Some(res) = futs.next().await {
match res {
Ok(sample) => cids.push(convert_cid(&sample.sample_id.into())?),
Ok(sample) => cids.push(convert_cid(&sample.id.into())?),
// Validation is done at Bitswap level, through `ShwapMultihasher`.
// If the sample is not valid, it will never be delivered to us
// as the data of the CID. Because of that, the only signal
Expand All @@ -179,18 +178,24 @@ where
}
}

fn random_indexes(block_len: usize, max_samples_needed: usize) -> HashSet<usize> {
// If block length is smaller than `max_samples_needed`, we are going
fn random_indexes(square_width: u16, max_samples_needed: usize) -> HashSet<(u16, u16)> {
let samples_in_block = usize::from(square_width).pow(2);

// If block size is smaller than `max_samples_needed`, we are going
// to sample the whole block. Randomness is not needed for this.
if block_len <= max_samples_needed {
return (0..block_len).collect();
if samples_in_block <= max_samples_needed {
return (0..square_width)
.flat_map(|row| (0..square_width).map(move |col| (row, col)))
.collect();
}

let mut indexes = HashSet::with_capacity(max_samples_needed);
let mut rng = rand::thread_rng();

while indexes.len() < max_samples_needed {
indexes.insert(rng.gen::<usize>() % block_len);
let row = rng.gen::<u16>() % square_width;
let col = rng.gen::<u16>() % square_width;
indexes.insert((row, col));
}

indexes
Expand Down Expand Up @@ -251,10 +256,10 @@ mod tests {
handle: &mut MockP2pHandle,
gen: &mut ExtendedHeaderGenerator,
store: &InMemoryStore,
square_len: usize,
square_width: usize,
simulate_invalid_sampling: bool,
) {
let eds = generate_eds(square_len);
let eds = generate_eds(square_width);
let dah = DataAvailabilityHeader::from_eds(&eds);
let header = gen.next_with_dah(dah);
let height = header.height().value();
Expand All @@ -263,7 +268,7 @@ mod tests {

let mut cids = Vec::new();

for i in 0..(square_len * square_len).min(MAX_SAMPLES_NEEDED) {
for i in 0..(square_width * square_width).min(MAX_SAMPLES_NEEDED) {
let (cid, respond_to) = handle.expect_get_shwap_cid().await;

// Simulate invalid sample by triggering BitswapQueryTimeout
Expand All @@ -273,7 +278,7 @@ mod tests {
}

let sample_id: SampleId = cid.try_into().unwrap();
assert_eq!(sample_id.row.block_height, height);
assert_eq!(sample_id.block_height(), height);

let sample = gen_sample_of_cid(sample_id, &eds, store).await;
let sample_bytes = sample.encode_vec().unwrap();
Expand All @@ -294,15 +299,15 @@ mod tests {
eds: &ExtendedDataSquare,
store: &InMemoryStore,
) -> Sample {
let header = store
.get_by_height(sample_id.row.block_height)
.await
.unwrap();

let row = sample_id.row.index as usize;
let col = sample_id.index as usize;
let index = row * header.dah.square_len() + col;

Sample::new(AxisType::Row, index, eds, header.height().value()).unwrap()
let header = store.get_by_height(sample_id.block_height()).await.unwrap();

Sample::new(
sample_id.row_index(),
sample_id.column_index(),
AxisType::Row,
eds,
header.height().value(),
)
.unwrap()
}
}
30 changes: 21 additions & 9 deletions node/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,28 +213,40 @@ where
Ok(self.p2p.get_verified_headers_range(from, amount).await?)
}

/// Request a [`Row`] from the network.
/// Request a verified [`Row`] from the network.
///
/// The result was not verified and [`Row::verify`] must be called.
/// # Errors
///
/// On failure to receive a verified [`Row`] within a certain time, the
/// `NodeError::P2p(P2pError::BitswapQueryTimeout)` error will be returned.
pub async fn request_row(&self, row_index: u16, block_height: u64) -> Result<Row> {
Ok(self.p2p.get_row(row_index, block_height).await?)
}

/// Request a [`Sample`] from the network.
/// Request a verified [`Sample`] from the network.
///
/// The result was not verified and [`Sample::verify`] must be called.
/// # Errors
///
/// On failure to receive a verified [`Sample`] within a certain time, the
/// `NodeError::P2p(P2pError::BitswapQueryTimeout)` error will be returned.
pub async fn request_sample(
&self,
index: usize,
square_len: usize,
row_index: u16,
column_index: u16,
block_height: u64,
) -> Result<Sample> {
Ok(self.p2p.get_sample(index, square_len, block_height).await?)
Ok(self
.p2p
.get_sample(row_index, column_index, block_height)
.await?)
}

/// Request a [`NamespacedData`] from the network.
/// Request a verified [`NamespacedData`] from the network.
///
/// # Errors
///
/// The result was not verified and [`NamespacedData::verify`] must be called.
/// On failure to receive a verified [`NamespacedData`] within a certain time, the
/// `NodeError::P2p(P2pError::BitswapQueryTimeout)` error will be returned.
pub async fn request_namespaced_data(
&self,
namespace: Namespace,
Expand Down
6 changes: 3 additions & 3 deletions node/src/p2p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ impl P2p {
/// failed.
pub async fn get_sample(
&self,
index: usize,
square_len: usize,
row_index: u16,
column_index: u16,
block_height: u64,
) -> Result<Sample> {
let cid = sample_cid(index, square_len, block_height)?;
let cid = sample_cid(row_index, column_index, block_height)?;
let data = self.get_shwap_cid(cid, Some(GET_SAMPLE_TIMEOUT)).await?;
Ok(Sample::decode(&data[..])?)
}
Expand Down
20 changes: 10 additions & 10 deletions node/src/p2p/shwap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ where
let ns_data =
NamespacedData::decode(input).map_err(MultihasherError::custom_fatal)?;

let hash = convert_cid(&ns_data.namespaced_data_id.into())
let hash = convert_cid(&ns_data.id.into())
.map_err(MultihasherError::custom_fatal)?
.hash()
.to_owned();

let header = self
.header_store
.get_by_height(ns_data.namespaced_data_id.row.block_height)
.get_by_height(ns_data.id.block_height())
.await
.map_err(MultihasherError::custom_fatal)?;

Expand All @@ -68,14 +68,14 @@ where
ROW_ID_MULTIHASH_CODE => {
let row = Row::decode(input).map_err(MultihasherError::custom_fatal)?;

let hash = convert_cid(&row.row_id.into())
let hash = convert_cid(&row.id.into())
.map_err(MultihasherError::custom_fatal)?
.hash()
.to_owned();

let header = self
.header_store
.get_by_height(row.row_id.block_height)
.get_by_height(row.id.block_height())
.await
.map_err(MultihasherError::custom_fatal)?;

Expand All @@ -87,14 +87,14 @@ where
SAMPLE_ID_MULTIHASH_CODE => {
let sample = Sample::decode(input).map_err(MultihasherError::custom_fatal)?;

let hash = convert_cid(&sample.sample_id.into())
let hash = convert_cid(&sample.id.into())
.map_err(MultihasherError::custom_fatal)?
.hash()
.to_owned();

let header = self
.header_store
.get_by_height(sample.sample_id.row.block_height)
.get_by_height(sample.id.block_height())
.await
.map_err(MultihasherError::custom_fatal)?;

Expand All @@ -114,8 +114,8 @@ pub(super) fn row_cid(row_index: u16, block_height: u64) -> Result<Cid> {
convert_cid(&row_id.into())
}

pub(super) fn sample_cid(index: usize, square_len: usize, block_height: u64) -> Result<Cid> {
let sample_id = SampleId::new(index, square_len, block_height).map_err(P2pError::Cid)?;
pub(super) fn sample_cid(row_index: u16, column_index: u16, block_height: u64) -> Result<Cid> {
let sample_id = SampleId::new(row_index, column_index, block_height).map_err(P2pError::Cid)?;
convert_cid(&sample_id.into())
}

Expand Down Expand Up @@ -153,9 +153,9 @@ mod tests {
let mut gen = ExtendedHeaderGenerator::new();
let header = gen.next_with_dah(dah.clone());

let sample = Sample::new(AxisType::Row, 0, &eds, header.header.height.value()).unwrap();
let sample = Sample::new(0, 0, AxisType::Row, &eds, header.header.height.value()).unwrap();
let sample_bytes = sample.encode_vec().unwrap();
let cid = sample_cid(0, eds.square_len(), 1).unwrap();
let cid = sample_cid(0, 0, 1).unwrap();

sample.verify(&dah).unwrap();
store.append_single(header).await.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion rpc/tests/share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async fn get_eds() {
let header = client.header_get_by_height(submitted_height).await.unwrap();
let eds = client.share_get_eds(&header).await.unwrap();

for i in 0..header.dah.square_len() {
for i in 0..header.dah.square_width() {
let row_root = eds.row_nmt(i).unwrap().root();
assert_eq!(row_root, header.dah.row_root(i).unwrap());

Expand Down
48 changes: 24 additions & 24 deletions types/src/byzantine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct BadEncodingFraudProof {
// For non-nil shares MerkleProofs are computed.
shares: Vec<Option<ShareWithProof>>,
// Index represents the row/col index where ErrByzantineRow/ErrByzantineColl occurred.
index: usize,
index: u16,
// Axis represents the axis that verification failed on.
axis: AxisType,
}
Expand All @@ -70,8 +70,8 @@ impl FraudProof for BadEncodingFraudProof {
);
}

let merkle_row_roots = &header.dah.row_roots;
let merkle_col_roots = &header.dah.column_roots;
let merkle_row_roots = header.dah.row_roots();
let merkle_col_roots = header.dah.column_roots();

// NOTE: shouldn't ever happen as header should be validated before
if merkle_row_roots.len() != merkle_col_roots.len() {
Expand All @@ -82,7 +82,7 @@ impl FraudProof for BadEncodingFraudProof {
);
}

if self.index >= merkle_row_roots.len() {
if usize::from(self.index) >= merkle_row_roots.len() {
bail_validation!(
"fraud proof index ({}) >= dah rows len ({})",
self.index,
Expand All @@ -99,8 +99,8 @@ impl FraudProof for BadEncodingFraudProof {
}

let root = match self.axis {
AxisType::Row => merkle_row_roots[self.index].clone(),
AxisType::Col => merkle_col_roots[self.index].clone(),
AxisType::Row => merkle_row_roots[usize::from(self.index)].clone(),
AxisType::Col => merkle_col_roots[usize::from(self.index)].clone(),
};

// verify if the root can be converted to a cid and back
Expand Down Expand Up @@ -250,6 +250,8 @@ impl TryFrom<RawBadEncodingFraudProof> for BadEncodingFraudProof {
.map_err(|_| Error::InvalidAxis(value.axis))?
.try_into()?;

let index = u16::try_from(value.index).map_err(|_| Error::EdsInvalidDimentions)?;

let shares = value
.shares
.into_iter()
Expand All @@ -266,7 +268,7 @@ impl TryFrom<RawBadEncodingFraudProof> for BadEncodingFraudProof {
header_hash: value.header_hash.try_into()?,
block_height: value.height.try_into()?,
shares,
index: value.index as usize,
index,
axis,
})
}
Expand Down Expand Up @@ -310,13 +312,15 @@ pub(crate) mod test_utils {
) -> (ExtendedHeader, BadEncodingFraudProof) {
let mut rng = rand::thread_rng();

let square_len = eds.square_len();
let square_width = eds.square_width();
let axis = rng.gen_range(0..1).try_into().unwrap();
let axis_idx = rng.gen_range(0..square_len);
let axis_idx = rng.gen_range(0..square_width);

// invalidate more than a half shares in axis
let shares_to_break = square_len / 2 + 1;
for share_idx in index::sample(&mut rng, square_len, shares_to_break) {
let shares_to_break = square_width / 2 + 1;
for share_idx in index::sample(&mut rng, square_width.into(), shares_to_break.into()) {
let share_idx = u16::try_from(share_idx).unwrap();

let share = match axis {
AxisType::Row => eds.share_mut(axis_idx, share_idx).unwrap(),
AxisType::Col => eds.share_mut(share_idx, axis_idx).unwrap(),
Expand All @@ -339,19 +343,19 @@ pub(crate) mod test_utils {
pub(crate) fn befp_from_header_and_eds(
eh: &ExtendedHeader,
eds: &ExtendedDataSquare,
axis_idx: usize,
axis_idx: u16,
axis: AxisType,
) -> BadEncodingFraudProof {
let square_len = eds.square_len();
let square_width = eds.square_width();
let mut nmt = eds.axis_nmt(axis, axis_idx).unwrap();
let mut shares_with_proof: Vec<_> = Vec::with_capacity(square_len);
let mut shares_with_proof: Vec<_> = Vec::with_capacity(square_width.into());

// collect the shares for fraud proof
for share_idx in 0..square_len {
let (share, proof) = nmt.get_index_with_proof(share_idx);
for share_idx in 0..square_width {
let (share, proof) = nmt.get_index_with_proof(share_idx.into());

// it doesn't matter which is row and which is column as ods is first quadrant
let ns = if is_ods_square(axis_idx, share_idx, square_len) {
let ns = if is_ods_square(axis_idx, share_idx, square_width) {
Namespace::from_raw(&share[..NS_SIZE]).unwrap()
} else {
Namespace::PARITY_SHARE
Expand Down Expand Up @@ -384,14 +388,10 @@ pub(crate) mod test_utils {

#[cfg(test)]
mod tests {
use crate::{
test_utils::{corrupt_eds, generate_eds, ExtendedHeaderGenerator},
DataAvailabilityHeader,
};

use self::test_utils::befp_from_header_and_eds;

use super::*;
use crate::test_utils::{corrupt_eds, generate_eds, ExtendedHeaderGenerator};
use crate::DataAvailabilityHeader;

#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
Expand Down Expand Up @@ -464,7 +464,7 @@ mod tests {
let mut eds = generate_eds(8);
let (mut eh, proof) = corrupt_eds(&mut gen, &mut eds);

eh.dah.row_roots = vec![];
eh.dah = DataAvailabilityHeader::new_unchecked(Vec::new(), eh.dah.column_roots().to_vec());

proof.validate(&eh).unwrap_err();
}
Expand Down
Loading

0 comments on commit debe5fe

Please sign in to comment.