diff --git a/Cargo.lock b/Cargo.lock index d5f7f6b..d6c6147 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,7 +153,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "circomspect" -version = "0.5.6" +version = "0.6.0" dependencies = [ "anyhow", "atty", @@ -177,7 +177,7 @@ dependencies = [ [[package]] name = "circomspect-parser" -version = "2.0.2" +version = "2.0.8" dependencies = [ "circomspect-program-structure", "lalrpop", @@ -193,7 +193,7 @@ dependencies = [ [[package]] name = "circomspect-program-analysis" -version = "0.5.2" +version = "0.6.0" dependencies = [ "anyhow", "circomspect-parser", @@ -205,7 +205,7 @@ dependencies = [ [[package]] name = "circomspect-program-structure" -version = "2.0.2" +version = "2.0.8" dependencies = [ "anyhow", "atty", diff --git a/README.md b/README.md index c42caf7..3175f5a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,8 @@ To output the results to a Sarif file (which can be read by the [VSCode Sarif Vi ![VSCode example image](https://github.com/trailofbits/circomspect/raw/main/doc/vscode.png) +Circomspect supports the same curves that Circom does: BN128, BLS12-381, and Ed448-Goldilocks. If you are using a different curve than the default (BN128) you can set the curve using the command line option `--curve`. + ## Analysis Passes The project currently implements analysis passes for the following types of issues. @@ -120,7 +122,7 @@ the prime. If not, there may be multiple correct representations of the input which could cause issues, since we typically expect the circuit output to be uniquely determined by the input. -For example, Suppose that we create a component `n2b` given by `Num2Bits(254)` and set the input to `1`. Now, both the binary representation of `1` _and_ the representation of `p + 1` will satisfy the circuit, since both are 254-bit numbers. If you cannot restrict the input size below 254 bits you should use the strict versions `Num2Bits_strict` and `Bits2Num_strict` to convert to and from binary representation. Circomspect will generate a warning if it cannot prove (using constant propagation) that the input size passed to `Num2Bits` or `Bits2Num` is less than 254 bits. +For example, Suppose that we create a component `n2b` given by `Num2Bits(254)` and set the input to `1`. Now, both the binary representation of `1` _and_ the representation of `p + 1` will satisfy the circuit over BN128, since both are 254-bit numbers. If you cannot restrict the input size below the prime size you should use the strict versions `Num2Bits_strict` and `Bits2Num_strict` to convert to and from binary representation. Circomspect will generate a warning if it cannot prove (using constant propagation) that the input size passed to `Num2Bits` or `Bits2Num` is less than the size of the prime in bits. #### Overly complex functions or templates (Warning) diff --git a/circom_algebra/Cargo.toml b/circom_algebra/Cargo.toml index ac48506..db85cfb 100644 --- a/circom_algebra/Cargo.toml +++ b/circom_algebra/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "circomspect-circom-algebra" version = "2.0.0" -authors = ["hermeGarcia "] edition = "2018" license = "LGPL-3.0-only" +authors = ["hermeGarcia "] description = "Support crate for the Circomspect static analyzer" repository = "https://github.com/trailofbits/circomspect" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 9aaa6f3..ea04d32 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "circomspect" -version = "0.5.6" +version = "0.6.0" edition = "2021" license = "LGPL-3.0-only" authors = ["Trail of Bits"] @@ -14,9 +14,9 @@ anyhow = "1.0" atty = "0.2.14" clap = { version = "3.2", features = ["derive"] } log = "0.4" -parser = { package = "circomspect-parser", version = "2.0.2", path = "../parser" } +parser = { package = "circomspect-parser", version = "2.0.8", path = "../parser" } pretty_env_logger = "0.4" -program_analysis = { package = "circomspect-program-analysis", version = "0.5.2", path = "../program_analysis" } -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure" } +program_analysis = { package = "circomspect-program-analysis", version = "0.6.0", path = "../program_analysis" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure" } serde_json = "1.0.81" termcolor = "1.1.3" diff --git a/cli/src/main.rs b/cli/src/main.rs index 27deab7..0e41a1c 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -1,6 +1,7 @@ use anyhow::Result; use clap::{CommandFactory, Parser}; use parser::ParseResult; +use program_structure::constants::Curve; use std::io::Write; use std::path::PathBuf; use std::process::ExitCode; @@ -15,8 +16,9 @@ use program_structure::function_data::FunctionInfo; use program_structure::report_writer::{StdoutWriter, ReportWriter, SarifWriter}; use program_structure::template_data::TemplateInfo; -const COMPILER_VERSION: &str = "2.0.3"; +const COMPILER_VERSION: &str = "2.0.8"; const DEFAULT_LEVEL: &str = "WARNING"; +const DEFAULT_CURVE: &str = "BN128"; #[derive(Parser, Debug)] /// A static analyzer and linter for Circom programs. @@ -40,10 +42,18 @@ struct Cli { /// Enable verbose output #[clap(short = 'v', long = "verbose")] verbose: bool, + + /// Set curve (BN128, BLS12_381, or GOLDILOCKS) + #[clap(short = 'c', long = "curve", name = "NAME", default_value = DEFAULT_CURVE)] + curve: Curve, } -fn generate_cfg(ast: Ast, reports: &mut ReportCollection) -> Result { - ast.into_cfg(reports).map_err(Report::from)?.into_ssa().map_err(Report::from) +fn generate_cfg( + ast: Ast, + curve: &Curve, + reports: &mut ReportCollection, +) -> Result { + ast.into_cfg(curve, reports).map_err(Report::from)?.into_ssa().map_err(Report::from) } fn analyze_cfg(cfg: &Cfg, reports: &mut ReportCollection) { @@ -52,8 +62,8 @@ fn analyze_cfg(cfg: &Cfg, reports: &mut ReportCollection) { } } -fn analyze_ast(ast: Ast, reports: &mut ReportCollection) { - match generate_cfg(ast, reports) { +fn analyze_ast(ast: Ast, curve: &Curve, reports: &mut ReportCollection) { + match generate_cfg(ast, curve, reports) { Ok(cfg) => { analyze_cfg(&cfg, reports); } @@ -67,6 +77,7 @@ fn analyze_definitions( functions: &FunctionInfo, templates: &TemplateInfo, file_library: &FileLibrary, + curve: &Curve, writer: &mut StdoutWriter, ) -> ReportCollection { let mut all_reports = ReportCollection::new(); @@ -75,7 +86,7 @@ fn analyze_definitions( for (name, function) in functions { log_message(&format!("analyzing function '{name}'")); let mut new_reports = ReportCollection::new(); - analyze_ast(function, &mut new_reports); + analyze_ast(function, curve, &mut new_reports); writer.write(&new_reports, file_library); all_reports.extend(new_reports); } @@ -83,7 +94,7 @@ fn analyze_definitions( for (name, template) in templates { log_message(&format!("analyzing template '{name}'")); let mut new_reports = ReportCollection::new(); - analyze_ast(template, &mut new_reports); + analyze_ast(template, curve, &mut new_reports); writer.write(&new_reports, file_library); all_reports.extend(new_reports); } @@ -139,6 +150,7 @@ fn main() -> ExitCode { &program.functions, &program.templates, &program.file_library, + &options.curve, &mut writer, )); program.file_library @@ -151,6 +163,7 @@ fn main() -> ExitCode { &library.functions, &library.templates, &library.file_library, + &options.curve, &mut writer, )); library.file_library diff --git a/parser/Cargo.toml b/parser/Cargo.toml index b412826..90b52d3 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "circomspect-parser" -version = "2.0.2" +version = "2.0.8" edition = "2018" build = "build.rs" license = "LGPL-3.0-only" @@ -15,7 +15,7 @@ num-bigint-dig = "0.6.0" num-traits = "0.2.6" [dependencies] -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure" } lalrpop = { version = "0.18.1", features = ["lexer"] } lalrpop-util = "0.18.1" log = "0.4" @@ -27,4 +27,4 @@ serde = "1.0.82" serde_derive = "1.0.91" [dev-dependencies] -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure" } diff --git a/parser/src/lang.lalrpop b/parser/src/lang.lalrpop index 0097867..83b75e6 100644 --- a/parser/src/lang.lalrpop +++ b/parser/src/lang.lalrpop @@ -27,6 +27,10 @@ ParsePragma : Version = { // maybe change to usize instead of BigInt => version, }; +// Pragma to indicate that we are allowing the definition of custom templates. +ParseCustomGates : () = { + "pragma" "custom_templates" ";" => () +} // Includes are added at the start of the file. // Their structure is the following: #include "path to the file" @@ -36,22 +40,14 @@ ParseInclude : Include = { }; // Parsing a program requires: -// Parsing the "pragma" instruction, if there is one -// Parsing "includes" instructions, if there are any, -// Parsing function and template definitions, +// Parsing the version pragma, if there is one +// Parsing the custom templates pragma, if there is one +// Parsing "includes" instructions, if there is anyone +// Parsing function and template definitions // Parsing the declaration of the main component pub ParseAst : AST = { - - => AST::new(Meta::new(s,e), Some(pragma), includes, definitions, Some(main)), - - - => AST::new(Meta::new(s,e), None, includes, definitions, Some(main)), - - - => AST::new(Meta::new(s,e), Some(pragma), includes, definitions, None), - - - => AST::new(Meta::new(s,e), None, includes, definitions, None), + + => AST::new(Meta::new(s,e), version, custom_gates.is_some(), includes, definitions, main), }; // ==================================================================== @@ -81,12 +77,12 @@ pub ParseDefinition : Definition = { Some(a) => build_function(Meta::new(s,e),name,a,args..arge,body), }, - "template" "(" ")" + "template" "(" ")" => match arg_names { None - => build_template(Meta::new(s,e), name, Vec::new(), args..arge, body, parallel.is_some()), + => build_template(Meta::new(s,e), name, Vec::new(), args..arge, body, parallel.is_some(), custom_gate.is_some()), Some(a) - => build_template(Meta::new(s,e), name, a, args..arge, body, parallel.is_some()), + => build_template(Meta::new(s,e), name, a, args..arge, body, parallel.is_some(), custom_gate.is_some()), }, }; @@ -134,7 +130,7 @@ SimpleSymbol : Symbol = { => Symbol { name, is_array: dims, - init: Option::None, + init: None, }, } @@ -143,15 +139,39 @@ ComplexSymbol : Symbol = { => Symbol { name, is_array: dims, - init: Option::Some(rhe), + init: Some(rhe), + }, +}; + +SignalConstraintSymbol : Symbol = { + "<==" + => Symbol { + name, + is_array: dims, + init: Some(rhe), }, }; +SignalSimpleSymbol : Symbol = { + "<--" + => Symbol { + name, + is_array: dims, + init: Some(rhe), + }, +}; + + SomeSymbol : Symbol = { ComplexSymbol, SimpleSymbol, } +SignalSymbol : Symbol = { + SimpleSymbol, + SignalConstraintSymbol, +} + // A declaration is the definition of a type followed by the initialization ParseDeclaration : Statement = { @@ -160,21 +180,28 @@ ParseDeclaration : Statement = { let meta = Meta::new(s,e); let xtype = VariableType::Var; symbols.push(symbol); - ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols) + ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignVar) }, "component" ",")*> => { let mut symbols = symbols; let meta = Meta::new(s,e); let xtype = VariableType::Component; symbols.push(symbol); - ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols) + ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignVar) }, - ",")*> + ",")*> => { let mut symbols = symbols; let meta = Meta::new(s,e); symbols.push(symbol); - ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols) + ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignConstraintSignal) + }, + ",")*> + => { + let mut symbols = symbols; + let meta = Meta::new(s,e); + symbols.push(symbol); + ast_shortcuts::split_declaration_into_single_nodes(meta,xtype,symbols,AssignOp::AssignSignal) }, }; ParseSubstitution : Statement = { @@ -253,18 +280,18 @@ ParseStatement0 : Statement = { ParseStmt0NB : Statement = { "if" "(" ")" - => build_conditional_block(Meta::new(s,e),cond,if_case,Option::None), + => build_conditional_block(Meta::new(s,e),cond,if_case,None), "if" "(" ")" - => build_conditional_block(Meta::new(s,e),cond,if_case,Option::None), + => build_conditional_block(Meta::new(s,e),cond,if_case,None), "if" "(" ")" > - => build_conditional_block(Meta::new(s,e),cond,if_case,Option::Some(else_case)), + => build_conditional_block(Meta::new(s,e),cond,if_case,Some(else_case)), }; ParseStatement1 : Statement = { "if" "(" ")" > - => build_conditional_block(Meta::new(s,e),cond,if_case,Option::Some(else_case)), + => build_conditional_block(Meta::new(s,e),cond,if_case,Some(else_case)), ParseStatement2 }; ParseStatement2 : Statement = { @@ -286,8 +313,7 @@ ParseStatement2 : Statement = { "===" ";" => build_constraint_equality(Meta::new(s,e),lhe,rhe), - "log" "(" ")" ";" - => build_log_call(Meta::new(s,e),arg), + ParseStatementLog, "assert" "(" ")" ";" => build_assert(Meta::new(s,e),arg), @@ -295,6 +321,14 @@ ParseStatement2 : Statement = { ParseBlock }; +ParseStatementLog : Statement = { + "log" "(" ")" ";" + => build_log_call(Meta::new(s,e),args), + + "log" "(" ")" ";" + => build_log_call(Meta::new(s,e),Vec::new()), +}; + ParseStatement3 : Statement = { ";" => dec, @@ -336,6 +370,34 @@ Listable: Vec = { }, }; +ParseString : LogArgument = { + + => { + build_log_string(e) + }, +}; + +ParseLogExp: LogArgument = { + + => { + build_log_expression(e) + } +} + +ParseLogArgument : LogArgument = { + ParseLogExp, + ParseString +}; + +LogListable: Vec = { + ",")*> + => { + let mut e = e; + e.push(tail); + e + }, +}; + InfixOpTier : Expression = { > => build_infix(Meta::new(s,e),lhe,infix_op,rhe), @@ -350,12 +412,25 @@ PrefixOpTier: Expression = { NextTier }; - pub ParseExpression: Expression = { + Expression14, + ParseExpression1, +} + +pub ParseExpression1: Expression = { Expression13, Expression12, }; +// parallel expr +Expression14: Expression = { + "parallel" + => { + build_parallel_op(Meta::new(s, e), expr) + }, + +} + // ops: e ? a : i Expression13 : Expression = { "?" ":" @@ -423,7 +498,7 @@ Expression0: Expression = { => build_number(Meta::new(s,e),value), - "(" ")" + "(" ")" }; diff --git a/parser/src/lib.rs b/parser/src/lib.rs index aff1a84..8370390 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -42,7 +42,7 @@ pub fn parse_files(file_paths: &Vec, compiler_version: &str) -> ParseRe match parse_file(&file_path, &mut file_stack, &mut file_library, &compiler_version) { Ok((file_id, program, mut warnings)) => { if let Some(main_component) = program.main_component { - main_components.push((file_id, main_component)); + main_components.push((file_id, main_component, program.custom_gates)); } definitions.insert(file_id, program.definitions); reports.append(&mut warnings); @@ -53,9 +53,15 @@ pub fn parse_files(file_paths: &Vec, compiler_version: &str) -> ParseRe } } match &main_components[..] { - [(main_id, main_component)] => { + [(main_id, main_component, custom_gates)] => { // TODO: This calls FillMeta::fill a second time. - match ProgramArchive::new(file_library, *main_id, main_component, &definitions) { + match ProgramArchive::new( + file_library, + *main_id, + main_component, + &definitions, + *custom_gates, + ) { Ok(program_archive) => ParseResult::Program(Box::new(program_archive), reports), Err((file_library, mut errors)) => { reports.append(&mut errors); diff --git a/program_analysis/Cargo.toml b/program_analysis/Cargo.toml index 14f5140..20a51a6 100644 --- a/program_analysis/Cargo.toml +++ b/program_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "circomspect-program-analysis" -version = "0.5.2" +version = "0.6.0" edition = "2021" license = "LGPL-3.0-only" authors = ["Trail of Bits"] @@ -12,9 +12,9 @@ anyhow = "1.0" log = "0.4" num-bigint-dig = "0.6.0" num-traits = "0.2.6" -parser = { package = "circomspect-parser", version = "2.0.2", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure" } +parser = { package = "circomspect-parser", version = "2.0.8", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure" } [dev-dependencies] -parser = { package = "circomspect-parser", version = "2.0.2", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure" } +parser = { package = "circomspect-parser", version = "2.0.8", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure" } diff --git a/program_analysis/src/bitwise_complement.rs b/program_analysis/src/bitwise_complement.rs index 6a09543..4eedf10 100644 --- a/program_analysis/src/bitwise_complement.rs +++ b/program_analysis/src/bitwise_complement.rs @@ -51,10 +51,17 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { visit_expression(size, reports); } } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + visit_expression(value, reports); + } + } + } IfThenElse { cond, .. } => visit_expression(cond, reports), Substitution { rhe, .. } => visit_expression(rhe, reports), Return { value, .. } => visit_expression(value, reports), - LogCall { arg, .. } => visit_expression(arg, reports), Assert { arg, .. } => visit_expression(arg, reports), ConstraintEquality { lhe, rhe, .. } => { visit_expression(lhe, reports); @@ -87,7 +94,7 @@ fn visit_expression(expr: &Expression, reports: &mut ReportCollection) { visit_expression(arg, reports); } } - Array { values, .. } => { + InlineArray { values, .. } => { for value in values { visit_expression(value, reports); } @@ -119,7 +126,7 @@ fn build_report(meta: &Meta) -> Report { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -136,8 +143,12 @@ mod tests { fn validate_reports(src: &str, expected_len: usize) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/constant_conditional.rs b/program_analysis/src/constant_conditional.rs index 54cddaf..954d753 100644 --- a/program_analysis/src/constant_conditional.rs +++ b/program_analysis/src/constant_conditional.rs @@ -67,7 +67,7 @@ fn build_report(meta: &Meta, value: bool) -> Report { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -105,8 +105,12 @@ mod tests { fn validate_reports(src: &str, expected_len: usize) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/field_arithmetic.rs b/program_analysis/src/field_arithmetic.rs index 7bc242c..7d9ef78 100644 --- a/program_analysis/src/field_arithmetic.rs +++ b/program_analysis/src/field_arithmetic.rs @@ -52,10 +52,17 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { visit_expression(size, reports); } } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + visit_expression(value, reports); + } + } + } IfThenElse { cond, .. } => visit_expression(cond, reports), Substitution { rhe, .. } => visit_expression(rhe, reports), Return { value, .. } => visit_expression(value, reports), - LogCall { arg, .. } => visit_expression(arg, reports), Assert { arg, .. } => visit_expression(arg, reports), ConstraintEquality { lhe, rhe, .. } => { visit_expression(lhe, reports); @@ -87,7 +94,7 @@ fn visit_expression(expr: &Expression, reports: &mut ReportCollection) { visit_expression(arg, reports); } } - Array { values, .. } => { + InlineArray { values, .. } => { for value in values { visit_expression(value, reports); } @@ -133,7 +140,7 @@ fn build_report(meta: &Meta) -> Report { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -152,8 +159,12 @@ mod tests { fn validate_reports(src: &str, expected_len: usize) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/field_comparisons.rs b/program_analysis/src/field_comparisons.rs index 2bc552e..b7a0a78 100644 --- a/program_analysis/src/field_comparisons.rs +++ b/program_analysis/src/field_comparisons.rs @@ -61,10 +61,17 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { visit_expression(size, reports); } } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + visit_expression(value, reports); + } + } + } IfThenElse { cond, .. } => visit_expression(cond, reports), Substitution { rhe, .. } => visit_expression(rhe, reports), Return { value, .. } => visit_expression(value, reports), - LogCall { arg, .. } => visit_expression(arg, reports), Assert { arg, .. } => visit_expression(arg, reports), ConstraintEquality { lhe, rhe, .. } => { visit_expression(lhe, reports); @@ -96,7 +103,7 @@ fn visit_expression(expr: &Expression, reports: &mut ReportCollection) { visit_expression(arg, reports); } } - Array { values, .. } => { + InlineArray { values, .. } => { for value in values { visit_expression(value, reports); } @@ -133,7 +140,7 @@ fn build_report(meta: &Meta) -> Report { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -158,8 +165,12 @@ mod tests { fn validate_reports(src: &str, expected_len: usize) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/nonstrict_binary_conversion.rs b/program_analysis/src/nonstrict_binary_conversion.rs index 262bee7..d67ca50 100644 --- a/program_analysis/src/nonstrict_binary_conversion.rs +++ b/program_analysis/src/nonstrict_binary_conversion.rs @@ -67,16 +67,17 @@ pub fn find_nonstrict_binary_conversion(cfg: &Cfg) -> ReportCollection { } debug!("running non-strict `Num2Bits` analysis pass"); let mut reports = ReportCollection::new(); + let prime_size = BigInt::from(cfg.constants().prime_size()); for basic_block in cfg.iter() { for stmt in basic_block.iter() { - visit_statement(stmt, &mut reports); + visit_statement(stmt, &prime_size, &mut reports); } } debug!("{} new reports generated", reports.len()); reports } -fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { +fn visit_statement(stmt: &Statement, prime_size: &BigInt, reports: &mut ReportCollection) { use AssignOp::*; use Expression::*; use Statement::*; @@ -96,9 +97,10 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { // We assume this is the `Num2Bits` circuit from Circomlib. if component_name == "Num2Bits" && args.len() == 1 { let arg = &args[0]; - // If the input size is known to be < 254, this initialization is safe. + // If the input size is known to be less than the prime size, this + // initialization is safe. if let Some(FieldElement { value }) = arg.value() { - if value < &BigInt::from(254u8) { + if value < prime_size { return; } } @@ -107,9 +109,10 @@ fn visit_statement(stmt: &Statement, reports: &mut ReportCollection) { // We assume this is the `Bits2Num` circuit from Circomlib. if component_name == "Bits2Num" && args.len() == 1 { let arg = &args[0]; - // If the input size is known to be < 254, this initialization is safe. + // If the input size is known to be less than the prime size, this + // initialization is safe. if let Some(FieldElement { value }) = arg.value() { - if value < &BigInt::from(254u8) { + if value < prime_size { return; } } @@ -137,7 +140,7 @@ fn build_bits2num(meta: &Meta) -> Report { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -156,13 +159,32 @@ mod tests { } "#; validate_reports(src, 1); + + let src = r#" + template F(n) { + signal input in; + signal output out[n]; + + var bits = 254; + component n2b = Num2Bits(bits - 1); + n2b.in === in; + for (var i = 0; i < n; i++) { + out[i] <== n2b.out[i]; + } + } + "#; + validate_reports(src, 0); } fn validate_reports(src: &str, expected_len: usize) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::Bn128, &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/side_effect_analysis.rs b/program_analysis/src/side_effect_analysis.rs index e61e0f0..ea0f239 100644 --- a/program_analysis/src/side_effect_analysis.rs +++ b/program_analysis/src/side_effect_analysis.rs @@ -415,7 +415,7 @@ fn dimensions_to_string(dimensions: &[Expression]) -> String { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -498,8 +498,12 @@ mod tests { fn validate_reports(src: &str, expected_len: usize) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/signal_assignments.rs b/program_analysis/src/signal_assignments.rs index 9458924..78dfe97 100644 --- a/program_analysis/src/signal_assignments.rs +++ b/program_analysis/src/signal_assignments.rs @@ -308,7 +308,7 @@ fn access_to_string(access: &[AccessType]) -> String { #[cfg(test)] mod tests { use parser::parse_definition; - use program_structure::cfg::IntoCfg; + use program_structure::{cfg::IntoCfg, constants::Curve}; use super::*; @@ -361,8 +361,12 @@ mod tests { // Build CFG. println!("{}", src); let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // Generate report collection. diff --git a/program_analysis/src/taint_analysis.rs b/program_analysis/src/taint_analysis.rs index e9db6ee..18d232b 100644 --- a/program_analysis/src/taint_analysis.rs +++ b/program_analysis/src/taint_analysis.rs @@ -171,6 +171,7 @@ mod tests { use parser::parse_definition; use program_structure::cfg::IntoCfg; + use program_structure::constants::Curve; use program_structure::report::ReportCollection; use super::*; @@ -231,8 +232,12 @@ mod tests { fn validate_taint(src: &str, taint_map: &HashMap<&str, HashSet>) { // Build CFG. let mut reports = ReportCollection::new(); - let cfg = - parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); let taint_analysis = run_taint_analysis(&cfg); diff --git a/program_structure/Cargo.toml b/program_structure/Cargo.toml index 919cb5c..166693e 100644 --- a/program_structure/Cargo.toml +++ b/program_structure/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "circomspect-program-structure" -version = "2.0.2" +version = "2.0.8" edition = "2018" license = "LGPL-3.0-only" description = "Support crate for the Circomspect static analyzer" diff --git a/program_structure/src/abstract_syntax_tree/ast.rs b/program_structure/src/abstract_syntax_tree/ast.rs index 185e463..c1acf69 100644 --- a/program_structure/src/abstract_syntax_tree/ast.rs +++ b/program_structure/src/abstract_syntax_tree/ast.rs @@ -89,6 +89,8 @@ impl Meta { pub struct AST { pub meta: Meta, pub compiler_version: Option, + pub custom_gates: bool, + pub custom_gates_declared: bool, pub includes: Vec, pub definitions: Vec, pub main_component: Option, @@ -97,11 +99,23 @@ impl AST { pub fn new( meta: Meta, compiler_version: Option, + custom_gates: bool, includes: Vec, definitions: Vec, main_component: Option, ) -> AST { - AST { meta, compiler_version, includes, definitions, main_component } + let custom_gates_declared = definitions.iter().any(|definition| { + matches!(definition, Definition::Template { is_custom_gate: true, .. }) + }); + AST { + meta, + compiler_version, + custom_gates, + custom_gates_declared, + includes, + definitions, + main_component, + } } } @@ -114,6 +128,7 @@ pub enum Definition { arg_location: FileLocation, body: Statement, parallel: bool, + is_custom_gate: bool, }, Function { meta: Meta, @@ -130,8 +145,9 @@ pub fn build_template( arg_location: FileLocation, body: Statement, parallel: bool, + is_custom_gate: bool, ) -> Definition { - Definition::Template { meta, name, args, arg_location, body, parallel } + Definition::Template { meta, name, args, arg_location, body, parallel, is_custom_gate } } pub fn build_function( @@ -187,7 +203,7 @@ pub enum Statement { }, LogCall { meta: Meta, - arg: Expression, + args: Vec, }, Block { meta: Meta, @@ -239,6 +255,10 @@ pub enum Expression { if_true: Box, if_false: Box, }, + ParallelOp { + meta: Meta, + rhe: Box, + }, Variable { meta: Meta, name: String, @@ -306,8 +326,21 @@ pub enum ExpressionPrefixOpcode { Complement, } -// Knowledge buckets +#[derive(Clone)] +pub enum LogArgument { + LogStr(String), + LogExp(Expression), +} + +pub fn build_log_string(acc: String) -> LogArgument { + LogArgument::LogStr(acc) +} +pub fn build_log_expression(expr: Expression) -> LogArgument { + LogArgument::LogExp(expr) +} + +// Knowledge buckets #[derive(Copy, Clone, PartialOrd, PartialEq, Ord, Eq)] pub enum TypeReduction { Variable, diff --git a/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs b/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs index 6f9cd61..868fed2 100644 --- a/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs +++ b/program_structure/src/abstract_syntax_tree/ast_shortcuts.rs @@ -49,7 +49,9 @@ pub fn split_declaration_into_single_nodes( meta: Meta, xtype: VariableType, symbols: Vec, + op: AssignOp, ) -> Statement { + // use crate::ast_shortcuts::VariableType::Var; let mut initializations = Vec::new(); for symbol in symbols { @@ -58,14 +60,25 @@ pub fn split_declaration_into_single_nodes( let name = symbol.name.clone(); let dimensions = symbol.is_array; let possible_init = symbol.init; - let single_declaration = build_declaration(with_meta, has_type, name, dimensions); + let single_declaration = build_declaration(with_meta, has_type, name, dimensions.clone()); initializations.push(single_declaration); if let Option::Some(init) = possible_init { - let substitution = - build_substitution(meta.clone(), symbol.name, vec![], AssignOp::AssignVar, init); + let substitution = build_substitution(meta.clone(), symbol.name, vec![], op, init); initializations.push(substitution); } + // If the variable is not initialialized it is default initialized to 0. + // We remove this because we don't want this assignment to be flagged as + // an unused assignment by the side-effect analysis. + // else if xtype == Var { + // let mut value = Expression::Number(meta.clone(), BigInt::from(0)); + // for dim_expr in dimensions.iter().rev() { + // value = build_uniform_array(meta.clone(), value, dim_expr.clone()); + // } + + // let substitution = build_substitution(meta.clone(), symbol.name, vec![], op, value); + // initializations.push(substitution); + // } } build_initialization_block(meta, xtype, initializations) } diff --git a/program_structure/src/abstract_syntax_tree/expression_builders.rs b/program_structure/src/abstract_syntax_tree/expression_builders.rs index b0dbeae..cfe98d5 100644 --- a/program_structure/src/abstract_syntax_tree/expression_builders.rs +++ b/program_structure/src/abstract_syntax_tree/expression_builders.rs @@ -29,6 +29,10 @@ pub fn build_inline_switch_op( } } +pub fn build_parallel_op(meta: Meta, rhe: Expression) -> Expression { + ParallelOp { meta, rhe: Box::new(rhe) } +} + pub fn build_variable(meta: Meta, name: String, access: Vec) -> Expression { Variable { meta, name, access } } diff --git a/program_structure/src/abstract_syntax_tree/expression_impl.rs b/program_structure/src/abstract_syntax_tree/expression_impl.rs index 4efdeca..13fdac6 100644 --- a/program_structure/src/abstract_syntax_tree/expression_impl.rs +++ b/program_structure/src/abstract_syntax_tree/expression_impl.rs @@ -9,6 +9,7 @@ impl Expression { | PrefixOp { meta, .. } | InlineSwitchOp { meta, .. } | Variable { meta, .. } + | ParallelOp { meta, .. } | Number(meta, ..) | Call { meta, .. } | ArrayInLine { meta, .. } => meta, @@ -21,6 +22,7 @@ impl Expression { | PrefixOp { meta, .. } | InlineSwitchOp { meta, .. } | Variable { meta, .. } + | ParallelOp { meta, .. } | Number(meta, ..) | Call { meta, .. } | ArrayInLine { meta, .. } => meta, @@ -61,6 +63,11 @@ impl Expression { use Expression::*; matches!(self, Call { .. }) } + + pub fn is_parallel(&self) -> bool { + use Expression::*; + matches!(self, ParallelOp { .. }) + } } impl FillMeta for Expression { @@ -73,6 +80,7 @@ impl FillMeta for Expression { Variable { meta, access, .. } => fill_variable(meta, access, file_id, element_id), InfixOp { meta, lhe, rhe, .. } => fill_infix(meta, lhe, rhe, file_id, element_id), PrefixOp { meta, rhe, .. } => fill_prefix(meta, rhe, file_id, element_id), + ParallelOp { meta, rhe, .. } => fill_parallel(meta, rhe, file_id, element_id), InlineSwitchOp { meta, cond, if_false, if_true, .. } => { fill_inline_switch_op(meta, cond, if_true, if_false, file_id, element_id) } @@ -147,6 +155,11 @@ fn fill_array_inline( } } +fn fill_parallel(meta: &mut Meta, rhe: &mut Expression, file_id: usize, element_id: &mut usize) { + meta.set_file_id(file_id); + rhe.fill(file_id, element_id); +} + impl Debug for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { write!(f, "{}", self) @@ -165,10 +178,11 @@ impl Display for Expression { } Ok(()) } - InfixOp { lhe, infix_op, rhe, .. } => write!(f, "({} {} {})", lhe, infix_op, rhe), - PrefixOp { prefix_op, rhe, .. } => write!(f, "{}({})", prefix_op, rhe), + ParallelOp { rhe, .. } => write!(f, "parallel {rhe}"), + InfixOp { lhe, infix_op, rhe, .. } => write!(f, "({lhe} {infix_op} {rhe})"), + PrefixOp { prefix_op, rhe, .. } => write!(f, "{prefix_op}({rhe})"), InlineSwitchOp { cond, if_true, if_false, .. } => { - write!(f, "({}? {} : {})", cond, if_true, if_false) + write!(f, "({cond}? {if_true} : {if_false})") } Call { id, args, .. } => write!(f, "{}({})", id, vec_to_string(args)), ArrayInLine { values, .. } => write!(f, "[{}]", vec_to_string(values)), diff --git a/program_structure/src/abstract_syntax_tree/statement_builders.rs b/program_structure/src/abstract_syntax_tree/statement_builders.rs index bbe6095..a7ddfc1 100644 --- a/program_structure/src/abstract_syntax_tree/statement_builders.rs +++ b/program_structure/src/abstract_syntax_tree/statement_builders.rs @@ -53,8 +53,31 @@ pub fn build_constraint_equality(meta: Meta, lhe: Expression, rhe: Expression) - ConstraintEquality { meta, lhe, rhe } } -pub fn build_log_call(meta: Meta, arg: Expression) -> Statement { - LogCall { meta, arg } +pub fn build_log_call(meta: Meta, args: Vec) -> Statement { + let mut new_args = Vec::new(); + for arg in args { + match arg { + LogArgument::LogExp(..) => { + new_args.push(arg); + } + LogArgument::LogStr(str) => { + new_args.append(&mut split_string(str)); + } + } + } + LogCall { meta, args: new_args } +} + +fn split_string(str: String) -> Vec { + let mut v = vec![]; + let sub_len = 230; + let mut cur = str; + while !cur.is_empty() { + let (chunk, rest) = cur.split_at(std::cmp::min(sub_len, cur.len())); + v.push(LogArgument::LogStr(chunk.to_string())); + cur = rest.to_string(); + } + v } pub fn build_assert(meta: Meta, arg: Expression) -> Statement { diff --git a/program_structure/src/abstract_syntax_tree/statement_impl.rs b/program_structure/src/abstract_syntax_tree/statement_impl.rs index ba4261d..66853ad 100644 --- a/program_structure/src/abstract_syntax_tree/statement_impl.rs +++ b/program_structure/src/abstract_syntax_tree/statement_impl.rs @@ -99,7 +99,7 @@ impl FillMeta for Statement { ConstraintEquality { meta, lhe, rhe } => { fill_constraint_equality(meta, lhe, rhe, file_id, element_id) } - LogCall { meta, arg, .. } => fill_log_call(meta, arg, file_id, element_id), + LogCall { meta, args, .. } => fill_log_call(meta, args, file_id, element_id), Block { meta, stmts, .. } => fill_block(meta, stmts, file_id, element_id), Assert { meta, arg, .. } => fill_assert(meta, arg, file_id, element_id), } @@ -191,9 +191,18 @@ fn fill_constraint_equality( rhe.fill(file_id, element_id); } -fn fill_log_call(meta: &mut Meta, arg: &mut Expression, file_id: usize, element_id: &mut usize) { +fn fill_log_call( + meta: &mut Meta, + args: &mut Vec, + file_id: usize, + element_id: &mut usize, +) { meta.set_file_id(file_id); - arg.fill(file_id, element_id); + for arg in args { + if let LogArgument::LogExp(e) = arg { + e.fill(file_id, element_id); + } + } } fn fill_block(meta: &mut Meta, stmts: &mut [Statement], file_id: usize, element_id: &mut usize) { @@ -244,7 +253,16 @@ impl Display for Statement { } write!(f, " {op} {rhe}") } - LogCall { arg, .. } => write!(f, "log({arg})"), + LogCall { args, .. } => { + write!(f, "log(")?; + for (index, arg) in args.iter().enumerate() { + if index > 0 { + write!(f, ", ")?; + } + write!(f, "{arg}")?; + } + write!(f, ")") + } // TODO: Remove this when switching to IR. Block { .. } => Ok(()), Assert { arg, .. } => write!(f, "assert({arg})"), @@ -294,3 +312,13 @@ impl Display for SignalType { } } } + +impl Display for LogArgument { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { + use LogArgument::*; + match self { + LogStr(message) => write!(f, "{message}"), + LogExp(value) => write!(f, "{value}"), + } + } +} diff --git a/program_structure/src/control_flow_graph/cfg.rs b/program_structure/src/control_flow_graph/cfg.rs index 5eecb4a..ee95bab 100644 --- a/program_structure/src/control_flow_graph/cfg.rs +++ b/program_structure/src/control_flow_graph/cfg.rs @@ -2,6 +2,7 @@ use log::debug; use std::collections::HashSet; use std::fmt; +use crate::constants::UsefulConstants; use crate::file_definition::FileID; use crate::ir::declarations::{Declaration, Declarations}; use crate::ir::degree_meta::{DegreeEnvironment, Degree}; @@ -37,6 +38,7 @@ impl fmt::Display for DefinitionType { pub struct Cfg { name: String, + constants: UsefulConstants, parameters: Parameters, declarations: Declarations, basic_blocks: Vec, @@ -47,13 +49,22 @@ pub struct Cfg { impl Cfg { pub(crate) fn new( name: String, + constants: UsefulConstants, definition_type: DefinitionType, parameters: Parameters, declarations: Declarations, basic_blocks: Vec, dominator_tree: DominatorTree, ) -> Cfg { - Cfg { name, parameters, declarations, basic_blocks, definition_type, dominator_tree } + Cfg { + name, + constants, + parameters, + declarations, + basic_blocks, + definition_type, + dominator_tree, + } } /// Returns the entry (first) block of the CFG. #[must_use] @@ -136,6 +147,11 @@ impl Cfg { &self.definition_type } + #[must_use] + pub fn constants(&self) -> &UsefulConstants { + &self.constants + } + /// Returns the parameter data for the corresponding function or template. #[must_use] pub fn parameters(&self) -> &Parameters { @@ -415,7 +431,7 @@ impl Cfg { /// Propagate constant values along the CFG. pub(crate) fn propagate_values(&mut self) { debug!("propagating constant values for `{}`", self.name()); - let mut env = ValueEnvironment::new(); + let mut env = ValueEnvironment::new(&self.constants); let mut rerun = true; while rerun { // Rerun value propagation if a single child node was updated. diff --git a/program_structure/src/control_flow_graph/lifting.rs b/program_structure/src/control_flow_graph/lifting.rs index 5de2034..45d2b46 100644 --- a/program_structure/src/control_flow_graph/lifting.rs +++ b/program_structure/src/control_flow_graph/lifting.rs @@ -4,6 +4,7 @@ use std::collections::HashSet; use crate::ast; use crate::ast::Definition; +use crate::constants::{UsefulConstants, Curve}; use crate::function_data::FunctionData; use crate::ir; use crate::ir::declarations::{Declaration, Declarations}; @@ -27,17 +28,20 @@ type Index = usize; type IndexSet = HashSet; type BasicBlockVec = NonEmptyVec; -/// This is a high level trait which simply wraps the implementation provided by `TryLift`. +/// This is a high level trait which simply wraps the implementation provided by +/// `TryLift`. We need to pass the prime to the CFG here, to be able to do value +/// propagation when converting to SSA. pub trait IntoCfg { - fn into_cfg(self, reports: &mut ReportCollection) -> CFGResult; + fn into_cfg(self, curve: &Curve, reports: &mut ReportCollection) -> CFGResult; } impl IntoCfg for T where - T: TryLift<(), IR = Cfg, Error = CFGError>, + T: TryLift, { - fn into_cfg(self, reports: &mut ReportCollection) -> CFGResult { - self.try_lift((), reports) + fn into_cfg(self, curve: &Curve, reports: &mut ReportCollection) -> CFGResult { + let constants = UsefulConstants::new(curve); + self.try_lift(constants, reports) } } @@ -58,45 +62,58 @@ impl From<&Parameters> for LiftingEnvironment { } } -impl TryLift<()> for &TemplateData { +impl TryLift for &TemplateData { type IR = Cfg; type Error = CFGError; - fn try_lift(&self, _: (), reports: &mut ReportCollection) -> CFGResult { + fn try_lift( + &self, + constants: UsefulConstants, + reports: &mut ReportCollection, + ) -> CFGResult { let name = self.get_name().to_string(); let parameters = Parameters::from(*self); let body = self.get_body().clone(); debug!("building CFG for template `{name}`"); - try_lift_impl(name, DefinitionType::Template, parameters, body, reports) + try_lift_impl(name, DefinitionType::Template, constants, parameters, body, reports) } } -impl TryLift<()> for &FunctionData { +impl TryLift for &FunctionData { type IR = Cfg; type Error = CFGError; - fn try_lift(&self, _: (), reports: &mut ReportCollection) -> CFGResult { + fn try_lift( + &self, + constants: UsefulConstants, + reports: &mut ReportCollection, + ) -> CFGResult { let name = self.get_name().to_string(); let parameters = Parameters::from(*self); let body = self.get_body().clone(); debug!("building CFG for function `{name}`"); - try_lift_impl(name, DefinitionType::Function, parameters, body, reports) + try_lift_impl(name, DefinitionType::Function, constants, parameters, body, reports) } } -impl TryLift<()> for Definition { +impl TryLift for Definition { type IR = Cfg; type Error = CFGError; - fn try_lift(&self, _: (), reports: &mut ReportCollection) -> CFGResult { + fn try_lift( + &self, + constants: UsefulConstants, + reports: &mut ReportCollection, + ) -> CFGResult { match self { Definition::Template { name, body, .. } => { debug!("building CFG for template `{name}`"); try_lift_impl( name.clone(), DefinitionType::Template, + constants, self.into(), body.clone(), reports, @@ -107,6 +124,7 @@ impl TryLift<()> for Definition { try_lift_impl( name.clone(), DefinitionType::Function, + constants, self.into(), body.clone(), reports, @@ -119,6 +137,7 @@ impl TryLift<()> for Definition { fn try_lift_impl( name: String, definition_type: DefinitionType, + constants: UsefulConstants, parameters: Parameters, mut body: ast::Statement, reports: &mut ReportCollection, @@ -131,15 +150,22 @@ fn try_lift_impl( let basic_blocks = build_basic_blocks(&body, &mut env, reports)?; let dominator_tree = DominatorTree::new(&basic_blocks); let declarations = Declarations::from(env); - let mut cfg = - Cfg::new(name, definition_type, parameters, declarations, basic_blocks, dominator_tree); + let mut cfg = Cfg::new( + name, + constants, + definition_type, + parameters, + declarations, + basic_blocks, + dominator_tree, + ); // 3. Propagate metadata to all child nodes. Since determining variable use // requires that variable types are available, type propagation must run // before caching variable use. // - // Note that the current implementation of value propagation only makes - // sense in SSA form. + // Note that the current implementations of value and degree propagation + // only make sense in SSA form. cfg.propagate_types(); cfg.cache_variable_use(); diff --git a/program_structure/src/control_flow_graph/ssa_impl.rs b/program_structure/src/control_flow_graph/ssa_impl.rs index 3834653..f895137 100644 --- a/program_structure/src/control_flow_graph/ssa_impl.rs +++ b/program_structure/src/control_flow_graph/ssa_impl.rs @@ -218,9 +218,17 @@ impl SSAStatement for Statement { visit_expression(lhe, env)?; visit_expression(rhe, env) } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + visit_expression(value, env)?; + } + } + Ok(()) + } IfThenElse { cond, .. } => visit_expression(cond, env), Return { value, .. } => visit_expression(value, env), - LogCall { arg, .. } => visit_expression(arg, env), Assert { arg, .. } => visit_expression(arg, env), }; // We need to update the node metadata to have a current view of @@ -343,7 +351,7 @@ fn visit_expression(expr: &mut Expression, env: &mut Environment) -> SSAResult<( } Ok(()) } - Array { values, .. } => { + InlineArray { values, .. } => { for value in values { visit_expression(value, env)?; } diff --git a/program_structure/src/control_flow_graph/unique_vars.rs b/program_structure/src/control_flow_graph/unique_vars.rs index cc97e10..eb594c3 100644 --- a/program_structure/src/control_flow_graph/unique_vars.rs +++ b/program_structure/src/control_flow_graph/unique_vars.rs @@ -4,7 +4,7 @@ use std::convert::{TryFrom, TryInto}; use super::errors::{CFGError, CFGResult}; use super::parameters::Parameters; -use crate::ast::{Access, Expression, Meta, Statement}; +use crate::ast::{Access, Expression, Meta, Statement, LogArgument}; use crate::environment::VarEnvironment; use crate::report::{Report, ReportCollection}; use crate::file_definition::{FileID, FileLocation}; @@ -231,6 +231,14 @@ fn visit_statement( } visit_expression(rhe, env); } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let LogExp(value) = arg { + visit_expression(value, env); + } + } + } Return { value, .. } => { visit_expression(value, env); } @@ -238,9 +246,6 @@ fn visit_statement( visit_expression(lhe, env); visit_expression(rhe, env); } - LogCall { arg, .. } => { - visit_expression(arg, env); - } Assert { arg, .. } => { visit_expression(arg, env); } @@ -312,6 +317,9 @@ fn visit_expression(expr: &mut Expression, env: &DeclarationEnvironment) { visit_expression(value, env); } } + ParallelOp { rhe, .. } => { + visit_expression(rhe, env); + } } } diff --git a/program_structure/src/intermediate_representation/expression_impl.rs b/program_structure/src/intermediate_representation/expression_impl.rs index 61b3876..032450c 100644 --- a/program_structure/src/intermediate_representation/expression_impl.rs +++ b/program_structure/src/intermediate_representation/expression_impl.rs @@ -1,11 +1,10 @@ -use circom_algebra::modular_arithmetic; use log::trace; use num_traits::Zero; use std::collections::HashSet; use std::fmt; use std::hash::{Hash, Hasher}; -use crate::constants::UsefulConstants; +use circom_algebra::modular_arithmetic; use super::declarations::Declarations; use super::degree_meta::{Degree, DegreeEnvironment, DegreeMeta, DegreeRange}; @@ -25,7 +24,7 @@ impl Expression { | Variable { meta, .. } | Number(meta, ..) | Call { meta, .. } - | Array { meta, .. } + | InlineArray { meta, .. } | Update { meta, .. } | Access { meta, .. } | Phi { meta, .. } => meta, @@ -42,7 +41,7 @@ impl Expression { | Variable { meta, .. } | Number(meta, ..) | Call { meta, .. } - | Array { meta, .. } + | InlineArray { meta, .. } | Update { meta, .. } | Access { meta, .. } | Phi { meta, .. } => meta, @@ -75,7 +74,7 @@ impl PartialEq for Expression { Call { name: self_id, args: self_args, .. }, Call { name: other_id, args: other_args, .. }, ) => self_id == other_id && self_args == other_args, - (Array { values: self_values, .. }, Array { values: other_values, .. }) => { + (InlineArray { values: self_values, .. }, InlineArray { values: other_values, .. }) => { self_values == other_values } ( @@ -116,7 +115,7 @@ impl Hash for Expression { Call { args, .. } => { args.hash(state); } - Array { values, .. } => { + InlineArray { values, .. } => { values.hash(state); } Access { var, access, .. } => { @@ -190,7 +189,7 @@ impl DegreeMeta for Expression { meta.degree_knowledge_mut().set_degree(&Constant.into()) } } - Array { meta, values } => { + InlineArray { meta, values } => { // The degree range of an array is the infimum of the ranges of all elements. for value in values.iter_mut() { value.propagate_degrees(env); @@ -269,7 +268,7 @@ impl TypeMeta for Expression { arg.propagate_types(vars); } } - Array { values, .. } => { + InlineArray { values, .. } => { for value in values { value.propagate_types(vars); } @@ -389,7 +388,7 @@ impl VariableMeta for Expression { locals_read .extend(args.iter().map(|name| VariableUse::new(meta, name, &Vec::new()))); } - Array { values, .. } => { + InlineArray { values, .. } => { for value in values { value.cache_variable_use(); locals_read.extend(value.locals_read().clone()); @@ -507,7 +506,7 @@ impl ValueMeta for Expression { match self { InfixOp { meta, lhe, infix_op, rhe, .. } => { let mut result = lhe.propagate_values(env) || rhe.propagate_values(env); - match infix_op.propagate_values(lhe.value(), rhe.value()) { + match infix_op.propagate_values(lhe.value(), rhe.value(), env) { Some(value) => { result = result || meta.value_knowledge_mut().set_reduces_to(value) } @@ -517,7 +516,7 @@ impl ValueMeta for Expression { } PrefixOp { meta, prefix_op, rhe } => { let mut result = rhe.propagate_values(env); - match prefix_op.propagate_values(rhe.value()) { + match prefix_op.propagate_values(rhe.value(), env) { Some(value) => { result = result || meta.value_knowledge_mut().set_reduces_to(value) } @@ -582,7 +581,7 @@ impl ValueMeta for Expression { } result } - Array { values, .. } => { + InlineArray { values, .. } => { // TODO: Handle inline arrays. let mut result = false; for value in values { @@ -682,9 +681,9 @@ impl ExpressionInfixOpcode { &self, lhv: Option<&ValueReduction>, rhv: Option<&ValueReduction>, + env: &ValueEnvironment, ) -> Option { - let constants = UsefulConstants::default(); - let p = constants.get_p(); + let p = env.prime(); use ValueReduction::*; match (lhv, rhv) { @@ -794,9 +793,12 @@ impl ExpressionPrefixOpcode { } } - fn propagate_values(&self, rhe: Option<&ValueReduction>) -> Option { - let constants = UsefulConstants::default(); - let p = constants.get_p(); + fn propagate_values( + &self, + rhe: Option<&ValueReduction>, + env: &ValueEnvironment, + ) -> Option { + let p = env.prime(); use ValueReduction::*; match rhe { @@ -846,7 +848,7 @@ impl fmt::Debug for Expression { write!(f, "({cond:?}? {if_true:?} : {if_false:?})") } Call { name: id, args, .. } => write!(f, "{}({})", id, vec_to_debug(args, ", ")), - Array { values, .. } => write!(f, "[{}]", vec_to_debug(values, ", ")), + InlineArray { values, .. } => write!(f, "[{}]", vec_to_debug(values, ", ")), Access { var, access, .. } => { let access = access .iter() @@ -888,7 +890,7 @@ impl fmt::Display for Expression { write!(f, "({cond}? {if_true} : {if_false})") } Call { name: id, args, .. } => write!(f, "{}({})", id, vec_to_display(args, ", ")), - Array { values, .. } => write!(f, "[{}]", vec_to_display(values, ", ")), + InlineArray { values, .. } => write!(f, "[{}]", vec_to_display(values, ", ")), Access { var, access, .. } => { write!(f, "{var}")?; for access in access { @@ -968,6 +970,8 @@ fn vec_to_display(elems: &[T], sep: &str) -> String { #[cfg(test)] mod tests { + use crate::constants::{UsefulConstants, Curve}; + use super::*; #[test] @@ -977,7 +981,8 @@ mod tests { use ValueReduction::*; let mut lhe = Number(Meta::default(), 7u64.into()); let mut rhe = Variable { meta: Meta::default(), name: VariableName::from_name("v") }; - let mut env = ValueEnvironment::new(); + let constants = UsefulConstants::new(&Curve::default()); + let mut env = ValueEnvironment::new(&constants); env.add_variable(&VariableName::from_name("v"), &FieldElement { value: 3u64.into() }); lhe.propagate_values(&mut env); rhe.propagate_values(&mut env); diff --git a/program_structure/src/intermediate_representation/ir.rs b/program_structure/src/intermediate_representation/ir.rs index 4ba5ea0..c171679 100644 --- a/program_structure/src/intermediate_representation/ir.rs +++ b/program_structure/src/intermediate_representation/ir.rs @@ -155,7 +155,7 @@ pub enum Statement { }, LogCall { meta: Meta, - arg: Expression, + args: Vec, }, Assert { meta: Meta, @@ -189,7 +189,7 @@ pub enum Expression { /// A function call node. Call { meta: Meta, name: String, args: Vec }, /// An inline array on the form `[value, ...]`. - Array { meta: Meta, values: Vec }, + InlineArray { meta: Meta, values: Vec }, /// An `Access` node represents an array access of the form `a[i]...[k]`. Access { meta: Meta, var: VariableName, access: Vec }, /// Updates of the form `var[i]...[k] = rhe` lift to IR statements of the @@ -389,3 +389,9 @@ pub enum ExpressionPrefixOpcode { BoolNot, Complement, } + +#[derive(Clone)] +pub enum LogArgument { + String(String), + Expr(Box), +} diff --git a/program_structure/src/intermediate_representation/lifting.rs b/program_structure/src/intermediate_representation/lifting.rs index f1270c4..96d27af 100644 --- a/program_structure/src/intermediate_representation/lifting.rs +++ b/program_structure/src/intermediate_representation/lifting.rs @@ -1,4 +1,4 @@ -use crate::ast; +use crate::ast::{self, LogArgument}; use crate::report::ReportCollection; use crate::ir; @@ -88,9 +88,12 @@ impl TryLift<()> for ast::Statement { rhe: rhe.try_lift((), reports)?, }) } - ast::Statement::LogCall { meta, arg } => Ok(ir::Statement::LogCall { + ast::Statement::LogCall { meta, args } => Ok(ir::Statement::LogCall { meta: meta.try_lift((), reports)?, - arg: arg.try_lift((), reports)?, + args: args + .iter() + .map(|arg| arg.try_lift((), reports)) + .collect::>>()?, }), ast::Statement::Assert { meta, arg } => Ok(ir::Statement::Assert { meta: meta.try_lift((), reports)?, @@ -172,13 +175,16 @@ impl TryLift<()> for ast::Expression { .map(|arg| arg.try_lift((), reports)) .collect::>>()?, }), - ast::Expression::ArrayInLine { meta, values } => Ok(ir::Expression::Array { + ast::Expression::ArrayInLine { meta, values } => Ok(ir::Expression::InlineArray { meta: meta.try_lift((), reports)?, values: values .iter() .map(|value| value.try_lift((), reports)) .collect::>>()?, }), + // TODO: We currently treat `ParallelOp` as transparent and simply + // lift the underlying expression. Should this be added to the IR? + ast::Expression::ParallelOp { rhe, .. } => rhe.try_lift((), reports), } } } @@ -320,6 +326,20 @@ impl TryLift<()> for ast::ExpressionInfixOpcode { } } +impl TryLift<()> for LogArgument { + type IR = ir::LogArgument; + type Error = IRError; + + fn try_lift(&self, _: (), reports: &mut ReportCollection) -> IRResult { + match self { + ast::LogArgument::LogStr(message) => Ok(ir::LogArgument::String(message.clone())), + ast::LogArgument::LogExp(value) => { + Ok(ir::LogArgument::Expr(Box::new(value.try_lift((), reports)?))) + } + } + } +} + #[cfg(test)] mod tests { use proptest::prelude::*; diff --git a/program_structure/src/intermediate_representation/statement_impl.rs b/program_structure/src/intermediate_representation/statement_impl.rs index a3d3534..1b34afc 100644 --- a/program_structure/src/intermediate_representation/statement_impl.rs +++ b/program_structure/src/intermediate_representation/statement_impl.rs @@ -59,9 +59,16 @@ impl Statement { } } } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + value.propagate_degrees(env); + } + } + } IfThenElse { cond, .. } => cond.propagate_degrees(env), Return { value, .. } => value.propagate_degrees(env), - LogCall { arg, .. } => arg.propagate_degrees(env), Assert { arg, .. } => arg.propagate_degrees(env), ConstraintEquality { lhe, rhe, .. } => { lhe.propagate_degrees(env); @@ -90,9 +97,18 @@ impl Statement { } result } + LogCall { args, .. } => { + let mut result = false; + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + result = result || value.propagate_values(env); + } + } + result + } IfThenElse { cond, .. } => cond.propagate_values(env), Return { value, .. } => value.propagate_values(env), - LogCall { arg, .. } => arg.propagate_values(env), Assert { arg, .. } => arg.propagate_values(env), ConstraintEquality { lhe, rhe, .. } => { lhe.propagate_values(env) || rhe.propagate_values(env) @@ -117,6 +133,14 @@ impl Statement { meta.type_knowledge_mut().set_variable_type(var_type); } } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + value.propagate_types(vars); + } + } + } ConstraintEquality { lhe, rhe, .. } => { lhe.propagate_types(vars); rhe.propagate_types(vars); @@ -127,9 +151,6 @@ impl Statement { Return { value, .. } => { value.propagate_types(vars); } - LogCall { arg, .. } => { - arg.propagate_types(vars); - } Assert { arg, .. } => { arg.propagate_types(vars); } @@ -165,7 +186,7 @@ impl fmt::Debug for Statement { }, Return { value, .. } => write!(f, "return {value:?}"), Assert { arg, .. } => write!(f, "assert({arg:?})"), - LogCall { arg, .. } => write!(f, "log({arg:?})"), + LogCall { args, .. } => write!(f, "log({:?})", vec_to_debug(args, ", ")), } } } @@ -201,7 +222,7 @@ impl fmt::Display for Statement { IfThenElse { cond, .. } => write!(f, "if {cond}"), Return { value, .. } => write!(f, "return {value}"), Assert { arg, .. } => write!(f, "assert({arg})"), - LogCall { arg, .. } => write!(f, "log({arg})"), + LogCall { args, .. } => write!(f, "log({})", vec_to_display(args, ", ")), } } } @@ -217,6 +238,26 @@ impl fmt::Display for AssignOp { } } +impl fmt::Display for LogArgument { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + use LogArgument::*; + match self { + String(message) => write!(f, "{message}"), + Expr(value) => write!(f, "{value}"), + } + } +} + +impl fmt::Debug for LogArgument { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + use LogArgument::*; + match self { + String(message) => write!(f, "{message:?}"), + Expr(value) => write!(f, "{value:?}"), + } + } +} + impl VariableMeta for Statement { fn cache_variable_use(&mut self) { let mut locals_read = VariableUses::new(); @@ -271,6 +312,17 @@ impl VariableMeta for Statement { } } } + LogCall { args, .. } => { + use LogArgument::*; + for arg in args { + if let Expr(value) = arg { + value.cache_variable_use(); + locals_read.extend(value.locals_read().clone()); + signals_read.extend(value.signals_read().clone()); + components_read.extend(value.components_read().clone()); + } + } + } IfThenElse { cond, .. } => { cond.cache_variable_use(); locals_read.extend(cond.locals_read().clone()); @@ -283,12 +335,6 @@ impl VariableMeta for Statement { signals_read.extend(value.signals_read().clone()); components_read.extend(value.components_read().clone()); } - LogCall { arg, .. } => { - arg.cache_variable_use(); - locals_read.extend(arg.locals_read().clone()); - signals_read.extend(arg.signals_read().clone()); - components_read.extend(arg.components_read().clone()); - } Assert { arg, .. } => { arg.cache_variable_use(); locals_read.extend(arg.locals_read().clone()); @@ -340,3 +386,13 @@ impl VariableMeta for Statement { self.meta().variable_knowledge().components_written() } } + +#[must_use] +fn vec_to_debug(elems: &[T], sep: &str) -> String { + elems.iter().map(|elem| format!("{elem:?}")).collect::>().join(sep) +} + +#[must_use] +fn vec_to_display(elems: &[T], sep: &str) -> String { + elems.iter().map(|elem| format!("{elem}")).collect::>().join(sep) +} diff --git a/program_structure/src/intermediate_representation/value_meta.rs b/program_structure/src/intermediate_representation/value_meta.rs index bd6aada..78bf6a0 100644 --- a/program_structure/src/intermediate_representation/value_meta.rs +++ b/program_structure/src/intermediate_representation/value_meta.rs @@ -2,16 +2,19 @@ use num_bigint::BigInt; use std::collections::HashMap; use std::fmt; +use crate::constants::UsefulConstants; + use super::ir::VariableName; -#[derive(Default, Clone)] +#[derive(Clone)] pub struct ValueEnvironment { + constants: UsefulConstants, reduces_to: HashMap, } impl ValueEnvironment { - pub fn new() -> ValueEnvironment { - ValueEnvironment::default() + pub fn new(constants: &UsefulConstants) -> ValueEnvironment { + ValueEnvironment { constants: constants.clone(), reduces_to: HashMap::new() } } /// Set the value of the given variable. Returns `true` on first update. @@ -33,6 +36,11 @@ impl ValueEnvironment { pub fn get_variable(&self, name: &VariableName) -> Option<&ValueReduction> { self.reduces_to.get(name) } + + /// Returns the prime used. + pub fn prime(&self) -> &BigInt { + self.constants.prime() + } } pub trait ValueMeta { diff --git a/program_structure/src/program_library/program_archive.rs b/program_structure/src/program_library/program_archive.rs index 841204e..a330836 100644 --- a/program_structure/src/program_library/program_archive.rs +++ b/program_structure/src/program_library/program_archive.rs @@ -20,6 +20,7 @@ pub struct ProgramArchive { pub template_keys: HashSet, pub public_inputs: Vec, pub initial_template_call: Expression, + pub custom_gates: bool, } impl ProgramArchive { pub fn new( @@ -27,6 +28,7 @@ impl ProgramArchive { file_id_main: FileID, main_component: &MainComponent, program_contents: &Contents, + custom_gates: bool, ) -> Result)> { let mut merger = Merger::new(); let mut reports = vec![]; @@ -57,6 +59,7 @@ impl ProgramArchive { function_keys, template_keys, public_inputs, + custom_gates, }) } else { Err((file_library, reports)) diff --git a/program_structure/src/program_library/program_merger.rs b/program_structure/src/program_library/program_merger.rs index 5126909..53ffdf5 100644 --- a/program_structure/src/program_library/program_merger.rs +++ b/program_structure/src/program_library/program_merger.rs @@ -25,7 +25,15 @@ impl Merger { let mut reports = vec![]; for definition in definitions { let (name, meta) = match definition { - Definition::Template { name, args, arg_location, body, meta, parallel } => { + Definition::Template { + name, + args, + arg_location, + body, + meta, + parallel, + is_custom_gate, + } => { if self.contains_function(name) || self.contains_template(name) { (Option::Some(name), meta) } else { @@ -38,6 +46,7 @@ impl Merger { arg_location.clone(), &mut self.fresh_id, *parallel, + *is_custom_gate, ); self.get_mut_template_info().insert(name.clone(), new_data); (Option::None, meta) diff --git a/program_structure/src/program_library/template_data.rs b/program_structure/src/program_library/template_data.rs index 870548e..6b4dcfb 100644 --- a/program_structure/src/program_library/template_data.rs +++ b/program_structure/src/program_library/template_data.rs @@ -18,6 +18,7 @@ pub struct TemplateData { input_signals: SignalInfo, output_signals: SignalInfo, is_parallel: bool, + is_custom_gate: bool, } impl TemplateData { @@ -31,6 +32,7 @@ impl TemplateData { param_location: FileLocation, elem_id: &mut usize, is_parallel: bool, + is_custom_gate: bool, ) -> TemplateData { body.fill(file_id, elem_id); let mut input_signals = SignalInfo::new(); @@ -46,6 +48,7 @@ impl TemplateData { input_signals, output_signals, is_parallel, + is_custom_gate, } } pub fn get_file_id(&self) -> FileID { @@ -96,6 +99,9 @@ impl TemplateData { pub fn is_parallel(&self) -> bool { self.is_parallel } + pub fn is_custom_gate(&self) -> bool { + self.is_custom_gate + } } fn fill_inputs_and_outputs( diff --git a/program_structure/src/program_library/template_library.rs b/program_structure/src/program_library/template_library.rs index d405819..23b198a 100644 --- a/program_structure/src/program_library/template_library.rs +++ b/program_structure/src/program_library/template_library.rs @@ -36,7 +36,15 @@ impl TemplateLibrary { ), ); } - Definition::Template { name, args, arg_location, body, parallel, .. } => { + Definition::Template { + name, + args, + arg_location, + body, + parallel, + is_custom_gate, + .. + } => { templates.insert( name.clone(), TemplateData::new( @@ -48,6 +56,7 @@ impl TemplateLibrary { arg_location, &mut elem_id, parallel, + is_custom_gate, ), ); } diff --git a/program_structure/src/utils/constants.rs b/program_structure/src/utils/constants.rs index 2a74d2b..0fc9d0a 100644 --- a/program_structure/src/utils/constants.rs +++ b/program_structure/src/utils/constants.rs @@ -1,27 +1,85 @@ +use anyhow::{anyhow, Error}; use num_bigint::BigInt; +use std::fmt; +use std::str::FromStr; -const P_STR: &str = "21888242871839275222246405745257275088548364400416034343698204186575808495617"; +#[derive(Clone)] +pub enum Curve { + Bn128, + Bls12_381, + Goldilocks, +} -pub struct UsefulConstants { - p: BigInt, +// Used for testing. +impl Default for Curve { + fn default() -> Self { + Curve::Bn128 + } } -impl Clone for UsefulConstants { - fn clone(&self) -> Self { - UsefulConstants { p: self.p.clone() } +impl fmt::Display for Curve { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use Curve::*; + match self { + Bn128 => write!(f, "BN128"), + Bls12_381 => write!(f, "BLS12_381"), + Goldilocks => write!(f, "Goldilocks"), + } } } -impl Default for UsefulConstants { - fn default() -> Self { - UsefulConstants { p: BigInt::parse_bytes(P_STR.as_bytes(), 10).expect("can not parse p") } + +impl fmt::Debug for Curve { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) } } +impl Curve { + fn prime(&self) -> BigInt { + use Curve::*; + let prime = match self { + Bn128 => { + "21888242871839275222246405745257275088548364400416034343698204186575808495617" + } + Bls12_381 => { + "52435875175126190479447740508185965837690552500527637822603658699938581184513" + } + Goldilocks => "18446744069414584321", + }; + BigInt::parse_bytes(prime.as_bytes(), 10).expect("failed to parse prime") + } +} + +impl FromStr for Curve { + type Err = Error; + + fn from_str(prime: &str) -> Result { + match &prime.to_uppercase()[..] { + "BN128" => Ok(Curve::Bn128), + "BLS12_381" => Ok(Curve::Bls12_381), + "GOLDILOCKS" => Ok(Curve::Goldilocks), + _ => Err(anyhow!("failed to parse prime `{prime}`")), + } + } +} + +#[derive(Clone)] +pub struct UsefulConstants { + prime: BigInt, +} + impl UsefulConstants { - pub fn new() -> UsefulConstants { - UsefulConstants::default() + pub fn new(curve: &Curve) -> UsefulConstants { + UsefulConstants { prime: curve.prime() } + } + + /// Returns the used prime. + pub fn prime(&self) -> &BigInt { + &self.prime } - pub fn get_p(&self) -> &BigInt { - &self.p + + /// Returns the size in bits of the used prime. + pub fn prime_size(&self) -> usize { + self.prime.bits() } } diff --git a/program_structure/src/utils/memory_slice.rs b/program_structure/src/utils/memory_slice.rs deleted file mode 100644 index ab3a7d9..0000000 --- a/program_structure/src/utils/memory_slice.rs +++ /dev/null @@ -1,250 +0,0 @@ -use num_bigint_dig::BigInt; -use std::fmt::{Display, Formatter}; -pub enum MemoryError { - OutOfBoundsError, - AssignmentError, - InvalidAccess, - UnknownSizeDimension, -} -pub type SliceCapacity = usize; -pub type SimpleSlice = MemorySlice; -/* - Represents the value stored in a element of a circom program. - The attribute route stores the dimensions of the slice, used to navigate through them. - The length of values is equal to multiplying all the values in route. -*/ -#[derive(Eq, PartialEq)] -pub struct MemorySlice { - route: Vec, - values: Vec, -} - -impl Clone for MemorySlice { - fn clone(&self) -> Self { - MemorySlice { route: self.route.clone(), values: self.values.clone() } - } -} - -impl Display for MemorySlice { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - if self.values.is_empty() { - write!(f, "[]") - } else if self.values.len() == 1 { - write!(f, "{}", self.values[0]) - } else { - write!(f, "[{}", self.values[0])?; - for i in 1..self.values.len() { - write!(f, ", {}", self.values[i])?; - } - write!(f, "]") - } - } -} - -impl MemorySlice { - // Raw manipulations of the slice - fn get_initial_cell( - memory_slice: &MemorySlice, - access: &[SliceCapacity], - ) -> Result { - if access.len() > memory_slice.route.len() { - return Result::Err(MemoryError::OutOfBoundsError); - } - - let mut cell = 0; - let mut cell_jump = memory_slice.values.len(); - let mut i: SliceCapacity = 0; - while i < access.len() { - if access[i] >= memory_slice.route[i] { - return Result::Err(MemoryError::OutOfBoundsError); - } - cell_jump /= memory_slice.route[i]; - cell += cell_jump * access[i]; - i += 1; - } - Result::Ok(cell) - } - // Returns the new route and the total number of cells - // that a slice with such route will have - fn generate_new_route_from_access( - memory_slice: &MemorySlice, - access: &[SliceCapacity], - ) -> Result<(Vec, SliceCapacity), MemoryError> { - if access.len() > memory_slice.route.len() { - return Result::Err(MemoryError::OutOfBoundsError); - } - - let mut size = Vec::new(); - let mut number_of_cells = 1; - for i in access.len()..memory_slice.route.len() { - number_of_cells *= memory_slice.route[i]; - size.push(memory_slice.route[i]); - } - Result::Ok((size, number_of_cells)) - } - - fn generate_slice_from_access( - memory_slice: &MemorySlice, - access: &[SliceCapacity], - ) -> Result, MemoryError> { - if access.is_empty() { - return Result::Ok(memory_slice.clone()); - } - - let (size, number_of_cells) = - MemorySlice::generate_new_route_from_access(memory_slice, access)?; - let mut values = Vec::with_capacity(number_of_cells); - let initial_cell = MemorySlice::get_initial_cell(memory_slice, access)?; - let mut offset = 0; - while offset < number_of_cells { - let new_value = memory_slice.values[initial_cell + offset].clone(); - values.push(new_value); - offset += 1; - } - - Result::Ok(MemorySlice { route: size, values }) - } - - // User operations - pub fn new(initial_value: &C) -> MemorySlice { - MemorySlice::new_with_route(&[], initial_value) - } - pub fn new_array(route: Vec, values: Vec) -> MemorySlice { - MemorySlice { route, values } - } - pub fn new_with_route(route: &[SliceCapacity], initial_value: &C) -> MemorySlice { - let mut length = 1; - for i in route { - length *= *i; - } - - let mut values = Vec::with_capacity(length); - for _i in 0..length { - values.push(initial_value.clone()); - } - - MemorySlice { route: route.to_vec(), values } - } - pub fn insert_values( - memory_slice: &mut MemorySlice, - access: &[SliceCapacity], - new_values: &MemorySlice, - ) -> Result<(), MemoryError> { - let mut cell = MemorySlice::get_initial_cell(memory_slice, access)?; - if MemorySlice::get_number_of_cells(new_values) - > (MemorySlice::get_number_of_cells(memory_slice) - cell) - { - return Result::Err(MemoryError::OutOfBoundsError); - } - for value in new_values.values.iter() { - memory_slice.values[cell] = value.clone(); - cell += 1; - } - Result::Ok(()) - } - - pub fn access_values( - memory_slice: &MemorySlice, - access: &[SliceCapacity], - ) -> Result, MemoryError> { - MemorySlice::generate_slice_from_access(memory_slice, access) - } - pub fn get_reference_to_single_value<'a>( - memory_slice: &'a MemorySlice, - access: &[SliceCapacity], - ) -> Result<&'a C, MemoryError> { - assert_eq!(memory_slice.route.len(), access.len()); - let cell = MemorySlice::get_initial_cell(memory_slice, access)?; - Result::Ok(&memory_slice.values[cell]) - } - pub fn get_mut_reference_to_single_value<'a>( - memory_slice: &'a mut MemorySlice, - access: &[SliceCapacity], - ) -> Result<&'a mut C, MemoryError> { - assert_eq!(memory_slice.route.len(), access.len()); - let cell = MemorySlice::get_initial_cell(memory_slice, access)?; - Result::Ok(&mut memory_slice.values[cell]) - } - pub fn get_number_of_cells(memory_slice: &MemorySlice) -> SliceCapacity { - memory_slice.values.len() - } - pub fn route(&self) -> &[SliceCapacity] { - &self.route - } - pub fn is_single(&self) -> bool { - self.route.is_empty() - } - - /* - !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - ! Calling this function with a MemorySlice ! - ! that has more than one cell will cause ! - ! the compiler to panic. Use carefully ! - !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - */ - pub fn unwrap_to_single(memory_slice: MemorySlice) -> C { - assert!(memory_slice.is_single()); - let mut memory_slice = memory_slice; - memory_slice.values.pop().unwrap() - } - pub fn destruct(self) -> (Vec, Vec) { - (self.route, self.values) - } -} - -#[cfg(test)] -mod tests { - use super::*; - type U32Slice = MemorySlice; - - #[test] - fn memory_slice_vector_initialization() { - let route = vec![3, 4]; - let slice = U32Slice::new_with_route(&route, &0); - assert_eq!(U32Slice::get_number_of_cells(&slice), 12); - for (dim_0, dim_1) in slice.route.iter().zip(&route) { - assert_eq!(*dim_0, *dim_1); - } - for f in 0..3 { - for c in 0..4 { - let result = U32Slice::get_reference_to_single_value(&slice, &[f, c]); - if let Result::Ok(v) = result { - assert_eq!(*v, 0); - } else { - assert!(false); - } - } - } - } - #[test] - fn memory_slice_single_initialization() { - let slice = U32Slice::new(&4); - assert_eq!(U32Slice::get_number_of_cells(&slice), 1); - let memory_response = U32Slice::get_reference_to_single_value(&slice, &[]); - if let Result::Ok(val) = memory_response { - assert_eq!(*val, 4); - } else { - assert!(false); - } - } - #[test] - fn memory_slice_multiple_insertion() { - let route = vec![3, 4]; - let mut slice = U32Slice::new_with_route(&route, &0); - let new_row = U32Slice::new_with_route(&[4], &4); - - let res = U32Slice::insert_values(&mut slice, &[2], &new_row); - if let Result::Ok(_) = res { - for c in 0..4 { - let memory_result = U32Slice::get_reference_to_single_value(&slice, &[2, c]); - if let Result::Ok(val) = memory_result { - assert_eq!(*val, 4); - } else { - assert!(false); - } - } - } else { - assert!(false); - } - } -} diff --git a/program_structure/src/utils/mod.rs b/program_structure/src/utils/mod.rs index d33c6ff..2d95c5f 100644 --- a/program_structure/src/utils/mod.rs +++ b/program_structure/src/utils/mod.rs @@ -1,6 +1,5 @@ pub mod constants; pub mod environment; -pub mod memory_slice; pub mod nonempty_vec; pub mod report_writer; pub mod sarif_conversion; diff --git a/program_structure_tests/Cargo.toml b/program_structure_tests/Cargo.toml index abdc0d8..a460a29 100644 --- a/program_structure_tests/Cargo.toml +++ b/program_structure_tests/Cargo.toml @@ -4,9 +4,9 @@ version = "0.5.1" edition = "2021" [dependencies] -parser = { package = "circomspect-parser", version = "2.0.1", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure"} +parser = { package = "circomspect-parser", version = "2.0.8", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure"} [dev-dependencies] -parser = { package = "circomspect-parser", version = "2.0.1", path = "../parser" } -program_structure = { package = "circomspect-program-structure", version = "2.0.2", path = "../program_structure"} +parser = { package = "circomspect-parser", version = "2.0.8", path = "../parser" } +program_structure = { package = "circomspect-program-structure", version = "2.0.8", path = "../program_structure"} diff --git a/program_structure_tests/src/control_flow_graph.rs b/program_structure_tests/src/control_flow_graph.rs index 95d90c4..64b5cad 100644 --- a/program_structure_tests/src/control_flow_graph.rs +++ b/program_structure_tests/src/control_flow_graph.rs @@ -2,6 +2,7 @@ use std::collections::{HashMap, HashSet}; use parser::parse_definition; use program_structure::cfg::*; +use program_structure::constants::Curve; use program_structure::report::ReportCollection; use program_structure::ir::VariableName; @@ -403,7 +404,7 @@ fn validate_cfg( ) { // 1. Generate CFG from source. let mut reports = ReportCollection::new(); - let cfg = parse_definition(src).unwrap().into_cfg(&mut reports).unwrap(); + let cfg = parse_definition(src).unwrap().into_cfg(&Curve::default(), &mut reports).unwrap(); assert!(reports.is_empty()); // 2. Verify declared variables. @@ -459,7 +460,7 @@ fn validate_dominance( ) { // 1. Generate CFG from source. let mut reports = ReportCollection::new(); - let cfg = parse_definition(src).unwrap().into_cfg(&mut reports).unwrap(); + let cfg = parse_definition(src).unwrap().into_cfg(&Curve::default(), &mut reports).unwrap(); assert!(reports.is_empty()); // 2. Validate immediate dominators. @@ -489,7 +490,7 @@ fn validate_branches( ) { // 1. Generate CFG from source. let mut reports = ReportCollection::new(); - let cfg = parse_definition(src).unwrap().into_cfg(&mut reports).unwrap(); + let cfg = parse_definition(src).unwrap().into_cfg(&Curve::default(), &mut reports).unwrap(); assert!(reports.is_empty()); // 2. Validate the set of true branches. diff --git a/program_structure_tests/src/static_single_assignment.rs b/program_structure_tests/src/static_single_assignment.rs index 14ab946..4960eb3 100644 --- a/program_structure_tests/src/static_single_assignment.rs +++ b/program_structure_tests/src/static_single_assignment.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use parser::parse_definition; use program_structure::cfg::{BasicBlock, Cfg, IntoCfg}; +use program_structure::constants::Curve; use program_structure::report::ReportCollection; use program_structure::ir::variable_meta::VariableMeta; use program_structure::ir::{AssignOp, Statement, VariableName}; @@ -150,7 +151,12 @@ fn test_ssa_with_non_unique_variables() { fn validate_ssa(src: &str, variables: &[&str]) { // 1. Generate CFG and convert to SSA. let mut reports = ReportCollection::new(); - let cfg = parse_definition(src).unwrap().into_cfg(&mut reports).unwrap().into_ssa().unwrap(); + let cfg = parse_definition(src) + .unwrap() + .into_cfg(&Curve::default(), &mut reports) + .unwrap() + .into_ssa() + .unwrap(); assert!(reports.is_empty()); // 2. Check that each variable is assigned at most once.