Skip to content

Commit

Permalink
enable sp1 aggregation
Browse files Browse the repository at this point in the history
Signed-off-by: smtmfft <[email protected]>
  • Loading branch information
smtmfft committed Sep 28, 2024
1 parent 9da8834 commit bdcb5f3
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 45 deletions.
102 changes: 70 additions & 32 deletions provers/sp1/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

use once_cell::sync::Lazy;
use raiko_lib::{
input::{AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput},
input::{
AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput,
ZkAggregationGuestInput,
},
prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult},
Measurement,
};
Expand All @@ -14,14 +17,15 @@ use sp1_sdk::{
action,
network::client::NetworkClient,
proto::network::{ProofMode, UnclaimReason},
SP1Proof, SP1ProofWithPublicValues, SP1VerifyingKey,
};
use sp1_sdk::{HashableKey, ProverClient, SP1Stdin};
use std::{
borrow::BorrowMut,
env, fs,
path::{Path, PathBuf},
};
use tracing::{debug, info};
use tracing::{debug, error, info};

pub const ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest");
pub const AGGREGATION_ELF: &[u8] = include_bytes!("../../guest/elf/sp1-aggregation");
Expand Down Expand Up @@ -76,9 +80,15 @@ impl From<Sp1Response> for Proof {
fn from(value: Sp1Response) -> Self {
Self {
proof: value.proof,
quote: None,
input: None,
uuid: None,
quote: value
.sp1_proof
.as_ref()
.map(|p| serde_json::to_string(&p.proof).unwrap()),
input: value
.sp1_proof
.as_ref()
.map(|p| B256::from_slice(p.public_values.as_slice())),
uuid: value.vkey.map(|v| serde_json::to_string(&v).unwrap()),
kzg_proof: None,
}
}
Expand All @@ -87,6 +97,9 @@ impl From<Sp1Response> for Proof {
#[derive(Clone, Serialize, Deserialize)]
pub struct Sp1Response {
pub proof: Option<String>,
/// for aggregation
pub sp1_proof: Option<SP1ProofWithPublicValues>,
pub vkey: Option<SP1VerifyingKey>,
}

pub struct Sp1Prover;
Expand Down Expand Up @@ -159,7 +172,13 @@ impl Prover for Sp1Prover {
.map_err(|e| ProverError::GuestError(format!("Sp1: network proof failed {e:?}")))?
};

let proof_bytes = prove_result.bytes();
let proof_bytes = match param.recursion {
RecursionMode::Compressed => {
info!("Compressed proof is used in aggregation mode only");
vec![]
}
_ => prove_result.bytes(),
};
if param.verify {
let time = Measurement::start("verify", false);
let pi_hash = prove_result
Expand Down Expand Up @@ -194,6 +213,8 @@ impl Prover for Sp1Prover {
Ok::<_, ProverError>(
Sp1Response {
proof: proof_string,
sp1_proof: Some(prove_result),
vkey: Some(vk),
}
.into(),
)
Expand Down Expand Up @@ -231,10 +252,38 @@ impl Prover for Sp1Prover {
let param = Sp1Param::deserialize(config.get("sp1").unwrap()).unwrap();
let mode = param.prover.clone().unwrap_or_else(get_env_mock);

println!("param: {param:?}");
info!("aggregate proof with param: {param:?}");

let block_inputs: Vec<B256> = input
.proofs
.iter()
.map(|proof| proof.input.unwrap())
.collect::<Vec<_>>();
let block_proof_vk = serde_json::from_str::<SP1VerifyingKey>(
&input.proofs.first().unwrap().uuid.clone().unwrap(),
)
.map_err(|e| ProverError::GuestError(format!("Failed to parse SP1 vk: {e}")))?;
let stark_vk = block_proof_vk.vk.clone();
let image_id = block_proof_vk.hash_u32();
let aggregation_input = ZkAggregationGuestInput {
image_id: image_id,
block_inputs,
};

let mut stdin = SP1Stdin::new();
stdin.write(&input);
stdin.write(&aggregation_input);
for proof in input.proofs.iter() {
let sp1_proof = serde_json::from_str::<SP1Proof>(&proof.quote.clone().unwrap())
.map_err(|e| ProverError::GuestError(format!("Failed to parse SP1 proof: {e}")))?;
match sp1_proof {
SP1Proof::Compressed(block_proof) => {
stdin.write_proof(block_proof.into(), stark_vk.clone());
}
_ => {
error!("unsupported proof type for aggregation: {:?}", sp1_proof);
}
}
}

// Generate the proof for the given program.
let client = param
Expand All @@ -248,29 +297,11 @@ impl Prover for Sp1Prover {

let (pk, vk) = client.setup(AGGREGATION_ELF);

let prove_action = action::Prove::new(client.prover.as_ref(), &pk, stdin.clone());
let prove_result = if !matches!(mode, ProverMode::Network) {
tracing::debug!("Proving locally with recursion mode: {:?}", param.recursion);
match param.recursion {
RecursionMode::Core => prove_action.run(),
RecursionMode::Compressed => prove_action.compressed().run(),
RecursionMode::Plonk => prove_action.plonk().run(),
}
.map_err(|e| ProverError::GuestError(format!("Sp1: local proving failed: {e}")))?
} else {
let network_prover = sp1_sdk::NetworkProver::new();

let proof_id = network_prover
.request_proof(AGGREGATION_ELF, stdin, param.recursion.clone().into())
.await
.map_err(|e| {
ProverError::GuestError(format!("Sp1: requesting proof failed: {e}"))
})?;
network_prover
.wait_proof::<sp1_sdk::SP1ProofWithPublicValues>(&proof_id, None)
.await
.map_err(|e| ProverError::GuestError(format!("Sp1: network proof failed {e:?}")))?
};
let prove_result = client
.prove(&pk, stdin)
.plonk()
.run()
.expect("proving failed");

let proof_bytes = prove_result.bytes();
if param.verify {
Expand Down Expand Up @@ -300,7 +331,14 @@ impl Prover for Sp1Prover {
),
);

Ok::<_, ProverError>(Sp1Response { proof }.into())
Ok::<_, ProverError>(
Sp1Response {
proof: proof,
sp1_proof: None,
vkey: None,
}
.into(),
)
}
}

Expand Down
23 changes: 13 additions & 10 deletions provers/sp1/guest/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions provers/sp1/guest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ path = "src/benchmark/bn254_mul.rs"

[dependencies]
raiko-lib = { path = "../../../lib", features = ["std", "sp1"] }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1", branch = "dev", features = [
"verify",
] }
sp1-zkvm = { version = "2.0.0", features = ["verify"] }
sp1-core = { version = "1.1.1" }
sha2 = { git = "https://github.com/sp1-patches/RustCrypto-hashes", package = "sha2", branch = "patch-v0.10.8" }
secp256k1 = { git = "https://github.com/sp1-patches/rust-secp256k1", branch = "patch-secp256k1-v0.29.0" }
Expand Down
Binary file modified provers/sp1/guest/elf/sp1-aggregation
Binary file not shown.
Binary file modified provers/sp1/guest/elf/sp1-guest
Binary file not shown.

0 comments on commit bdcb5f3

Please sign in to comment.