Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jolt-core): Adding new multiplication instructions #303

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jolt-core/src/jolt/instruction/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ pub mod bgeu;
pub mod bne;
pub mod lb;
pub mod lh;
pub mod mul;
pub mod mulhu;
pub mod mulu;
pub mod or;
pub mod sb;
pub mod sh;
Expand Down
108 changes: 108 additions & 0 deletions jolt-core/src/jolt/instruction/mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use ark_std::log2;
use rand::prelude::StdRng;
use rand::RngCore;
use serde::{Deserialize, Serialize};

use super::{JoltInstruction, SubtableIndices};
use crate::jolt::subtable::{
identity::IdentitySubtable, truncate_overflow::TruncateOverflowSubtable, LassoSubtable,
};
use crate::poly::field::JoltField;
use crate::utils::instruction_utils::{
assert_valid_parameters, concatenate_lookups, multiply_and_chunk_operands,
};

#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
pub struct MULInstruction<const WORD_SIZE: usize>(pub u64, pub u64);

impl<const WORD_SIZE: usize> JoltInstruction for MULInstruction<WORD_SIZE> {
fn operands(&self) -> (u64, u64) {
(self.0, self.1)
}

fn combine_lookups<F: JoltField>(&self, vals: &[F], C: usize, M: usize) -> F {
assert!(vals.len() == C);
concatenate_lookups(vals, C, log2(M) as usize)
}

fn g_poly_degree(&self, _: usize) -> usize {
1
}

fn subtables<F: JoltField>(
&self,
C: usize,
M: usize,
) -> Vec<(Box<dyn LassoSubtable<F>>, SubtableIndices)> {
let msb_chunk_index = C - (WORD_SIZE / log2(M) as usize) - 1;
vec![
(
Box::new(TruncateOverflowSubtable::<F, WORD_SIZE>::new()),
SubtableIndices::from(0..msb_chunk_index + 1),
),
(
Box::new(IdentitySubtable::new()),
SubtableIndices::from(msb_chunk_index + 1..C),
),
]
}

fn to_indices(&self, C: usize, log_M: usize) -> Vec<usize> {
assert_valid_parameters(WORD_SIZE, C, log_M);
multiply_and_chunk_operands(self.0 as u128, self.1 as u128, C, log_M)
}

fn lookup_entry(&self) -> u64 {
if WORD_SIZE == 32 {
let x = self.0 as i32;
let y = self.1 as i32;
x.wrapping_mul(y) as u32 as u64
} else if WORD_SIZE == 64 {
let x = self.0 as i64;
let y = self.1 as i64;
x.wrapping_mul(y) as u64
} else {
panic!("only implemented for u32 / u64")
}
}

fn random(&self, rng: &mut StdRng) -> Self {
Self(rng.next_u32() as u64, rng.next_u32() as u64)
}
}

#[cfg(test)]
mod test {
use ark_bn254::Fr;
use ark_std::test_rng;
use rand_chacha::rand_core::RngCore;

use super::MULInstruction;
use crate::{jolt::instruction::JoltInstruction, jolt_instruction_test};

#[test]
fn mul_instruction_32_e2e() {
let mut rng = test_rng();
const C: usize = 4;
const M: usize = 1 << 16;

for _ in 0..256 {
let (x, y) = (rng.next_u32() as u64, rng.next_u32() as u64);
let instruction = MULInstruction::<32>(x, y);
jolt_instruction_test!(instruction);
}
}

#[test]
fn mul_instruction_64_e2e() {
let mut rng = test_rng();
const C: usize = 8;
const M: usize = 1 << 16;

for _ in 0..256 {
let (x, y) = (rng.next_u32() as u64, rng.next_u32() as u64);
let instruction = MULInstruction::<64>(x, y);
jolt_instruction_test!(instruction);
}
}
}
105 changes: 105 additions & 0 deletions jolt-core/src/jolt/instruction/mulhu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use ark_std::log2;
use rand::prelude::StdRng;
use rand::RngCore;
use serde::{Deserialize, Serialize};

use super::{JoltInstruction, SubtableIndices};
use crate::jolt::subtable::{
identity::IdentitySubtable, truncate_overflow::TruncateOverflowSubtable, LassoSubtable,
};
use crate::poly::field::JoltField;
use crate::utils::instruction_utils::{
assert_valid_parameters, concatenate_lookups, multiply_and_chunk_operands,
};

#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
pub struct MULHUInstruction<const WORD_SIZE: usize>(pub u64, pub u64);

impl<const WORD_SIZE: usize> JoltInstruction for MULHUInstruction<WORD_SIZE> {
fn operands(&self) -> (u64, u64) {
(self.0, self.1)
}

fn combine_lookups<F: JoltField>(&self, vals: &[F], C: usize, M: usize) -> F {
assert!(vals.len() == C);
concatenate_lookups(vals, C, log2(M) as usize)
}

fn g_poly_degree(&self, _: usize) -> usize {
1
}

fn subtables<F: JoltField>(
&self,
C: usize,
M: usize,
) -> Vec<(Box<dyn LassoSubtable<F>>, SubtableIndices)> {
let msb_chunk_index = C - (WORD_SIZE / log2(M) as usize) - 1;
// Reversed the order of the subtable indices compared to MUL and MULU
vec![
(
Box::new(TruncateOverflowSubtable::<F, WORD_SIZE>::new()),
SubtableIndices::from(msb_chunk_index + 1..C),
),
(
Box::new(IdentitySubtable::new()),
SubtableIndices::from(0..msb_chunk_index + 1),
),
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our current parameters C = 4 and M = 2^16, the TruncateOverflowSubtable here is just all zeros (which is what we want), but I think this would be broken for C and M where C * log2(M) != 2 * WORD_SIZE
For now, let's just add an assert_eq! and remove the TruncateOverflowSubtable

Suggested change
let msb_chunk_index = C - (WORD_SIZE / log2(M) as usize) - 1;
// Reversed the order of the subtable indices compared to MUL and MULU
vec![
(
Box::new(TruncateOverflowSubtable::<F, WORD_SIZE>::new()),
SubtableIndices::from(msb_chunk_index + 1..C),
),
(
Box::new(IdentitySubtable::new()),
SubtableIndices::from(0..msb_chunk_index + 1),
),
]
assert_eq!(C * log2(M), 2 * WORD_SIZE);
vec![
(
Box::new(IdentitySubtable::new()),
SubtableIndices::from(0..C / 2),
),
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(there are existing instructions that assume C = 4 and M = 2^16, so I think this is fine for now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, went with your approach. I also had to change the concatenate looks param to take in vals.len() instead of C for it to not panic.

}

fn to_indices(&self, C: usize, log_M: usize) -> Vec<usize> {
assert_valid_parameters(WORD_SIZE, C, log_M);
multiply_and_chunk_operands(self.0 as u128, self.1 as u128, C, log_M)
}

fn lookup_entry(&self) -> u64 {
if WORD_SIZE == 32 {
(self.0).wrapping_mul(self.1) >> 32
} else if WORD_SIZE == 64 {
((self.0 as u128).wrapping_mul(self.1 as u128) >> 64) as u64
} else {
panic!("only implemented for u32 / u64")
}
}

fn random(&self, rng: &mut StdRng) -> Self {
Self(rng.next_u32() as u64, rng.next_u32() as u64)
}
}

#[cfg(test)]
mod test {
use ark_bn254::Fr;
use ark_std::test_rng;
use rand_chacha::rand_core::RngCore;

use super::MULHUInstruction;
use crate::{jolt::instruction::JoltInstruction, jolt_instruction_test};

#[test]
fn mulhu_instruction_32_e2e() {
let mut rng = test_rng();
const C: usize = 4;
const M: usize = 1 << 16;

for _ in 0..256 {
let (x, y) = (rng.next_u32() as u64, rng.next_u32() as u64);
let instruction = MULHUInstruction::<32>(x, y);
jolt_instruction_test!(instruction);
}
}

#[test]
fn mulhu_instruction_64_e2e() {
let mut rng = test_rng();
const C: usize = 8;
const M: usize = 1 << 16;

for _ in 0..256 {
let (x, y) = (rng.next_u64(), rng.next_u64());
let instruction = MULHUInstruction::<64>(x, y);
jolt_instruction_test!(instruction);
}
}
}
104 changes: 104 additions & 0 deletions jolt-core/src/jolt/instruction/mulu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use ark_std::log2;
use rand::prelude::StdRng;
use rand::RngCore;
use serde::{Deserialize, Serialize};

use super::{JoltInstruction, SubtableIndices};
use crate::jolt::subtable::{
identity::IdentitySubtable, truncate_overflow::TruncateOverflowSubtable, LassoSubtable,
};
use crate::poly::field::JoltField;
use crate::utils::instruction_utils::{
assert_valid_parameters, concatenate_lookups, multiply_and_chunk_operands,
};

#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
pub struct MULUInstruction<const WORD_SIZE: usize>(pub u64, pub u64);

impl<const WORD_SIZE: usize> JoltInstruction for MULUInstruction<WORD_SIZE> {
fn operands(&self) -> (u64, u64) {
(self.0, self.1)
}

fn combine_lookups<F: JoltField>(&self, vals: &[F], C: usize, M: usize) -> F {
assert!(vals.len() == C);
concatenate_lookups(vals, C, log2(M) as usize)
}

fn g_poly_degree(&self, _: usize) -> usize {
1
}

fn subtables<F: JoltField>(
&self,
C: usize,
M: usize,
) -> Vec<(Box<dyn LassoSubtable<F>>, SubtableIndices)> {
let msb_chunk_index = C - (WORD_SIZE / log2(M) as usize) - 1;
vec![
(
Box::new(TruncateOverflowSubtable::<F, WORD_SIZE>::new()),
SubtableIndices::from(0..msb_chunk_index + 1),
),
(
Box::new(IdentitySubtable::new()),
SubtableIndices::from(msb_chunk_index + 1..C),
),
]
}

fn to_indices(&self, C: usize, log_M: usize) -> Vec<usize> {
assert_valid_parameters(WORD_SIZE, C, log_M);
multiply_and_chunk_operands(self.0 as u128, self.1 as u128, C, log_M)
}

fn lookup_entry(&self) -> u64 {
if WORD_SIZE == 32 {
self.0.wrapping_mul(self.1) as u32 as u64
} else if WORD_SIZE == 64 {
self.0.wrapping_mul(self.1)
} else {
panic!("only implemented for u32 / u64")
}
}

fn random(&self, rng: &mut StdRng) -> Self {
Self(rng.next_u32() as u64, rng.next_u32() as u64)
}
}

#[cfg(test)]
mod test {
use ark_bn254::Fr;
use ark_std::test_rng;
use rand_chacha::rand_core::RngCore;

use super::MULUInstruction;
use crate::{jolt::instruction::JoltInstruction, jolt_instruction_test};

#[test]
fn mulu_instruction_32_e2e() {
let mut rng = test_rng();
const C: usize = 4;
const M: usize = 1 << 16;

for _ in 0..256 {
let (x, y) = (rng.next_u32() as u64, rng.next_u32() as u64);
let instruction = MULUInstruction::<32>(x, y);
jolt_instruction_test!(instruction);
}
}

#[test]
fn mulu_instruction_64_e2e() {
let mut rng = test_rng();
const C: usize = 8;
const M: usize = 1 << 16;

for _ in 0..256 {
let (x, y) = (rng.next_u32() as u64, rng.next_u32() as u64);
let instruction = MULUInstruction::<64>(x, y);
jolt_instruction_test!(instruction);
}
}
}
14 changes: 14 additions & 0 deletions jolt-core/src/utils/instruction_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ pub fn add_and_chunk_operands(x: u128, y: u128, C: usize, log_M: usize) -> Vec<u
.collect()
}

/// Chunks `z` into `C` chunks bitwise where `z = x * y`.
/// `log_M` is the number of bits for each of the `C` chunks of `z`.
pub fn multiply_and_chunk_operands(x: u128, y: u128, C: usize, log_M: usize) -> Vec<usize> {
let product_chunk_bits: usize = log_M;
let product_chunk_bit_mask: usize = (1 << product_chunk_bits) - 1;
let z: u128 = x * y;
(0..C)
.map(|i| {
let shift = ((C - i - 1) * product_chunk_bits) as u32;
z.checked_shr(shift).unwrap_or(0) as usize & product_chunk_bit_mask
})
.collect()
}

/// Splits `x`, `y` into `C` chunks and writes [ x_{C-1} || y_0, ..., x_0 || y_0 ]
/// where `x_{C-1}`` is the the big end of `x``, and `y_0`` is the small end of `y`.
pub fn chunk_and_concatenate_for_shift(x: u64, y: u64, C: usize, log_M: usize) -> Vec<usize> {
Expand Down
Loading