From 2af5b97bbfd28222510412bc2da53a0ce0d729cb Mon Sep 17 00:00:00 2001 From: Heng Date: Thu, 22 Aug 2024 21:15:21 +0800 Subject: [PATCH] Optimize circuit by int_div_unsafe (#8) * fix failure on empty chip * optimize int_div_unsafe --- src/circuit/ecc_chip.rs | 15 +++++----- src/circuit/fq12.rs | 7 ++--- src/circuit/integer_chip.rs | 56 ++++++++++++++++++++++++++++++++----- src/context.rs | 2 ++ src/range_info.rs | 9 ++++-- 5 files changed, 66 insertions(+), 23 deletions(-) diff --git a/src/circuit/ecc_chip.rs b/src/circuit/ecc_chip.rs index 48b109d..17845a7 100644 --- a/src/circuit/ecc_chip.rs +++ b/src/circuit/ecc_chip.rs @@ -844,16 +844,15 @@ pub trait EccChipBaseOps: ) -> Result, UnsafeError> { let diff_x = self.base_integer_chip().int_sub(&a.x, &b.x); let diff_y = self.base_integer_chip().int_sub(&a.y, &b.y); - let (x_eq, tangent) = self.base_integer_chip().int_div(&diff_y, &diff_x); + let tangent = self.base_integer_chip().int_div_unsafe(&diff_y, &diff_x); - // x cannot be same - let succeed = self.base_integer_chip().base_chip().try_assert_false(&x_eq); - let res = self.lambda_to_point_non_zero(&tangent, a, b); + match tangent { + Some(tangent) => { + let res = self.lambda_to_point_non_zero(&tangent, a, b); - if succeed { - Ok(res) - } else { - Err(UnsafeError::AddSameOrNegPoint) + Ok(res) + } + None => Err(UnsafeError::AddSameOrNegPoint), } } diff --git a/src/circuit/fq12.rs b/src/circuit/fq12.rs index bfd51e3..37e47ee 100644 --- a/src/circuit/fq12.rs +++ b/src/circuit/fq12.rs @@ -102,7 +102,7 @@ pub trait Fq2ChipOps: EccBaseIntegerChipWrapper { let t0 = self.base_integer_chip().int_square(&x.0); let t1 = self.base_integer_chip().int_square(&x.1); let t0 = self.base_integer_chip().int_add(&t0, &t1); - let t = self.base_integer_chip().int_unsafe_invert(&t0); + let t = self.base_integer_chip().int_unsafe_invert(&t0).unwrap(); let c0 = self.base_integer_chip().int_mul(&x.0, &t); let c1 = self.base_integer_chip().int_mul(&x.1, &t); let c1 = self.base_integer_chip().int_neg(&c1); @@ -303,10 +303,7 @@ pub trait Fq6ChipOps: Fq2ChipOps + Fq2BnSpecificO pub trait Fq12ChipOps: Fq6ChipOps + Fq6BnSpecificOps { fn fq12_reduce(&mut self, x: &AssignedFq12) -> AssignedFq12 { - ( - self.fq6_reduce(&x.0), - self.fq6_reduce(&x.1), - ) + (self.fq6_reduce(&x.0), self.fq6_reduce(&x.1)) } fn fq12_assert_one(&mut self, x: &AssignedFq12) { let one = self.fq12_assign_one(); diff --git a/src/circuit/integer_chip.rs b/src/circuit/integer_chip.rs index 427e35c..70d9c62 100644 --- a/src/circuit/integer_chip.rs +++ b/src/circuit/integer_chip.rs @@ -35,12 +35,17 @@ pub trait IntegerChipOps { a: &AssignedInteger, b: &AssignedInteger, ) -> AssignedInteger; - fn int_unsafe_invert(&mut self, x: &AssignedInteger) -> AssignedInteger; + fn int_unsafe_invert(&mut self, x: &AssignedInteger) -> Option>; fn int_div( &mut self, a: &AssignedInteger, b: &AssignedInteger, ) -> (AssignedCondition, AssignedInteger); + fn int_div_unsafe( + &mut self, + a: &AssignedInteger, + b: &AssignedInteger, + ) -> Option>; fn is_pure_zero(&mut self, a: &AssignedInteger) -> AssignedCondition; fn is_pure_w_modulus(&mut self, a: &AssignedInteger) -> AssignedCondition; fn is_int_zero(&mut self, a: &AssignedInteger) -> AssignedCondition; @@ -79,7 +84,7 @@ impl IntegerContext { ) { assert!(a.times < self.info().overflow_limit); assert!(b.times < self.info().overflow_limit); - assert!(rem.times == 1); + assert!(rem.times < self.info().overflow_limit); let info = self.info(); let one = N::one(); @@ -482,12 +487,49 @@ impl IntegerChipOps for IntegerContext { rem } - fn int_unsafe_invert(&mut self, x: &AssignedInteger) -> AssignedInteger { - //TODO: optimize + fn int_unsafe_invert(&mut self, x: &AssignedInteger) -> Option> { let one = self.assign_int_constant(W::one()); - let (c, v) = self.int_div(&one, x); - self.ctx.borrow_mut().assert_false(&c); - v + self.int_div_unsafe(&one, x) + } + + fn int_div_unsafe( + &mut self, + a: &AssignedInteger, + b: &AssignedInteger, + ) -> Option> { + let info = self.info(); + + let mut b = b.clone(); + + // Ensure b > a, so c * b > a and we can find the d that c * b = d * w + a + if b.times <= a.times { + let assigned_w = self.assign_w(&info.w_modulus); + while b.times < a.times { + b = self.int_add(&b, &assigned_w); + } + } + + let a_bn = self.get_w_bn(&a); + let b_bn = self.get_w_bn(&b); + + let b_inv: Option = bn_to_field::(&b_bn).invert().into(); + + match b_inv { + Some(b_inv) => { + let c = bn_to_field::(&a_bn) * b_inv; + let c_bn = field_to_bn(&c); + let d_bn = (&b_bn * &c_bn - &a_bn) / &info.w_modulus; + + let c = self.assign_w(&c_bn); + let d = self.assign_d(&d_bn); + + self.add_constraints_for_mul_equation_on_limbs(&b, &c, &d.0, &a); + self.add_constraints_for_mul_equation_on_native(&b, &c, &d.1, &a); + + Some(c) + } + None => None, + } } fn int_div( diff --git a/src/context.rs b/src/context.rs index 82588a8..db7de0d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -325,6 +325,7 @@ impl Records { let threads = 16; let chunk_size = (self.base_height + threads - 1) / threads; + let chunk_size = if chunk_size == 0 { 1 } else { chunk_size }; let chunk_num = chunk_size * threads; self.inner .base_adv_record @@ -419,6 +420,7 @@ impl Records { let threads = 16; let chunk_size = (self.base_height + threads - 1) / threads; + let chunk_size = if chunk_size == 0 { 1 } else { chunk_size }; let chunk_num = chunk_size * threads; self.inner .range_adv_record diff --git a/src/range_info.rs b/src/range_info.rs index 91e1e07..b17dd13 100644 --- a/src/range_info.rs +++ b/src/range_info.rs @@ -254,7 +254,7 @@ impl RangeInfo { let lcm = self .n_modulus .lcm(&(BigUint::from(1u64) << (self.limb_bits * self.mul_check_limbs))); - let max_rem = &self.w_ceil - 1u64; + let max_rem = &self.w_ceil * (self.overflow_limit - 1) - 1u64; assert!(lcm > &max_a * max_b); assert!(lcm > &max_d * &self.w_modulus + &max_rem); @@ -273,7 +273,7 @@ impl RangeInfo { .iter() .reduce(|acc, x| acc.max(x)) .unwrap(); - let max_rem_i = &self.limb_modulus - 1u64; + let max_rem_i = &self.limb_modulus * (self.overflow_limit - 1) - 1u64; assert!( &borrow * &self.limb_modulus - &borrow >= self.limbs * max_d_j * max_w_j + max_rem_i @@ -285,7 +285,10 @@ impl RangeInfo { let max_v = &self.limb_modulus * common_modulus - 1u64; let max_a_j = &self.limb_modulus * (self.overflow_limit - 1); let max_b_j = &max_a_j; - assert!(&max_v * &self.limb_modulus >= &max_a_j * max_b_j * self.limbs + &self.limb_modulus * &borrow); + assert!( + &max_v * &self.limb_modulus + >= &max_a_j * max_b_j * self.limbs + &self.limb_modulus * &borrow + ); // To avoid overflow // max(v) * limb_modulus < n_modulus