Skip to content

Commit

Permalink
[serialization] impl SerializeCanonical, DeserializeCanonical for Con…
Browse files Browse the repository at this point in the history
…straintSystem (#11)
  • Loading branch information
bergkvist authored Feb 17, 2025
1 parent 5a9ffd6 commit b3f3076
Show file tree
Hide file tree
Showing 25 changed files with 515 additions and 56 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ generic-array = "0.14.7"
getset = "0.1.2"
groestl_crypto = { package = "groestl", version = "0.10.1" }
hex-literal = "0.4.1"
inventory = "0.3.19"
itertools = "0.13.0"
lazy_static = "1.5.0"
paste = "1.0.15"
Expand Down
1 change: 1 addition & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ derive_more.workspace = true
digest.workspace = true
either.workspace = true
getset.workspace = true
inventory.workspace = true
itertools.workspace = true
rand.workspace = true
stackalloc.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/constraint_system/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};

pub type ChannelId = usize;

#[derive(Debug, Clone)]
#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)]
pub struct Flush {
pub oracles: Vec<OracleId>,
pub channel_id: ChannelId,
Expand Down
20 changes: 18 additions & 2 deletions crates/core/src/constraint_system/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ mod prove;
pub mod validate;
mod verify;

use binius_field::TowerField;
use binius_field::{serialization, BinaryField128b, DeserializeCanonical, TowerField};
use binius_macros::SerializeCanonical;
use channel::{ChannelId, Flush};
pub use prove::prove;
pub use verify::verify;
Expand All @@ -21,7 +22,7 @@ use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId};
///
/// As a result, a ConstraintSystem allows us to validate all of these
/// constraints against a witness, as well as enabling generic prove/verify
#[derive(Debug, Clone)]
#[derive(Debug, Clone, SerializeCanonical)]
pub struct ConstraintSystem<F: TowerField> {
pub oracles: MultilinearOracleSet<F>,
pub table_constraints: Vec<ConstraintSet<F>>,
Expand All @@ -30,6 +31,21 @@ pub struct ConstraintSystem<F: TowerField> {
pub max_channel_id: ChannelId,
}

impl DeserializeCanonical for ConstraintSystem<BinaryField128b> {
fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result<Self, serialization::Error>
where
Self: Sized,
{
Ok(Self {
oracles: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
table_constraints: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
non_zero_oracle_ids: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
flushes: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
max_channel_id: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
})
}
}

impl<F: TowerField> ConstraintSystem<F> {
pub const fn no_base_constraints(self) -> Self {
self
Expand Down
2 changes: 2 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ pub mod tower;
pub mod transcript;
pub mod transparent;
pub mod witness;

pub use inventory;
13 changes: 7 additions & 6 deletions crates/core/src/oracle/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use core::iter::IntoIterator;
use std::sync::Arc;

use binius_field::{Field, TowerField};
use binius_macros::{DeserializeCanonical, SerializeCanonical};
use binius_math::{ArithExpr, CompositionPolyOS};
use binius_utils::bail;
use itertools::Itertools;
Expand All @@ -15,23 +16,23 @@ use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId};
pub type TypeErasedComposition<P> = Arc<dyn CompositionPolyOS<P>>;

/// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube
#[derive(Debug, Clone)]
#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)]
pub struct Constraint<F: Field> {
pub name: Arc<str>,
pub name: String,
pub composition: ArithExpr<F>,
pub predicate: ConstraintPredicate<F>,
}

/// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality to zero
/// on the hypercube (zerocheck)
#[derive(Clone, Debug)]
#[derive(Clone, Debug, SerializeCanonical, DeserializeCanonical)]
pub enum ConstraintPredicate<F: Field> {
Sum(F),
Zero,
}

/// Constraint set is a group of constraints that operate over the same set of oracle-identified multilinears
#[derive(Debug, Clone)]
#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)]
pub struct ConstraintSet<F: Field> {
pub n_vars: usize,
pub oracle_ids: Vec<OracleId>,
Expand All @@ -41,7 +42,7 @@ pub struct ConstraintSet<F: Field> {
// A deferred constraint constructor that instantiates index composition after the superset of oracles is known
#[allow(clippy::type_complexity)]
struct UngroupedConstraint<F: Field> {
name: Arc<str>,
name: String,
oracle_ids: Vec<OracleId>,
composition: ArithExpr<F>,
predicate: ConstraintPredicate<F>,
Expand Down Expand Up @@ -82,7 +83,7 @@ impl<F: Field> ConstraintSetBuilder<F> {
composition: ArithExpr<F>,
) {
self.constraints.push(UngroupedConstraint {
name: name.to_string().into(),
name: name.to_string(),
oracle_ids: oracle_ids.into_iter().collect(),
composition,
predicate: ConstraintPredicate::Zero,
Expand Down
117 changes: 101 additions & 16 deletions crates/core/src/oracle/multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

use std::{array, fmt::Debug, sync::Arc};

use binius_field::{Field, TowerField};
use binius_field::{
serialization, BinaryField128b, DeserializeCanonical, Field, SerializeCanonical, TowerField,
};
use binius_macros::{DeserializeCanonical, SerializeCanonical};
use binius_utils::bail;
use getset::{CopyGetters, Getters};

Expand Down Expand Up @@ -280,9 +283,20 @@ impl<F: TowerField> MultilinearOracleSetAddition<'_, F> {
///
/// The oracle set also tracks the committed polynomial in batches where each batch is committed
/// together with a polynomial commitment scheme.
#[derive(Default, Debug, Clone)]
#[derive(Default, Debug, Clone, SerializeCanonical)]
pub struct MultilinearOracleSet<F: TowerField> {
oracles: Vec<Arc<MultilinearPolyOracle<F>>>,
oracles: Vec<MultilinearPolyOracle<F>>,
}

impl binius_field::DeserializeCanonical for MultilinearOracleSet<BinaryField128b> {
fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result<Self, serialization::Error>
where
Self: Sized,
{
Ok(Self {
oracles: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
})
}
}

impl<F: TowerField> MultilinearOracleSet<F> {
Expand Down Expand Up @@ -323,12 +337,11 @@ impl<F: TowerField> MultilinearOracleSet<F> {
oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle<F>,
) -> OracleId {
let id = self.oracles.len();

self.oracles.push(Arc::new(oracle(id)));
self.oracles.push(oracle(id));
id
}

fn get_from_set(&self, id: OracleId) -> Arc<MultilinearPolyOracle<F>> {
fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle<F> {
self.oracles[id].clone()
}

Expand Down Expand Up @@ -401,7 +414,7 @@ impl<F: TowerField> MultilinearOracleSet<F> {
}

pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle<F> {
(*self.oracles[id]).clone()
self.oracles[id].clone()
}

pub fn n_vars(&self, id: OracleId) -> usize {
Expand Down Expand Up @@ -438,7 +451,7 @@ impl<F: TowerField> MultilinearOracleSet<F> {
/// other oracles. This is formalized in [DP23] Section 4.
///
/// [DP23]: <https://eprint.iacr.org/2023/1784>
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical)]
pub struct MultilinearPolyOracle<F: TowerField> {
pub id: OracleId,
pub name: Option<String>,
Expand All @@ -447,7 +460,22 @@ pub struct MultilinearPolyOracle<F: TowerField> {
pub variant: MultilinearPolyVariant<F>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
impl DeserializeCanonical for MultilinearPolyOracle<BinaryField128b> {
fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result<Self, serialization::Error>
where
Self: Sized,
{
Ok(Self {
id: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
name: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
n_vars: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
tower_level: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
variant: DeserializeCanonical::deserialize_canonical(&mut read_buf)?,
})
}
}

#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical)]
pub enum MultilinearPolyVariant<F: TowerField> {
Committed,
Transparent(TransparentPolyOracle<F>),
Expand All @@ -459,6 +487,33 @@ pub enum MultilinearPolyVariant<F: TowerField> {
ZeroPadded(OracleId),
}

impl DeserializeCanonical for MultilinearPolyVariant<BinaryField128b> {
fn deserialize_canonical(mut buf: impl bytes::Buf) -> Result<Self, serialization::Error>
where
Self: Sized,
{
Ok(match u8::deserialize_canonical(&mut buf)? {
0 => Self::Committed,
1 => Self::Transparent(DeserializeCanonical::deserialize_canonical(&mut buf)?),
2 => Self::Repeating {
id: DeserializeCanonical::deserialize_canonical(&mut buf)?,
log_count: DeserializeCanonical::deserialize_canonical(&mut buf)?,
},
3 => Self::Projected(DeserializeCanonical::deserialize_canonical(&mut buf)?),
4 => Self::Shifted(DeserializeCanonical::deserialize_canonical(&mut buf)?),
5 => Self::Packed(DeserializeCanonical::deserialize_canonical(&mut buf)?),
6 => Self::LinearCombination(DeserializeCanonical::deserialize_canonical(&mut buf)?),
7 => Self::ZeroPadded(DeserializeCanonical::deserialize_canonical(&mut buf)?),
variant_index => {
return Err(serialization::Error::UnknownEnumVariant {
name: "MultilinearPolyVariant",
index: variant_index,
});
}
})
}
}

/// A transparent multilinear polynomial oracle.
///
/// See the [`MultilinearPolyOracle`] documentation for context.
Expand All @@ -468,6 +523,28 @@ pub struct TransparentPolyOracle<F: Field> {
poly: Arc<dyn MultivariatePoly<F>>,
}

impl<F: TowerField> SerializeCanonical for TransparentPolyOracle<F> {
fn serialize_canonical(
&self,
mut write_buf: impl bytes::BufMut,
) -> Result<(), binius_field::serialization::Error> {
self.poly.erased_serialize_canonical(&mut write_buf)
}
}

impl DeserializeCanonical for TransparentPolyOracle<BinaryField128b> {
fn deserialize_canonical(
mut read_buf: impl bytes::Buf,
) -> Result<Self, binius_field::serialization::Error>
where
Self: Sized,
{
let poly: Box<dyn MultivariatePoly<BinaryField128b>> =
DeserializeCanonical::deserialize_canonical(&mut read_buf)?;
Ok(Self { poly: poly.into() })
}
}

impl<F: TowerField> TransparentPolyOracle<F> {
fn new(poly: Arc<dyn MultivariatePoly<F>>) -> Result<Self, Error> {
if poly.binary_tower_level() > F::TOWER_LEVEL {
Expand All @@ -494,13 +571,15 @@ impl<F: Field> PartialEq for TransparentPolyOracle<F> {

impl<F: Field> Eq for TransparentPolyOracle<F> {}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)]
pub enum ProjectionVariant {
FirstVars,
LastVars,
}

#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)]
#[derive(
Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical,
)]
pub struct Projected<F: TowerField> {
#[get_copy = "pub"]
id: OracleId,
Expand Down Expand Up @@ -530,14 +609,16 @@ impl<F: TowerField> Projected<F> {
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)]
pub enum ShiftVariant {
CircularLeft,
LogicalLeft,
LogicalRight,
}

#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)]
#[derive(
Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical,
)]
pub struct Shifted {
#[get_copy = "pub"]
id: OracleId,
Expand Down Expand Up @@ -579,7 +660,9 @@ impl Shifted {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)]
#[derive(
Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical,
)]
pub struct Packed {
#[get_copy = "pub"]
id: OracleId,
Expand All @@ -593,7 +676,9 @@ pub struct Packed {
log_degree: usize,
}

#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)]
#[derive(
Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical,
)]
pub struct LinearCombination<F: TowerField> {
#[get_copy = "pub"]
n_vars: usize,
Expand All @@ -606,7 +691,7 @@ impl<F: TowerField> LinearCombination<F> {
fn new(
n_vars: usize,
offset: F,
inner: impl IntoIterator<Item = (Arc<MultilinearPolyOracle<F>>, F)>,
inner: impl IntoIterator<Item = (MultilinearPolyOracle<F>, F)>,
) -> Result<Self, Error> {
let inner = inner
.into_iter()
Expand Down
13 changes: 12 additions & 1 deletion crates/core/src/polynomial/multivariate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};

use binius_field::{Field, PackedField};
use binius_field::{serialization, Field, PackedField};
use binius_math::{
ArithExpr, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
};
use binius_utils::bail;
use bytes::BufMut;
use itertools::Itertools;
use rand::{rngs::StdRng, SeedableRng};

Expand All @@ -28,6 +29,16 @@ pub trait MultivariatePoly<P>: Debug + Send + Sync {

/// Returns the maximum binary tower level of all constants in the arithmetic expression.
fn binary_tower_level(&self) -> usize;

/// Serialize a type erased MultivariatePoly.
/// Since not every MultivariatePoly implements serialization, this defaults to returning an error.
fn erased_serialize_canonical(
&self,
write_buf: &mut dyn BufMut,
) -> Result<(), serialization::Error> {
let _ = write_buf;
Err(serialization::Error::SerializationNotImplemented)
}
}

/// Identity composition function $g(X) = X$.
Expand Down
Loading

0 comments on commit b3f3076

Please sign in to comment.