Skip to content

Commit

Permalink
Update sighash saving helper
Browse files Browse the repository at this point in the history
  • Loading branch information
ekrembal committed Sep 6, 2024
1 parent 55cb191 commit 78d4651
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 92 deletions.
122 changes: 31 additions & 91 deletions core/src/database/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,67 +468,40 @@ impl Database {
/// Verifier: saves the sighash and returns sec and agg nonces, if the sighash is already there and different, returns error
pub async fn save_sighashes_and_get_nonces(
&self,
tx: Option<&mut sqlx::Transaction<'_, Postgres>>,
deposit_outpoint: OutPoint,
index: usize,
sighashes: &[MuSigSigHash],
) -> Result<Option<Vec<(MuSigSecNonce, MuSigAggNonce)>>, BridgeError> {
let indices: Vec<i32> = sqlx::query_scalar::<_, i32>(
"SELECT internal_idx FROM nonces WHERE deposit_outpoint = $1 ORDER BY internal_idx ASC;",
)
.bind(OutPointDB(deposit_outpoint))
.fetch_all(&self.connection)
.await?;

// Start a batch query for updating sighashes
let mut update_query_builder = QueryBuilder::new("UPDATE nonces SET sighash = CASE");

// Create a set of updates for the corresponding sighash and internal_idx
for (sighash, idx) in sighashes.iter().zip(indices[index..].iter()) {
update_query_builder
.push(" WHEN internal_idx = ")
.push_bind(*idx)
.push(" THEN ")
.push_bind(sighash);
}

// Finish the CASE statement and add the WHERE clause for deposit_outpoint
update_query_builder
.push(" END WHERE deposit_outpoint = ")
// Update the sighashes
let mut query = QueryBuilder::new(
"UPDATE nonces
SET sighash = batch.sighash
FROM (",
);
let query = query.push_values(sighashes.iter().enumerate(), |mut builder, (i, sighash)| {
builder.push_bind((index + i) as i32).push_bind(sighash);
});

let query = query
.push(
") AS batch (internal_idx, sighash)
WHERE nonces.internal_idx = batch.internal_idx AND nonces.deposit_outpoint = ",
)
.push_bind(OutPointDB(deposit_outpoint))
.push(" AND internal_idx IN (")
.push_values(indices[index..].iter(), |mut b, idx| {
b.push_bind(*idx); // Add semicolon to ensure closure returns `()`
})
.push(")");

// Execute the batch update
update_query_builder
.build()
.execute(&self.connection)
.await?;
.push(" RETURNING sec_nonce, agg_nonce;")
.build_query_as();

// Now batch fetch the sec_nonce and agg_nonce after the update
let mut select_query_builder =
QueryBuilder::new("SELECT sec_nonce, agg_nonce FROM nonces WHERE deposit_outpoint = ");
select_query_builder
.push_bind(OutPointDB(deposit_outpoint))
.push(" AND internal_idx IN (")
.push_values(indices[index..].iter(), |mut b, idx| {
b.push_bind(*idx); // Ensure closure returns `()`
})
.push(") AND sighash IN (")
.push_values(sighashes, |mut b, sighash| {
b.push_bind(sighash); // Ensure closure returns `()`
})
.push(")");

// Execute the batch select query
let nonces: Vec<(MuSigSecNonce, MuSigAggNonce)> = select_query_builder
.build_query_as()
.fetch_all(&self.connection)
.await?;
let result: Result<Vec<(MuSigSecNonce, MuSigAggNonce)>, sqlx::Error> = match tx {
Some(tx) => query.fetch_all(&mut **tx).await,
None => query.fetch_all(&self.connection).await,
};

Ok(Some(nonces))
match result {
Ok(nonces) => Ok(Some(nonces)),
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(BridgeError::DatabaseError(e)),
}
}

/// Verifier: Save the agg nonces for signing
Expand Down Expand Up @@ -703,39 +676,6 @@ impl Database {
None => Ok(None),
}
}

// pub async fn save_kickoff_root(
// &self,
// deposit_outpoint: OutPoint,
// kickoff_root: [u8; 32],
// ) -> Result<(), BridgeError> {
// sqlx::query(
// "INSERT INTO kickoff_roots (deposit_outpoint, kickoff_merkle_root) VALUES ($1, $2);",
// )
// .bind(OutPointDB(deposit_outpoint))
// .bind(hex::encode(kickoff_root))
// .execute(&self.connection)
// .await?;

// Ok(())
// }

// pub async fn get_kickoff_root(
// &self,
// deposit_outpoint: OutPoint,
// ) -> Result<Option<[u8; 32]>, BridgeError> {
// let qr: Option<String> = sqlx::query_scalar(
// "SELECT kickoff_merkle_root FROM kickoff_roots WHERE deposit_outpoint = $1;",
// )
// .bind(OutPointDB(deposit_outpoint))
// .fetch_optional(&self.connection)
// .await?;

// match qr {
// Some(root) => Ok(Some(hex::decode(root)?.try_into()?)),
// None => Ok(None),
// }
// }
}

#[cfg(test)]
Expand Down Expand Up @@ -897,7 +837,7 @@ mod tests {
db.save_nonces(None, outpoint, &nonce_pairs).await.unwrap();
db.save_agg_nonces(outpoint, &agg_nonces).await.unwrap();
let db_sec_and_agg_nonces = db
.save_sighashes_and_get_nonces(outpoint, index, &sighashes)
.save_sighashes_and_get_nonces(None, outpoint, index, &sighashes)
.await
.unwrap()
.unwrap();
Expand Down Expand Up @@ -937,7 +877,7 @@ mod tests {
db.save_nonces(None, outpoint, &nonce_pairs).await.unwrap();
db.save_agg_nonces(outpoint, &agg_nonces).await.unwrap();
let db_sec_and_agg_nonces = db
.save_sighashes_and_get_nonces(outpoint, index, &sighashes)
.save_sighashes_and_get_nonces(None, outpoint, index, &sighashes)
.await
.unwrap()
.unwrap();
Expand Down Expand Up @@ -982,15 +922,15 @@ mod tests {
db.save_nonces(None, outpoint, &nonce_pairs).await.unwrap();
db.save_agg_nonces(outpoint, &agg_nonces).await.unwrap();
let _db_sec_and_agg_nonces = db
.save_sighashes_and_get_nonces(outpoint, index, &sighashes)
.save_sighashes_and_get_nonces(None, outpoint, index, &sighashes)
.await
.unwrap()
.unwrap();

// Accidentally try to save a different sighash
sighashes[0] = ByteArray32([2u8; 32]);
let _db_sec_and_agg_nonces = db
.save_sighashes_and_get_nonces(outpoint, index, &sighashes)
.save_sighashes_and_get_nonces(None, outpoint, index, &sighashes)
.await
.expect_err("Should return database sighash update error");
println!("Error: {:?}", _db_sec_and_agg_nonces);
Expand Down
4 changes: 3 additions & 1 deletion core/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ where
let nonces = self
.db
.save_sighashes_and_get_nonces(
None,
deposit_outpoint,
self.config.num_operators + 1,
&slash_or_take_sighashes,
Expand Down Expand Up @@ -376,7 +377,7 @@ where
// println!("Operator takes sighashes: {:?}", operator_takes_sighashes);
let nonces = self
.db
.save_sighashes_and_get_nonces(deposit_outpoint, 1, &operator_takes_sighashes)
.save_sighashes_and_get_nonces(None, deposit_outpoint, 1, &operator_takes_sighashes)
.await?
.ok_or(BridgeError::NoncesNotFound)?;
// println!("Nonces: {:?}", nonces);
Expand Down Expand Up @@ -482,6 +483,7 @@ where
let nonces = self
.db
.save_sighashes_and_get_nonces(
None,
deposit_outpoint,
0,
&[ByteArray32(move_tx_sighash.to_byte_array())],
Expand Down

0 comments on commit 78d4651

Please sign in to comment.