Skip to content

Commit

Permalink
zkevm-framework: enable setup in assigner
Browse files Browse the repository at this point in the history
  • Loading branch information
akokoshn authored and akokoshn committed Sep 20, 2024
1 parent 49eb312 commit dc8911c
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 3 deletions.
110 changes: 107 additions & 3 deletions zkevm-framework/bin/assigner/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <map>
#include <optional>
#include <string>
#include <chrono>

#ifndef BOOST_FILESYSTEM_NO_DEPRECATED
#define BOOST_FILESYSTEM_NO_DEPRECATED
Expand All @@ -27,14 +28,95 @@
#include "checks.hpp"
#include "zkevm_framework/assigner_runner/runner.hpp"
#include "zkevm_framework/preset/preset.hpp"
#include "zkevm_framework/assigner_runner/write_assignments.hpp"
#include "zkevm_framework/assigner_runner/write_circuits.hpp"

template<typename Endianness, typename ArithmetizationType, typename BlueprintFieldType>
std::optional<std::string> write_circuit(nil::evm_assigner::zkevm_circuit idx,
const std::unordered_map<nil::evm_assigner::zkevm_circuit, nil::blueprint::assignment<ArithmetizationType>>& assignments,
const nil::blueprint::circuit<ArithmetizationType> circuit,
const std::string& concrete_circuit_file_name) {
const auto find_it = assignments.find(idx);
if (find_it == assignments.end()) {
return "Can't find assignment table";
}
std::vector<std::size_t> public_input_column_sizes;
const auto public_input_size = find_it->second.public_inputs_amount();
for (std::uint32_t i = 0; i < public_input_size; i++) {
public_input_column_sizes.push_back(find_it->second.public_input_column_size(i));
}
return write_binary_circuit<Endianness, ArithmetizationType, BlueprintFieldType>(circuit, public_input_column_sizes, concrete_circuit_file_name);
}

template<typename BlueprintFieldType, typename ArithmetizationType>
int setup_prover(const std::string& assignment_table_file_name,
const std::string& circuit_file_name,
zkevm_circuits<ArithmetizationType>& circuits) {
auto start = std::chrono::high_resolution_clock::now();

using Endianness = nil::marshalling::option::big_endian;

std::unordered_map<nil::evm_assigner::zkevm_circuit,
nil::blueprint::assignment<ArithmetizationType>>
assignments;

auto init_start = std::chrono::high_resolution_clock::now();
auto err = initialize_circuits<BlueprintFieldType>(circuits, assignments);
if (err) {
std::cerr << "Preset step failed: " << err.value() << std::endl;
return 1;
}
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - init_start);
std::cout << "INITIALIZE: " << duration.count() << " ms\n";

auto write_assignments_start = std::chrono::high_resolution_clock::now();
err = write_binary_assignments<Endianness, ArithmetizationType, BlueprintFieldType>(
assignments, assignment_table_file_name);
if (err) {
std::cerr << "Write assignments failed: " << err.value() << std::endl;
return 1;
}
duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - write_assignments_start);
std::cout << "WRITE ASSIGNMENT TABLES: " << duration.count() << " ms\n";

auto write_circuits_start = std::chrono::high_resolution_clock::now();
const auto& circuit_names = circuits.get_circuit_names();
for (const auto& circuit_name : circuit_names) {
std::string concrete_circuit_file_name = circuit_file_name + "_" + circuit_name;
if (circuit_name == "bytecode") {
err = write_circuit<Endianness, ArithmetizationType, BlueprintFieldType>(nil::evm_assigner::zkevm_circuit::BYTECODE,
assignments,
circuits.m_bytecode_circuit,
concrete_circuit_file_name);
} else if (circuit_name == "sha256") {
err = write_circuit<Endianness, ArithmetizationType, BlueprintFieldType>(nil::evm_assigner::zkevm_circuit::BYTECODE,
assignments,
circuits.m_sha256_circuit,
concrete_circuit_file_name);
}
if (err) {
std::cerr << "Write circuits failed: " << err.value() << std::endl;
return 1;
}
}
duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - write_circuits_start);
std::cout << "WRITE CIRCUITS: " << duration.count() << " ms\n";

duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - start);
std::cout << "SETUP: " << duration.count() << " ms\n";

return 0;
}

template<typename BlueprintFieldType>
int curve_dependent_main(uint64_t shardId, const std::string& blockHash,
const std::string& block_file_name,
const std::string& account_storage_file_name,
const std::string& assignment_table_file_name,
const std::string& circuit_file_name,
const std::optional<OutputArtifacts>& artifacts,
const std::vector<std::string>& target_circuits,
bool is_setup,
boost::log::trivial::severity_level log_level) {
using ArithmetizationType =
nil::crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType>;
Expand All @@ -44,6 +126,11 @@ int curve_dependent_main(uint64_t shardId, const std::string& blockHash,
zkevm_circuits<ArithmetizationType> circuits;
circuits.m_names = target_circuits;


if (is_setup) {
BOOST_LOG_TRIVIAL(debug) << "SetUp prover\n";
return setup_prover<BlueprintFieldType, ArithmetizationType>(assignment_table_file_name, circuit_file_name, circuits);
}
std::unordered_map<nil::evm_assigner::zkevm_circuit,
nil::blueprint::assignment<ArithmetizationType>>
assignments;
Expand Down Expand Up @@ -95,7 +182,9 @@ int main(int argc, char* argv[]) {
// clang-format off
options_desc.add_options()("help,h", "Display help message")
("version,v", "Display version")
("assignment-tables,t", boost::program_options::value<std::string>(), "Assignment table output files")
("setup", "Run prover setup")
("assignment-tables,t", boost::program_options::value<std::string>(), "Assignment tables output files")
("circuits,c", boost::program_options::value<std::string>(), "Circuits output files")
("output-text", boost::program_options::value<std::string>(), "Output assignment table in readable format. "
"Filename or `-` for stdout. "
"Using this enables options --tables, --rows, --columns")
Expand Down Expand Up @@ -148,6 +237,7 @@ int main(int argc, char* argv[]) {

uint64_t shardId = 0;
std::string assignment_table_file_name;
std::string circuit_file_name;
std::string blockHash;
std::string block_file_name;
std::string account_storage_file_name;
Expand All @@ -164,6 +254,15 @@ int main(int argc, char* argv[]) {
return 1;
}

if (vm.count("circuits")) {
circuit_file_name = vm["circuits"].as<std::string>();
} else {
std::cerr << "Invalid command line argument - circuits file name is not specified"
<< std::endl;
std::cout << options_desc << std::endl;
return 1;
}

if (vm.count("block-file")) {
block_file_name = vm["block-file"].as<std::string>();
} else {
Expand Down Expand Up @@ -215,6 +314,11 @@ int main(int argc, char* argv[]) {
target_circuits = vm["target-circuits"].as<std::vector<std::string>>();
}

bool is_setup = false;
if (vm.count("setup")) {
is_setup = true;
}

if (vm.count("log-level")) {
log_level = vm["log-level"].as<std::string>();
} else {
Expand Down Expand Up @@ -251,7 +355,7 @@ int main(int argc, char* argv[]) {
return curve_dependent_main<
typename nil::crypto3::algebra::curves::pallas::base_field_type>(
shardId, blockHash, block_file_name, account_storage_file_name,
assignment_table_file_name, artifacts, target_circuits, log_options[log_level]);
assignment_table_file_name, circuit_file_name, artifacts, target_circuits, is_setup, log_options[log_level]);
break;
}
case 1: {
Expand All @@ -266,7 +370,7 @@ int main(int argc, char* argv[]) {
return curve_dependent_main<
typename nil::crypto3::algebra::fields::bls12_base_field<381>>(
shardId, blockHash, block_file_name, account_storage_file_name,
assignment_table_file_name, artifacts, target_circuits, log_options[log_level]);
assignment_table_file_name, circuit_file_name, artifacts, target_circuits, is_setup, log_options[log_level]);
break;
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* @file write_circuits.hpp
*
* @brief This file defines functions for writing circuits in binary mode.
*/

#ifndef ZKEMV_FRAMEWORK_LIBS_ASSIGNER_RUNNER_INCLUDE_ZKEVM_FRAMEWORK_ASSIGNER_RUNNER_WRITE_CIRCUITS_HPP_
#define ZKEMV_FRAMEWORK_LIBS_ASSIGNER_RUNNER_INCLUDE_ZKEVM_FRAMEWORK_ASSIGNER_RUNNER_WRITE_CIRCUITS_HPP_

#include <array>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <optional>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>

#include "nil/blueprint/blueprint/plonk/assignment.hpp"
#include "nil/blueprint/blueprint/plonk/circuit.hpp"
#include "nil/crypto3/marshalling/algebra/types/field_element.hpp"
#include "nil/crypto3/marshalling/zk/types/plonk/constraint_system.hpp"
#include "nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp"
#include "nil/marshalling/types/integral.hpp"
#include "output_artifacts.hpp"

/**
* @brief Write circuit serialized into binary to output file.
*/
template<typename Endianness, typename ArithmetizationType, typename BlueprintFieldType>
std::optional<std::string> write_binary_circuit(const nil::blueprint::circuit<ArithmetizationType>& circuit,
const std::vector<std::size_t> public_input_column_sizes,
const std::string& filename) {
std::ofstream fout(filename, std::ios_base::binary | std::ios_base::out);
if (!fout.is_open()) {
return "Cannot open " + filename;
}
BOOST_LOG_TRIVIAL(debug) << "writing circuit into file "
<< filename;

using TTypeBase = nil::marshalling::field_type<Endianness>;
using ConstraintSystemType = nil::crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType>;
using value_marshalling_type =
nil::crypto3::marshalling::types::plonk_constraint_system<TTypeBase, ConstraintSystemType>;

// fill public input sizes
nil::crypto3::marshalling::types::public_input_sizes_type<TTypeBase> public_input_sizes;
using public_input_size_type = typename nil::crypto3::marshalling::types::public_input_sizes_type<TTypeBase>::element_type;
const auto public_input_size = public_input_column_sizes.size();
for (auto i : public_input_column_sizes) {
public_input_sizes.value().push_back(public_input_size_type(i));
}

auto filled_val =
value_marshalling_type(std::make_tuple(
nil::crypto3::marshalling::types::fill_plonk_gates<Endianness, typename ConstraintSystemType::gates_container_type::value_type>(circuit.gates()),
nil::crypto3::marshalling::types::fill_plonk_copy_constraints<Endianness, typename ConstraintSystemType::field_type>(circuit.copy_constraints()),
nil::crypto3::marshalling::types::fill_plonk_lookup_gates<Endianness, typename ConstraintSystemType::lookup_gates_container_type::value_type>(circuit.lookup_gates()),
nil::crypto3::marshalling::types::fill_plonk_lookup_tables<Endianness, typename ConstraintSystemType::lookup_tables_type::value_type>(circuit.lookup_tables()),
public_input_sizes
));

std::vector<std::uint8_t> cv;
cv.resize(filled_val.length(), 0x00);
auto cv_iter = cv.begin();
nil::marshalling::status_type status = filled_val.write(cv_iter, cv.size());
fout.write(reinterpret_cast<char*>(cv.data()), cv.size());

fout.close();
return {};
}

#endif // ZKEMV_FRAMEWORK_LIBS_ASSIGNER_RUNNER_INCLUDE_ZKEVM_FRAMEWORK_ASSIGNER_RUNNER_WRITE_CIRCUITS_HPP_
2 changes: 2 additions & 0 deletions zkevm-framework/libs/nil_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/in

target_link_libraries(${LIBRARY_NAME} PUBLIC intx::intx sszpp::sszpp)

install(TARGETS ${LIBRARY_NAME}
DESTINATION ${CMAKE_INSTALL_LIBDIR})
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
#include <unordered_map>

#include "zkevm_framework/preset/bytecode.hpp"
#include "zkevm_framework/preset/sha256.hpp"

template<typename ArithmetizationType>
struct zkevm_circuits {
std::vector<std::string> m_default_names = {"bytecode"};
std::vector<std::string> m_names;
nil::blueprint::circuit<ArithmetizationType> m_bytecode_circuit;
nil::blueprint::circuit<ArithmetizationType> m_sha256_circuit;
const std::vector<std::string>& get_circuit_names() {
return m_names.size() > 0 ? m_names : m_default_names;
}
Expand All @@ -38,6 +40,11 @@ std::optional<std::string> initialize_circuits(
if (err) {
return err;
}
} else if (circuit_name == "sha256") {
auto err = initialize_sha256_circuit(circuits.m_sha256_circuit, assignments);
if (err) {
return err;
}
}
}
return {};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#ifndef ZKEMV_FRAMEWORK_LIBS_PRESET_SHA256_HPP_
#define ZKEMV_FRAMEWORK_LIBS_PRESET_SHA256_HPP_

#include <assigner.hpp>
#include <boost/log/core.hpp>
#include <boost/log/expressions.hpp>
#include <boost/log/trivial.hpp>
#include <nil/blueprint/blueprint/plonk/assignment.hpp>
#include <nil/blueprint/components/hashes/sha2/plonk/sha256.hpp>
#include <optional>
#include <string>

template<typename BlueprintFieldType>
std::optional<std::string> initialize_sha256_circuit(
nil::blueprint::circuit<nil::crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType>>&
sha256_circuit,
std::unordered_map<nil::evm_assigner::zkevm_circuit,
nil::blueprint::assignment<
nil::crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType>>>&
assignments) {
// initialize assignment table
nil::crypto3::zk::snark::plonk_table_description<BlueprintFieldType> desc(65, // witness
1, // public
35, // constants
56 // selectors
);
BOOST_LOG_TRIVIAL(debug) << "sha256 table:\n"
<< "witnesses = " << desc.witness_columns
<< " public inputs = " << desc.public_input_columns
<< " constants = " << desc.constant_columns
<< " selectors = " << desc.selector_columns << "\n";
using ArithmetizationType =
nil::crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType>;

auto insert_it = assignments.insert(std::pair<nil::evm_assigner::zkevm_circuit,
nil::blueprint::assignment<ArithmetizationType>>(
nil::evm_assigner::zkevm_circuit::BYTECODE,// index = 0, just for experiment with sha256
nil::blueprint::assignment<ArithmetizationType>(desc)));
auto& sha256_table = insert_it.first->second;

using component_type =
nil::blueprint::components::sha256<nil::crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType>>;

// Prepare witness container to make an instance of the component
typename component_type::manifest_type m = component_type::get_manifest();
size_t witness_amount = *(m.witness_amount->begin());
std::vector<std::uint32_t> witnesses(witness_amount);
std::iota(witnesses.begin(), witnesses.end(), 0); // fill 0, 1, ...

component_type component_instance = component_type(
witnesses, std::array<std::uint32_t, 1>{0}, std::array<std::uint32_t, 1>{0});

auto lookup_tables = component_instance.component_lookup_tables();
for (auto& [k, v] : lookup_tables) {
sha256_circuit.reserve_table(k);
}

constexpr const std::int32_t block_size = 2;
constexpr const std::int32_t input_blocks_amount = 2;

const auto& row_idx = sha256_table.public_input_column_size(0);
std::array<typename component_type::var, input_blocks_amount * block_size> input_block_vars = {
typename component_type::var(0, row_idx, false, component_type::var::column_type::public_input),
typename component_type::var(0, row_idx + 1, false, component_type::var::column_type::public_input),
typename component_type::var(0, row_idx + 2, false, component_type::var::column_type::public_input),
typename component_type::var(0, row_idx + 3, false, component_type::var::column_type::public_input)
};
typename component_type::input_type input = {input_block_vars};

nil::blueprint::components::generate_circuit(component_instance, sha256_circuit,
sha256_table, input, 0);
std::vector<size_t> lookup_columns_indices;
for (std::size_t i = 1; i < sha256_table.constants_amount(); i++) {
lookup_columns_indices.push_back(i);
}

std::size_t cur_selector_id = 0;
for (const auto& gate : sha256_circuit.gates()) {
cur_selector_id = std::max(cur_selector_id, gate.selector_index);
}
for (const auto& lookup_gate : sha256_circuit.lookup_gates()) {
cur_selector_id = std::max(cur_selector_id, lookup_gate.tag_index);
}
cur_selector_id++;
nil::crypto3::zk::snark::pack_lookup_tables_horizontal(
sha256_circuit.get_reserved_indices(), sha256_circuit.get_reserved_tables(),
sha256_circuit.get_reserved_dynamic_tables(), sha256_circuit, sha256_table,
lookup_columns_indices, cur_selector_id, sha256_table.rows_amount(), 500000);
BOOST_LOG_TRIVIAL(debug) << "rows amount = " << sha256_table.rows_amount() << "\n";
return {};
}

#endif // ZKEMV_FRAMEWORK_LIBS_PRESET_SHA256_HPP_

0 comments on commit dc8911c

Please sign in to comment.