diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 5f3814e63..aeaa3992e 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -62,4 +62,4 @@ jobs: env: RAYON_NUM_THREADS: 8 RUST_LOG: debug - run: cargo run --release --package ceno_zkvm --example fibonacci_elf --target ${{ matrix.target }} -- \ No newline at end of file + run: cargo run --release --package ceno_zkvm --example fibonacci_elf --target ${{ matrix.target }} diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index c3c06c11c..d1ef6d2e3 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, ops::Not}; use super::rv32im::EmuContext; use crate::{ @@ -79,13 +79,7 @@ impl VMState { pub fn iter_until_halt(&mut self) -> impl Iterator> + '_ { let emu = Emulator::new(); - from_fn(move || { - if self.halted() { - None - } else { - Some(self.step(&emu)) - } - }) + from_fn(move || self.halted().not().then(|| self.step(&emu))) } fn step(&mut self, emu: &Emulator) -> Result { diff --git a/ceno_zkvm/examples/fibonacci_elf.rs b/ceno_zkvm/examples/fibonacci_elf.rs index 168a4b36e..a7c4ac127 100644 --- a/ceno_zkvm/examples/fibonacci_elf.rs +++ b/ceno_zkvm/examples/fibonacci_elf.rs @@ -19,7 +19,7 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate}; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use std::{ collections::{HashMap, HashSet}, - panic, + panic::{self, PanicHookInfo}, time::Instant, }; use tracing_flame::FlameLayer; @@ -35,6 +35,27 @@ struct Args { max_steps: Option, } +/// Temporarily override the panic hook +/// +/// We restore the original hook after we are done. +fn with_panic_hook(hook: Box) + Sync + Send + 'static>, f: F) -> R +where + F: FnOnce() -> R, +{ + // Save the current panic hook + let original_hook = panic::take_hook(); + + // Set the new panic hook + panic::set_hook(hook); + + let result = f(); + + // Restore the original panic hook + panic::set_hook(original_hook); + + result +} + fn main() { let args = Args::parse(); @@ -125,7 +146,7 @@ fn main() { let pk = zkvm_cs .clone() - .key_gen::(pp.clone(), vp.clone(), zkvm_fixed_traces.clone()) + .key_gen::(pp, vp, zkvm_fixed_traces.clone()) .expect("keygen failed"); let vk = pk.get_vk(); @@ -153,14 +174,14 @@ fn main() { record.insn().codes().kind == EANY && record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() }) - .and_then(|halt_record| halt_record.rs2()) + .and_then(StepRecord::rs2) .map(|rs2| rs2.value); let final_access = vm.tracer().final_accesses(); let end_cycle: u32 = vm.tracer().cycle().try_into().unwrap(); let pi = PublicValues::new( - exit_code.unwrap_or(0), + exit_code.unwrap_or_default(), vm.program().entry, Tracer::SUBCYCLES_PER_INSN as u32, vm.get_pc().into(), @@ -188,7 +209,7 @@ fn main() { MemFinalRecord { addr: rec.addr, value: vm.peek_register(index), - cycle: *final_access.get(&vma).unwrap_or(&0), + cycle: final_access.get(&vma).copied().unwrap_or_default(), } } else { // The table is padded beyond the number of registers. @@ -209,7 +230,7 @@ fn main() { MemFinalRecord { addr: rec.addr, value: vm.peek_memory(vma), - cycle: *final_access.get(&vma).unwrap_or(&0), + cycle: final_access.get(&vma).copied().unwrap_or_default(), } }) .collect_vec(); @@ -218,7 +239,12 @@ fn main() { // Find the final public IO cycles. let io_final = io_init .iter() - .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) + .map(|rec| { + final_access + .get(&rec.addr.into()) + .copied() + .unwrap_or_default() + }) .collect_vec(); // assign table circuits @@ -269,18 +295,16 @@ fn main() { } let transcript = Transcript::new(b"riscv"); - // change public input maliciously should cause verifier to reject proof + // Maliciously changing the public input should cause the verifier to reject the proof. zkvm_proof.raw_pi[0] = vec![::BaseField::ONE]; zkvm_proof.raw_pi[1] = vec![::BaseField::ONE]; - // capture panic message, if have - let default_hook = panic::take_hook(); - panic::set_hook(Box::new(|_info| { - // by default it will print msg to stdout/stderr - // we override it to avoid print msg since we will capture the msg by our own - })); - let result = panic::catch_unwind(|| verifier.verify_proof(zkvm_proof, transcript)); - panic::set_hook(default_hook); + // capture panic message, if any + // by default it will print msg to stdout/stderr + // we override it to avoid print msg since we will capture the msg by ourselves + let result = with_panic_hook(Box::new(|_info| ()), || { + panic::catch_unwind(|| verifier.verify_proof(zkvm_proof, transcript)) + }); match result { Ok(res) => { res.expect_err("verify proof should return with error"); @@ -322,11 +346,11 @@ fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) { tracing::debug!( "Memory range (accessed): {:?}", - format_segments(vm.platform(), accessed_addrs.iter().copied()) + format_segments(vm.platform(), &accessed_addrs) ); tracing::debug!( "Memory range (handled): {:?}", - format_segments(vm.platform(), handled_addrs.iter().copied()) + format_segments(vm.platform(), &handled_addrs) ); for addr in &accessed_addrs { @@ -334,11 +358,12 @@ fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) { } } -fn format_segments( +fn format_segments<'a>( platform: &Platform, - addrs: impl Iterator, -) -> HashMap> { + addrs: impl IntoIterator, +) -> HashMap> { addrs + .into_iter() .into_grouping_map_by(|addr| format_segment(platform, addr.0)) .minmax() } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 8c9b35524..64722bd25 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -46,11 +46,7 @@ impl MmuConfig { io_addrs: &[Addr], ) { assert!( - chain( - static_mem_init.iter().map(|record| record.addr), - io_addrs.iter().copied(), - ) - .all_unique(), + chain(static_mem_init.iter().map(|record| &record.addr), io_addrs,).all_unique(), "memory addresses must be unique" ); @@ -142,14 +138,9 @@ impl MemPadder { new_len: usize, records: Vec, ) -> Vec { - if records.is_empty() { - self.padded(new_len, records) - } else { - self.padded(new_len, records) - .into_iter() - .sorted_by_key(|record| record.addr) - .collect() - } + let mut padded = self.padded(new_len, records); + padded.sort_by_key(|record| record.addr); + padded } /// Pad `records` to `new_len` using unused addresses. diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3143f2287..57e9200c6 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,5 +1,4 @@ use ff_ext::ExtensionField; -use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; use serde::{Deserialize, Serialize}; use std::{collections::BTreeMap, fmt::Debug}; @@ -133,15 +132,15 @@ impl> ZKVMProof { .iter() .map(|pv| { if pv.len() == 1 { - // this is constant poly, and always evaluate to same constant value + // this is constant poly, and always evaluates to same constant value E::from(pv[0]) } else { - // set 0 as placeholder. will be evaluate lazily + // set 0 as placeholder. will be evaluated lazily // Or the vector is empty, i.e. the constant 0 polynomial. E::ZERO } }) - .collect_vec(); + .collect(); Self { raw_pi, pi_evals, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 30d8d9a6d..9729be88c 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1241,7 +1241,7 @@ impl TowerProver { virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator); } } - tracing::debug!("generated tower proof at round {}/{}", round, max_round_index); + tracing::debug!("generated tower proof at round {round}/{max_round_index}"); let wrap_batch_span = entered_span!("wrap_batch"); // NOTE: at the time of adding this span, visualizing it with the flamegraph layer diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index bdce2bd09..d7e8cfd6c 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -60,12 +60,12 @@ impl> ZKVMVerifier transcript: Transcript, does_halt: bool, ) -> Result { - // require ecall/halt proof to exist, depending whether we expect a halt. + // require ecall/halt proof to exist, depending on whether we expect a halt. let num_instances = vm_proof .opcode_proofs .get(&HaltInstruction::::name()) .map(|(_, p)| p.num_instances) - .unwrap_or(0); + .unwrap_or_default(); if num_instances != (does_halt as usize) { return Err(ZKVMError::VerifyError(format!( "ecall/halt num_instances={}, expected={}", @@ -117,12 +117,12 @@ impl> ZKVMVerifier } } - for (name, (_, proof)) in vm_proof.opcode_proofs.iter() { + for (name, (_, proof)) in &vm_proof.opcode_proofs { tracing::debug!("read {}'s commit", name); PCS::write_commitment(&proof.wits_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; } - for (name, (_, proof)) in vm_proof.table_proofs.iter() { + for (name, (_, proof)) in &vm_proof.table_proofs { tracing::debug!("read {}'s commit", name); PCS::write_commitment(&proof.wits_commit, &mut transcript) .map_err(ZKVMError::PCSError)?;