Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Avoid ejecting the scalar value when loading and storing a circuit. #2534

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ jobs:
resource_class: << pipeline.parameters.twoxlarge >>
steps:
- run_serial:
flags: --test '*' -- --skip keccak --skip psd --skip sha --skip instruction::is --skip instruction::equal --skip instruction::commit
flags: --test '*' -- --skip keccak --skip psd --skip sha --skip instruction::is --skip instruction::equal --skip instruction::commit --skip instruction::cast
workspace_member: synthesizer/program
cache_key: snarkvm-synthesizer-program-cache

Expand Down Expand Up @@ -782,6 +782,16 @@ jobs:
workspace_member: synthesizer/program
cache_key: snarkvm-synthesizer-program-cache

synthesizer-program-integration-instruction-cast:
docker:
- image: cimg/rust:1.76.0 # Attention - Change the MSRV in Cargo.toml and rust-toolchain as well
resource_class: << pipeline.parameters.xlarge >>
steps:
- run_serial:
flags: instruction::cast --test '*'
workspace_member: synthesizer/program
cache_key: snarkvm-synthesizer-program-cache

synthesizer-snark:
docker:
- image: cimg/rust:1.76.0 # Attention - Change the MSRV in Cargo.toml and rust-toolchain as well
Expand Down Expand Up @@ -962,6 +972,7 @@ workflows:
- synthesizer-program-integration-instruction-is
- synthesizer-program-integration-instruction-equal
- synthesizer-program-integration-instruction-commit
- synthesizer-program-integration-instruction-cast
- synthesizer-snark
- utilities
- utilities-derives
Expand Down
28 changes: 28 additions & 0 deletions ledger/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,34 @@ finalize is_id:
ledger.advance_to_next_block(&block_3).unwrap();
}

#[test]
fn test_deployment_with_cast_from_field_to_scalar() {
// Initialize an RNG.
let rng = &mut TestRng::default();

// Initialize the test environment.
let crate::test_helpers::TestEnv { ledger, private_key, .. } = crate::test_helpers::sample_test_env(rng);

// Construct a program that casts a field to a scalar.
let program = Program::<CurrentNetwork>::from_str(
r"
program test_cast_field_to_scalar.aleo;
function foo:
input r0 as field.public;
cast r0 into r1 as scalar;",
)
.unwrap();

// Deploy the program.
let deployment = ledger.vm().deploy(&private_key, &program, None, 0, None, rng).unwrap();

// Verify the deployment under different RNGs to ensure the deployment is valid.
for _ in 0..20 {
let rng = &mut TestRng::default();
assert!(ledger.vm().check_transaction(&deployment, None, rng).is_ok());
}
}

// These tests require the proof targets to be low enough to be able to generate **valid** solutions.
// This requires the 'test' feature to be enabled for the `console` dependency.
#[cfg(feature = "test")]
Expand Down
18 changes: 17 additions & 1 deletion synthesizer/process/src/stack/registers/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,23 @@ impl<N: Network, A: circuit::Aleo<Network = N>> RegistersLoadCircuit<N, A> for R
match self.register_types.get_type(stack, register) {
// Ensure the stack value matches the register type.
Ok(register_type) => {
stack.matches_register_type(&circuit::Eject::eject_value(&circuit_value), &register_type)?
// Check if the register type and circuit value are both scalar.
let register_type_is_scalar =
matches!(register_type, RegisterType::Plaintext(PlaintextType::Literal(LiteralType::Scalar)));
let circuit_value_is_scalar = matches!(
circuit_value,
circuit::Value::Plaintext(circuit::Plaintext::Literal(circuit::Literal::Scalar(_), _))
);

// Check if the register type matches the type in the circuit value.
// We do a special check for scalar values, as there is a possibility of an overflow via
// field to scalar conversion in deployment verification.
match register_type_is_scalar && circuit_value_is_scalar {
true => {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above states that We do a special check for scalar values however there is no check here on L174 for the scalar case. Can you clarify what exactly should be checked here?

false => {
stack.matches_register_type(&circuit::Eject::eject_value(&circuit_value), &register_type)?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other types that might have scalars internally?

e.g. Structs that have scalar elements or arrays of scalars?

}
}
}
// Ensure the register is defined.
Err(error) => bail!("Register '{register}' is not a member of the function: {error}"),
Expand Down
2 changes: 1 addition & 1 deletion synthesizer/process/src/stack/registers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod store;
use crate::{CallStack, RegisterTypes, RegistersCall};
use console::{
network::prelude::*,
program::{Entry, Literal, Plaintext, Register, Value},
program::{Entry, Literal, LiteralType, Plaintext, PlaintextType, Register, RegisterType, Value},
types::{Address, Field},
};
use synthesizer_program::{
Expand Down
19 changes: 18 additions & 1 deletion synthesizer/process/src/stack/registers/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,24 @@ impl<N: Network, A: circuit::Aleo<Network = N>> RegistersStoreCircuit<N, A> for
match self.register_types.get_type(stack, register) {
// Ensure the stack value matches the register type.
Ok(register_type) => {
stack.matches_register_type(&circuit::Eject::eject_value(&circuit_value), &register_type)?
// Check if the register type and circuit value are both scalar.
let register_type_is_scalar = matches!(
register_type,
RegisterType::Plaintext(PlaintextType::Literal(LiteralType::Scalar))
);
let circuit_value_is_scalar = matches!(
circuit_value,
circuit::Value::Plaintext(circuit::Plaintext::Literal(circuit::Literal::Scalar(_), _))
);

// Check if the register type matches the type in the circuit value.
// We do a special check for scalar values, as there is a possibility of an overflow via
// field to scalar conversion in deployment verification.
match register_type_is_scalar && circuit_value_is_scalar {
true => {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above states that We do a special check for scalar values however there is no check here on L107 for the scalar case. Can you clarify what exactly should be checked here?

false => stack
.matches_register_type(&circuit::Eject::eject_value(&circuit_value), &register_type)?,
}
}
// Ensure the register is defined.
Err(error) => bail!("Register '{register}' is missing a type definition: {error}"),
Expand Down
19 changes: 19 additions & 0 deletions synthesizer/program/src/logic/instruction/operation/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ pub struct CastOperation<N: Network, const VARIANT: u8> {
}

impl<N: Network, const VARIANT: u8> CastOperation<N, VARIANT> {
/// Initializes a new `cast` instruction.
#[inline]
pub fn new(operands: Vec<Operand<N>>, destination: Register<N>, cast_type: CastType<N>) -> Result<Self> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this new method only use for testing purposes?

If so, I would advise as an independent reviewer that this method be denoted as CastOperation::new_testing_only with a #[cfg(test)] guard on it.

// Ensure the number of operands is within the bounds.
let max_operands = match cast_type {
CastType::GroupYCoordinate
| CastType::GroupXCoordinate
| CastType::Plaintext(PlaintextType::Literal(_)) => 1,
CastType::Plaintext(PlaintextType::Struct(_)) => N::MAX_STRUCT_ENTRIES,
CastType::Plaintext(PlaintextType::Array(_)) => N::MAX_ARRAY_ELEMENTS,
CastType::Record(_) | CastType::ExternalRecord(_) => N::MAX_RECORD_ENTRIES,
};
if operands.is_empty() || operands.len() > max_operands {
bail!("The number of operands must be nonzero and <= {max_operands}");
}
// Return the instruction.
Ok(Self { operands, destination, cast_type })
}

/// Returns the opcode.
#[inline]
pub const fn opcode() -> Opcode {
Expand Down
221 changes: 221 additions & 0 deletions synthesizer/program/tests/instruction/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright (C) 2019-2023 Aleo Systems Inc.
// This file is part of the snarkVM library.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

include!("../helpers/macros.rs");

use crate::helpers::sample::{sample_finalize_registers, sample_registers};

use circuit::{AleoV0, Eject};
use console::{
network::MainnetV0,
prelude::*,
program::{Identifier, Literal, LiteralType, Plaintext, PlaintextType, Register, Value},
};
use snarkvm_synthesizer_program::{
Cast,
CastLossy,
CastOperation,
CastType,
Opcode,
Operand,
Program,
RegistersLoad,
RegistersLoadCircuit,
};
use synthesizer_process::{Process, Stack};

type CurrentNetwork = MainnetV0;
type CurrentAleo = AleoV0;

const ITERATIONS: usize = 25;

fn valid_cast_types<N: Network>() -> &'static [CastType<N>] {
&[
CastType::Plaintext(PlaintextType::Literal(LiteralType::Address)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::Boolean)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::Field)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::Group)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::I8)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::I16)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::I32)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::I64)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::I128)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::U8)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::U16)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::U32)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::U64)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::U128)),
CastType::Plaintext(PlaintextType::Literal(LiteralType::Scalar)),
]
}

/// Samples the stack. Note: Do not replicate this for real program use, it is insecure.
#[allow(clippy::type_complexity)]
fn sample_stack(
opcode: Opcode,
type_: LiteralType,
mode: circuit::Mode,
cast_type: CastType<CurrentNetwork>,
) -> Result<(Stack<CurrentNetwork>, Vec<Operand<CurrentNetwork>>, Register<CurrentNetwork>)> {
// Initialize the opcode.
let opcode = opcode.to_string();

// Initialize the function name.
let function_name = Identifier::<CurrentNetwork>::from_str("run")?;

// Initialize the registers.
let r0 = Register::Locator(0);
let r1 = Register::Locator(1);

// Initialize the program.
let program = Program::from_str(&format!(
"program testing.aleo;
function {function_name}:
input {r0} as {type_}.{mode};
{opcode} {r0} into {r1} as {cast_type};
async {function_name} {r0} into r2;
output r2 as testing.aleo/{function_name}.future;
finalize {function_name}:
input {r0} as {type_}.public;
{opcode} {r0} into {r1} as {cast_type};
"
))?;

// Initialize the operands.
let operands = vec![Operand::Register(r0)];

// Initialize the stack.
let stack = Stack::new(&Process::load()?, &program)?;

Ok((stack, operands, r1))
}

fn check_cast<const VARIANT: u8>(
operation: impl FnOnce(
Vec<Operand<CurrentNetwork>>,
Register<CurrentNetwork>,
CastType<CurrentNetwork>,
) -> CastOperation<CurrentNetwork, VARIANT>,
opcode: Opcode,
literal: &Literal<CurrentNetwork>,
mode: &circuit::Mode,
cast_type: CastType<CurrentNetwork>,
) {
println!("Checking '{opcode}' for '{literal}.{mode}'");

// Initialize the types.
let type_ = literal.to_type();

// Initialize the stack.
let (stack, operands, destination) = sample_stack(opcode, type_, *mode, cast_type.clone()).unwrap();

// Initialize the operation.
let operation = operation(operands, destination.clone(), cast_type.clone());
// Initialize the function name.
let function_name = Identifier::from_str("run").unwrap();
// Initialize a destination operand.
let destination_operand = Operand::Register(destination);

// Attempt to evaluate the valid operand case.
let mut evaluate_registers = sample_registers(&stack, &function_name, &[(literal, None)]).unwrap();
let result_a = operation.evaluate(&stack, &mut evaluate_registers);

// Attempt to execute the valid operand case.
let mut execute_registers = sample_registers(&stack, &function_name, &[(literal, Some(*mode))]).unwrap();
let result_b = operation.execute::<CurrentAleo>(&stack, &mut execute_registers);
let circuit_is_satisfied = <CurrentAleo as circuit::Environment>::is_satisfied();

// Attempt to finalize the valid operand case.
let mut finalize_registers = sample_finalize_registers(&stack, &function_name, &[literal]).unwrap();
let result_c = operation.finalize(&stack, &mut finalize_registers);

// Check that either all operations failed, or all operations succeeded.
let all_failed = result_a.is_err() && (result_b.is_err() || !circuit_is_satisfied) && result_c.is_err();
let all_succeeded = result_a.is_ok() && (result_b.is_ok() && circuit_is_satisfied) && result_c.is_ok();
assert!(
all_failed || all_succeeded,
"The results of the evaluation, execution, and finalization should either all succeed or all fail"
);

// If all operations succeeded, check that the outputs are consistent.
if all_succeeded {
// Retrieve the output of evaluation.
let output_a = evaluate_registers.load(&stack, &destination_operand).unwrap();

// Retrieve the output of execution.
let output_b = execute_registers.load_circuit(&stack, &destination_operand).unwrap();

// Retrieve the output of finalization.
let output_c = finalize_registers.load(&stack, &destination_operand).unwrap();

// Check that the outputs are consistent.
assert_eq!(output_a, output_b.eject_value(), "The results of the evaluation and execution are inconsistent");
assert_eq!(output_a, output_c, "The results of the evaluation and finalization are inconsistent");

// Check that the output type is consistent with the declared type.
match output_a {
Value::Plaintext(Plaintext::Literal(literal, _)) => {
assert_eq!(
CastType::Plaintext(PlaintextType::Literal(literal.to_type())),
cast_type,
"The output type is inconsistent with the declared type"
);
}
_ => unreachable!("The output type is inconsistent with the declared type"),
}
}

// Reset the circuit.
<CurrentAleo as circuit::Environment>::reset();
}

macro_rules! test_cast_operation {
($name: tt, $cast:ident, $iterations:expr) => {
paste::paste! {
#[test]
fn [<test _ $name _ is _ consistent>]() {
// Initialize the operation.
let operation = |operands, destination, destination_type| $cast::<CurrentNetwork>::new(operands, destination, destination_type).unwrap();
// Initialize the opcode.
let opcode = $cast::<CurrentNetwork>::opcode();

// Prepare the rng.
let mut rng = TestRng::default();

// Prepare the test.
let modes = [circuit::Mode::Public, circuit::Mode::Private];

for _ in 0..$iterations {
let literals = sample_literals!(CurrentNetwork, &mut rng);
for literal in literals.iter() {
for mode in modes.iter() {
for cast_type in valid_cast_types() {
check_cast(
operation,
opcode,
literal,
mode,
cast_type.clone(),
);
}
}
}
}
}
}
};
}

test_cast_operation!(cast, Cast, ITERATIONS);
test_cast_operation!(cast_lossy, CastLossy, ITERATIONS);
1 change: 1 addition & 0 deletions synthesizer/program/tests/instruction/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

mod assert;
mod cast;
mod commit;
mod hash;
mod is;