From 48050c2552d22686f77db30c53cc544d3218d1a3 Mon Sep 17 00:00:00 2001 From: Marcos Pernambuco Motta <1091485+mpernambuco@users.noreply.github.commits> Date: Mon, 13 Nov 2023 16:50:51 -0300 Subject: [PATCH] feat: Replace proof with sibling_hashes --- lib/grpc-interfaces | 2 +- src/Makefile | 7 +- src/access-log.h | 64 ++--- src/cartesi-machine-tests.lua | 1 - src/cartesi/util.lua | 1 - src/clua-jsonrpc-machine.cpp | 6 +- src/clua-machine-util.cpp | 24 +- src/json-util.cpp | 26 ++- src/machine-c-api.cpp | 67 +++--- src/machine-c-api.h | 22 +- src/machine.cpp | 95 ++++---- src/merkle-tree-proof.h | 7 +- src/protobuf-util.cpp | 128 +++++----- src/test-machine-c-api.cpp | 12 +- src/test-utils.h | 4 +- src/tests/machine-bind.lua | 106 ++++++++- src/tests/mcycle-overflow.lua | 2 +- src/tests/util.lua | 10 +- src/uarch-record-reset-state-access.h | 4 +- src/uarch-record-step-state-access.h | 13 +- src/uarch-replay-reset-state-access.h | 84 ++----- src/uarch-replay-step-state-access.h | 220 +++++++----------- src/uarch-riscv-tests.lua | 21 +- .../src/run-rv64i-arch-test.lua | 3 - 24 files changed, 485 insertions(+), 444 deletions(-) diff --git a/lib/grpc-interfaces b/lib/grpc-interfaces index d69a1f48a..a9241a715 160000 --- a/lib/grpc-interfaces +++ b/lib/grpc-interfaces @@ -1 +1 @@ -Subproject commit d69a1f48a9a0cb79a44db828479b4e316acd8cca +Subproject commit a9241a715bb113656a775c972f2497f50bba8f4b diff --git a/src/Makefile b/src/Makefile index 7dd5fa0f4..143c26cc7 100644 --- a/src/Makefile +++ b/src/Makefile @@ -257,7 +257,9 @@ PGO_WORKLOAD=\ whetstone 2500 # We ignore test-machine-c-api.cpp cause it takes too long. -LINTER_IGNORE_SOURCES=test-machine-c-api.cpp +# We ignore uarch-pristine-ram.c because it is generated by xxd. +# We ignore uarch-pristine-state-hash.cpp because it is generated by compute-uarch-pristine-hash. +LINTER_IGNORE_SOURCES=test-machine-c-api.cpp uarch-pristine-ram.c uarch-pristine-state-hash.cpp LINTER_IGNORE_HEADERS=%.pb.h LINTER_SOURCES=$(filter-out $(LINTER_IGNORE_SOURCES),$(strip $(wildcard *.cpp) $(wildcard *.c))) LINTER_HEADERS=$(filter-out $(LINTER_IGNORE_HEADERS),$(strip $(wildcard *.hpp) $(wildcard *.h))) @@ -269,7 +271,9 @@ CLANG_FORMAT=clang-format CLANG_FORMAT_UARCH_FILES:=$(wildcard ../uarch/*.cpp) CLANG_FORMAT_UARCH_FILES:=$(filter-out %uarch-printf%,$(strip $(CLANG_FORMAT_UARCH_FILES))) CLANG_FORMAT_FILES:=$(wildcard *.cpp) $(wildcard *.c) $(wildcard *.h) $(wildcard *.hpp) $(CLANG_FORMAT_UARCH_FILES) +CLANG_FORMAT_IGNORE_FILES:=uarch-pristine-ram.c uarch-pristine-state-hash.cpp CLANG_FORMAT_FILES:=$(filter-out %.pb.h,$(strip $(CLANG_FORMAT_FILES))) +CLANG_FORMAT_FILES:=$(filter-out $(CLANG_FORMAT_IGNORE_FILES),$(strip $(CLANG_FORMAT_FILES))) STYLUA=stylua STYLUA_FLAGS=--indent-type Spaces --collapse-simple-statement Always @@ -755,7 +759,6 @@ jsonrpc-discover.cpp: jsonrpc-discover.json uarch-pristine-state-hash.cpp: compute-uarch-pristine-hash @echo '// This file is auto-generated and should not be modified' > $@ - @echo '// clang-format off' >> $@ @echo '#include "uarch-pristine-state-hash.h"' >> $@ @echo 'namespace cartesi {' >> $@ @echo ' const machine_merkle_tree::hash_type uarch_pristine_state_hash{' >> $@ diff --git a/src/access-log.h b/src/access-log.h index 20d588ed0..f657782c0 100644 --- a/src/access-log.h +++ b/src/access-log.h @@ -57,8 +57,12 @@ static inline uint64_t get_word_access_data(const access_data &ad) { /// NOLINTNEXTLINE(bugprone-exception-escape) class access { - using proof_type = machine_merkle_tree::proof_type; + using hasher_type = machine_merkle_tree::hasher_type; + +public: using hash_type = machine_merkle_tree::hash_type; + using sibling_hashes_type = std::vector; + using proof_type = machine_merkle_tree::proof_type; public: void set_type(access_type type) { @@ -158,38 +162,46 @@ class access { return m_read_hash; } - /// \brief Sets proof that data read at address was in - /// Merkle tree before access. - /// \param proof Corresponding Merkle tree proof. - void set_proof(const proof_type &proof) { - m_proof = proof; - } - void set_proof(proof_type &&proof) { - m_proof = std::move(proof); + /// \brief Constructs a proof using this access' data and a given root hash. + /// \param root_hash Hash to be used as the root of the proof. + /// \return The corresponding proof + proof_type make_proof(const hash_type root_hash) const { + if (!m_sibling_hashes.has_value()) { + throw std::runtime_error("can't make proof if access doesn't have sibling hashes"); + } + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + const auto &sibiling_hashes = m_sibling_hashes.value(); + auto log2_root_size = m_log2_size + sibiling_hashes.size(); + proof_type proof(log2_root_size, m_log2_size); + proof.set_root_hash(root_hash); + proof.set_target_address(m_address); + proof.set_target_hash(m_read_hash); + for (size_t log2_size = m_log2_size; log2_size < log2_root_size; log2_size++) { + proof.set_sibling_hash(sibiling_hashes[log2_size - m_log2_size], log2_size); + } + return proof; } - /// \brief Gets proof that data read at address was in - /// Merkle tree before access. - /// \returns Proof, if one is available. - const std::optional &get_proof(void) const { - return m_proof; + std::optional &get_sibling_hashes() { + return m_sibling_hashes; + } + const std::optional &get_sibling_hashes() const { + return m_sibling_hashes; } - /// \brief Removes proof that data read at address was in - /// Merkle tree before access. - void clear_proof(void) { - m_proof = std::nullopt; + void set_sibling_hashes(const sibling_hashes_type &sibling_hashes) { + m_sibling_hashes = sibling_hashes; } private: - access_type m_type{0}; ///< Type of access - uint64_t m_address{0}; ///< Address of access - int m_log2_size{0}; ///< Log2 of size of access - std::optional m_read{}; ///< Data before access - hash_type m_read_hash; ///< Hash of data before access - std::optional m_written{}; ///< Written data - std::optional m_written_hash{}; ///< Hash of written data - std::optional m_proof{}; ///< Proof of data before access + access_type m_type{0}; ///< Type of access + uint64_t m_address{0}; ///< Address of access + int m_log2_size{0}; ///< Log2 of size of access + std::optional m_read{}; ///< Data before access + hash_type m_read_hash; ///< Hash of data before access + std::optional m_written{}; ///< Written data + std::optional m_written_hash{}; ///< Hash of written data + std::optional m_sibling_hashes{}; ///< Hashes of siblings in path from address to root }; /// \brief Log of state accesses diff --git a/src/cartesi-machine-tests.lua b/src/cartesi-machine-tests.lua index 921f28c48..b55441b34 100755 --- a/src/cartesi-machine-tests.lua +++ b/src/cartesi-machine-tests.lua @@ -673,7 +673,6 @@ local function print_machine(test_name, expected_cycles) --ram-length=32Mi\ --ram-image='%s'\ --no-bootargs\ - --uarch-ram-length=%d\ --uarch-ram-image=%s\ --max-mcycle=%d ", test_path .. "/" .. test_name, diff --git a/src/cartesi/util.lua b/src/cartesi/util.lua index 307d8f2db..1ecebe789 100644 --- a/src/cartesi/util.lua +++ b/src/cartesi/util.lua @@ -232,7 +232,6 @@ function _M.dump_log(log, out) j = j + 1 -- Otherwise, output access elseif ai then - if ai.proof then indentout(out, indent, "hash %s\n", hexhash8(ai.proof.root_hash)) end local read = accessdatastring(ai.read, ai.read_hash, ai.log2_size) if ai.type == "read" then indentout(out, indent, "%d: read %s@0x%x(%u): %s\n", i, notes[i] or "", ai.address, ai.address, read) diff --git a/src/clua-jsonrpc-machine.cpp b/src/clua-jsonrpc-machine.cpp index 6bc60411b..2ae47afec 100644 --- a/src/clua-jsonrpc-machine.cpp +++ b/src/clua-jsonrpc-machine.cpp @@ -97,9 +97,8 @@ static int jsonrpc_machine_class_verify_uarch_step_state_transition(lua_State *L clua_check_cm_hash(L, 3, &target_hash); auto &managed_runtime_config = clua_push_to(L, clua_managed_cm_ptr(clua_opt_cm_machine_runtime_config(L, 4, {}, ctxidx)), ctxidx); - const bool one_based = lua_toboolean(L, 5); TRY_EXECUTE(cm_jsonrpc_verify_uarch_step_state_transition(managed_jsonrpc_mg_mgr.get(), &root_hash, - managed_log.get(), &target_hash, managed_runtime_config.get(), one_based, err_msg)); + managed_log.get(), &target_hash, managed_runtime_config.get(), true, err_msg)); managed_log.reset(); managed_runtime_config.reset(); lua_pop(L, 2); @@ -122,9 +121,8 @@ static int jsonrpc_machine_class_verify_uarch_reset_state_transition(lua_State * clua_check_cm_hash(L, 3, &target_hash); auto &managed_runtime_config = clua_push_to(L, clua_managed_cm_ptr(clua_opt_cm_machine_runtime_config(L, 4, {}, ctxidx)), ctxidx); - const bool one_based = lua_toboolean(L, 5); TRY_EXECUTE(cm_jsonrpc_verify_uarch_reset_state_transition(managed_jsonrpc_mg_mgr.get(), &root_hash, - managed_log.get(), &target_hash, managed_runtime_config.get(), one_based, err_msg)); + managed_log.get(), &target_hash, managed_runtime_config.get(), true, err_msg)); managed_log.reset(); managed_runtime_config.reset(); lua_pop(L, 2); diff --git a/src/clua-machine-util.cpp b/src/clua-machine-util.cpp index 60a428295..46151ce35 100644 --- a/src/clua-machine-util.cpp +++ b/src/clua-machine-util.cpp @@ -352,9 +352,9 @@ static void check_sibling_cm_hashes(lua_State *L, int idx, size_t log2_target_si } sibling_hashes->count = sibling_hashes_count; sibling_hashes->entry = new cm_hash[sibling_hashes_count]{}; - for (; log2_target_size < log2_root_size; ++log2_target_size) { - lua_rawgeti(L, idx, static_cast(log2_root_size - log2_target_size)); - auto index = log2_root_size - 1 - log2_target_size; + for (size_t log2_size = log2_target_size; log2_size < log2_root_size; ++log2_size) { + lua_rawgeti(L, idx, static_cast(log2_size - log2_target_size) + 1); + auto index = log2_size - log2_target_size; clua_check_cm_hash(L, -1, &sibling_hashes->entry[index]); lua_pop(L, 1); } @@ -424,8 +424,7 @@ static unsigned char *opt_cm_access_data_field(lua_State *L, int tabidx, const c /// \param a Pointer to receive access /// \param ctxidx Index (or pseudo-index) of clua context static void check_cm_access(lua_State *L, int tabidx, bool proofs, cm_access *a, int ctxidx) { - ctxidx = lua_absindex(L, ctxidx); - tabidx = lua_absindex(L, tabidx); + (void) ctxidx; luaL_checktype(L, tabidx, LUA_TTABLE); a->type = check_cm_access_type_field(L, tabidx, "type"); a->address = check_uint_field(L, tabidx, "address"); @@ -435,8 +434,9 @@ static void check_cm_access(lua_State *L, int tabidx, bool proofs, cm_access *a, CM_TREE_LOG2_ROOT_SIZE); } if (proofs) { - lua_getfield(L, tabidx, "proof"); - a->proof = clua_check_cm_merkle_tree_proof(L, -1, ctxidx); + lua_getfield(L, tabidx, "sibling_hashes"); + a->sibling_hashes = new cm_hash_array{}; + check_sibling_cm_hashes(L, -1, a->log2_size, CM_TREE_LOG2_ROOT_SIZE, a->sibling_hashes); lua_pop(L, 1); } @@ -651,9 +651,13 @@ void clua_push_cm_access_log(lua_State *L, const cm_access_log *log) { lua_setfield(L, -2, "written"); } } - if (log->log_type.proofs && a->proof != nullptr) { - clua_push_cm_proof(L, a->proof); - lua_setfield(L, -2, "proof"); + if (log->log_type.proofs && a->sibling_hashes != nullptr) { + lua_newtable(L); + for (size_t log2_size = a->log2_size; log2_size < CM_TREE_LOG2_ROOT_SIZE; log2_size++) { + clua_push_cm_hash(L, &a->sibling_hashes->entry[log2_size - a->log2_size]); + lua_rawseti(L, -2, static_cast(log2_size - a->log2_size) + 1); + } + lua_setfield(L, -2, "sibling_hashes"); } lua_rawseti(L, -2, static_cast(i) + 1); } diff --git a/src/json-util.cpp b/src/json-util.cpp index efa38aeaa..81639f857 100644 --- a/src/json-util.cpp +++ b/src/json-util.cpp @@ -653,8 +653,15 @@ void ju_get_opt_field(const nlohmann::json &j, const K &key, access &access, con } not_default_constructible proof; ju_get_opt_field(jk, "proof"s, proof, new_path); - if (proof.has_value()) { - access.set_proof(std::move(proof).value()); + if (contains(jk, "sibling_hashes")) { + access.get_sibling_hashes().emplace(); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto &sibling_hashes = access.get_sibling_hashes().value(); + ju_get_vector_like_field(jk, "sibling_hashes"s, sibling_hashes, new_path); + auto expected_depth = static_cast(machine_merkle_tree::get_log2_root_size() - access.get_log2_size()); + if (sibling_hashes.size() != expected_depth) { + throw std::invalid_argument("field \""s + new_path + "sibling_hashes\" has wrong length"); + } } } @@ -754,8 +761,9 @@ void ju_get_opt_field(const nlohmann::json &j, const K &key, not_default_constru ju_get_vector_like_field(jk, "accesses"s, accesses, new_path); if (log_type.value().has_proofs()) { for (unsigned i = 0; i < accesses.size(); ++i) { - if (!accesses[i].get_proof().has_value()) { - throw std::invalid_argument("field \""s + new_path + "accesses/" + to_string(i) + "\" missing proof"); + if (!accesses[i].get_sibling_hashes().has_value()) { + throw std::invalid_argument( + "field \""s + new_path + "accesses/" + to_string(i) + "\" missing sibling hashes"); } } } @@ -1159,9 +1167,15 @@ void to_json(nlohmann::json &j, const access &a) { j["written"] = encode_base64(a.get_written().value()); } } - if (a.get_proof().has_value()) { + if (a.get_sibling_hashes().has_value()) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - j["proof"] = a.get_proof().value(); + const auto &sibling_hashes = a.get_sibling_hashes().value(); + auto depth = machine_merkle_tree::get_log2_root_size() - a.get_log2_size(); + nlohmann::json s = nlohmann::json::array(); + for (int i = 0; i < depth; i++) { + s.push_back(encode_base64(sibling_hashes[i])); + } + j["sibling_hashes"] = s; } } diff --git a/src/machine-c-api.cpp b/src/machine-c-api.cpp index 7c21424ef..d67922a13 100644 --- a/src/machine-c-api.cpp +++ b/src/machine-c-api.cpp @@ -439,6 +439,25 @@ cartesi::machine_merkle_tree::hash_type convert_from_c(const cm_hash *c_hash) { return cpp_hash; } +std::vector convert_from_c(const cm_hash_array *c_array) { + auto new_array = std::vector(c_array->count); + for (size_t i = 0; i < c_array->count; ++i) { + new_array[i] = convert_from_c(&c_array->entry[i]); + } + return new_array; +} + +static cm_hash_array *convert_to_c(const std::vector &cpp_array) { + auto *new_array = new cm_hash_array{}; + new_array->count = cpp_array.size(); + new_array->entry = new cm_hash[cpp_array.size()]; + memset(new_array->entry, 0, sizeof(cm_hash) * new_array->count); + for (size_t i = 0; i < new_array->count; ++i) { + memcpy(&new_array->entry[i], static_cast(cpp_array[i].data()), sizeof(cm_hash)); + } + return new_array; +} + // ---------------------------------------------- // Semantic version conversion functions // ---------------------------------------------- @@ -458,9 +477,11 @@ cm_semantic_version *convert_to_c(const cartesi::semantic_version &cpp_version) // ---------------------------------------------- /// \brief Converts log2_size to index into siblings array -static int cm_log2_size_to_index(int log2_size, int log2_root_size) { - // We know log2_root_size > 0, so log2_root_size-1 >= 0 - const int index = log2_root_size - 1 - log2_size; +static int cm_log2_size_to_index(int log2_size, int log2_target_size) { + const int index = log2_size - log2_target_size; + if (index < 0) { + throw std::invalid_argument("log2_size can't be smaller than log2_target_size"); + } return index; } @@ -484,8 +505,8 @@ static cm_merkle_tree_proof *convert_to_c(const cartesi::machine_merkle_tree::pr for (size_t log2_size = new_merkle_tree_proof->log2_target_size; log2_size < new_merkle_tree_proof->log2_root_size; ++log2_size) { - const int current_index = - cm_log2_size_to_index(static_cast(log2_size), static_cast(new_merkle_tree_proof->log2_root_size)); + const int current_index = cm_log2_size_to_index(static_cast(log2_size), + static_cast(new_merkle_tree_proof->log2_target_size)); const cartesi::machine_merkle_tree::hash_type sibling_hash = proof.get_sibling_hash(static_cast(log2_size)); memcpy(&(new_merkle_tree_proof->sibling_hashes.entry[current_index]), @@ -495,24 +516,6 @@ static cm_merkle_tree_proof *convert_to_c(const cartesi::machine_merkle_tree::pr return new_merkle_tree_proof; } -static cartesi::machine_merkle_tree::proof_type convert_from_c(const cm_merkle_tree_proof *c_proof) { - cartesi::machine_merkle_tree::proof_type cpp_proof(static_cast(c_proof->log2_root_size), - static_cast(c_proof->log2_target_size)); - cpp_proof.set_target_address(c_proof->target_address); - - cpp_proof.set_root_hash(convert_from_c(&c_proof->root_hash)); - cpp_proof.set_target_hash(convert_from_c(&c_proof->target_hash)); - - for (int log2_size = cpp_proof.get_log2_target_size(); log2_size < cpp_proof.get_log2_root_size(); ++log2_size) { - const int current_index = cm_log2_size_to_index(log2_size, cpp_proof.get_log2_root_size()); - const cartesi::machine_merkle_tree::hash_type cpp_sibling_hash = - convert_from_c(&c_proof->sibling_hashes.entry[current_index]); - cpp_proof.set_sibling_hash(cpp_sibling_hash, log2_size); - } - - return cpp_proof; -} - // ---------------------------------------------- // Access log conversion functions // ---------------------------------------------- @@ -574,24 +577,23 @@ static cm_access convert_to_c(const cartesi::access &cpp_access) { } } - if (cpp_access.get_proof().has_value()) { + if (cpp_access.get_sibling_hashes().has_value()) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - new_access.proof = convert_to_c(*cpp_access.get_proof()); + new_access.sibling_hashes = convert_to_c(*cpp_access.get_sibling_hashes()); } else { - new_access.proof = nullptr; + new_access.sibling_hashes = nullptr; } return new_access; } -static cartesi::access convert_from_c(const cm_access *c_access) { +cartesi::access convert_from_c(const cm_access *c_access) { cartesi::access cpp_access{}; cpp_access.set_type(convert_from_c(c_access->type)); cpp_access.set_log2_size(c_access->log2_size); cpp_access.set_address(c_access->address); - if (c_access->proof != nullptr) { - const cartesi::machine_merkle_tree::proof_type proof = convert_from_c(c_access->proof); - cpp_access.set_proof(proof); + if (c_access->sibling_hashes != nullptr) { + cpp_access.set_sibling_hashes(convert_from_c(c_access->sibling_hashes)); } cpp_access.set_read_hash(convert_from_c(&c_access->read_hash)); @@ -612,7 +614,10 @@ static void cm_cleanup_access(cm_access *access) { if (access == nullptr) { return; } - cm_delete_merkle_tree_proof(access->proof); + if (access->sibling_hashes != nullptr) { + delete[] access->sibling_hashes->entry; + delete access->sibling_hashes; + } delete[] access->written_data; delete[] access->read_data; } diff --git a/src/machine-c-api.h b/src/machine-c-api.h index 92cc270d7..eb55e91b3 100644 --- a/src/machine-c-api.h +++ b/src/machine-c-api.h @@ -313,17 +313,17 @@ typedef struct { // NOLINT(modernize-use-using) } cm_bracket_note; /// \brief Records an access to the machine state -typedef struct { // NOLINT(modernize-use-using) - CM_ACCESS_TYPE type; ///< Type of access - uint64_t address; ///< Address of access - int log2_size; ///< Log2 of size of access - cm_hash read_hash; ///< Hash of data before access - uint8_t *read_data; ///< Data before access - size_t read_data_size; ///< Size of data before access in bytes - cm_hash written_hash; ///< Hash of data after access (if writing) - uint8_t *written_data; ///< Data after access (if writing) - size_t written_data_size; ///< Size of data after access in bytes - cm_merkle_tree_proof *proof; ///< Proof of data before access +typedef struct { // NOLINT(modernize-use-using) + CM_ACCESS_TYPE type; ///< Type of access + uint64_t address; ///< Address of access + int log2_size; ///< Log2 of size of access + cm_hash read_hash; ///< Hash of data before access + uint8_t *read_data; ///< Data before access + size_t read_data_size; ///< Size of data before access in bytes + cm_hash written_hash; ///< Hash of data after access (if writing) + uint8_t *written_data; ///< Data after access (if writing) + size_t written_data_size; ///< Size of data after access in bytes + cm_hash_array *sibling_hashes; ///< Sibling hashes towards root } cm_access; /// \brief Array of accesses diff --git a/src/machine.cpp b/src/machine.cpp index 1445e0266..0165a3115 100644 --- a/src/machine.cpp +++ b/src/machine.cpp @@ -1843,7 +1843,6 @@ void machine::reset_uarch() { access_log machine::log_uarch_reset(const access_log::type &log_type, bool one_based) { hash_type root_hash_before; if (log_type.has_proofs()) { - update_merkle_tree(); get_root_hash(root_hash_before); } // Call uarch_reset_state with a uarch_record_reset_state_access object @@ -1863,20 +1862,13 @@ access_log machine::log_uarch_reset(const access_log::type &log_type, bool one_b return std::move(*a.get_log()); } -void machine::verify_uarch_step_log(const access_log &log, const machine_runtime_config &r, bool one_based) { +void machine::verify_uarch_reset_log(const access_log &log, const machine_runtime_config &r, bool one_based) { (void) r; // There must be at least one access in log if (log.get_accesses().empty()) { throw std::invalid_argument{"too few accesses in log"}; } - uarch_replay_step_state_access a(log, log.get_log_type().has_proofs(), one_based); - uarch_step(a); - a.finish(); -} - -void machine::verify_uarch_reset_log(const access_log &log, const machine_runtime_config &r, bool one_based) { - (void) r; - uarch_replay_reset_state_access a(log, log.get_log_type().has_proofs(), one_based); + uarch_replay_reset_state_access a(log, false /* verify_proofs */, {} /* initial_hash */, one_based); uarch_reset_state(a); a.finish(); } @@ -1892,17 +1884,8 @@ void machine::verify_uarch_reset_state_transition(const hash_type &root_hash_bef if (log.get_accesses().empty()) { throw std::invalid_argument{"too few accesses in log"}; } - // It must contain proofs - if (!log.get_accesses().front().get_proof().has_value()) { - throw std::invalid_argument{"access has no proof"}; - } - // Make sure the access log starts from the same root hash as the state - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - if (log.get_accesses().front().get_proof().value().get_root_hash() != root_hash_before) { - throw std::invalid_argument{"mismatch in root hash before replay"}; - } // Verify all intermediate state transitions - uarch_replay_reset_state_access a(log, log.get_log_type().has_proofs(), one_based); + uarch_replay_reset_state_access a(log, true /* verify_proofs */, root_hash_before, one_based); uarch_reset_state(a); a.finish(); // Make sure the access log ends at the same root hash as the state @@ -1913,6 +1896,42 @@ void machine::verify_uarch_reset_state_transition(const hash_type &root_hash_bef } } +access_log machine::log_uarch_step(const access_log::type &log_type, bool one_based) { + if (m_uarch.get_state().ram.get_istart_E()) { + throw std::runtime_error("microarchitecture RAM is not present"); + } + hash_type root_hash_before; + if (log_type.has_proofs()) { + update_merkle_tree(); + get_root_hash(root_hash_before); + } + // Call interpret with a logged state access object + uarch_record_step_state_access a(m_uarch.get_state(), *this, log_type); + a.push_bracket(bracket_type::begin, "step"); + uarch_step(a); + a.push_bracket(bracket_type::end, "step"); + // Verify access log before returning + if (log_type.has_proofs()) { + hash_type root_hash_after; + get_root_hash(root_hash_after); + verify_uarch_step_state_transition(root_hash_before, *a.get_log(), root_hash_after, m_r, one_based); + } else { + verify_uarch_step_log(*a.get_log(), m_r, one_based); + } + return std::move(*a.get_log()); +} + +void machine::verify_uarch_step_log(const access_log &log, const machine_runtime_config &r, bool one_based) { + (void) r; + // There must be at least one access in log + if (log.get_accesses().empty()) { + throw std::invalid_argument{"too few accesses in log"}; + } + uarch_replay_step_state_access a(log, false /* verify proofs */, {} /* initial hash */, one_based); + uarch_step(a); + a.finish(); +} + void machine::verify_uarch_step_state_transition(const hash_type &root_hash_before, const access_log &log, const hash_type &root_hash_after, const machine_runtime_config &r, bool one_based) { (void) r; @@ -1924,17 +1943,8 @@ void machine::verify_uarch_step_state_transition(const hash_type &root_hash_befo if (log.get_accesses().empty()) { throw std::invalid_argument{"too few accesses in log"}; } - // It must contain proofs - if (!log.get_accesses().front().get_proof().has_value()) { - throw std::invalid_argument{"access has no proof"}; - } - // Make sure the access log starts from the same root hash as the state - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - if (log.get_accesses().front().get_proof().value().get_root_hash() != root_hash_before) { - throw std::invalid_argument{"mismatch in root hash before replay"}; - } // Verify all intermediate state transitions - uarch_replay_step_state_access a(log, true /* verify proofs! */, one_based); + uarch_replay_step_state_access a(log, true /* verify proofs! */, root_hash_before, one_based); uarch_step(a); a.finish(); // Make sure the access log ends at the same root hash as the state @@ -1949,31 +1959,6 @@ machine_config machine::get_default_config(void) { return machine_config{}; } -access_log machine::log_uarch_step(const access_log::type &log_type, bool one_based) { - if (m_uarch.get_state().ram.get_istart_E()) { - throw std::runtime_error("microarchitecture RAM is not present"); - } - hash_type root_hash_before; - if (log_type.has_proofs()) { - update_merkle_tree(); - get_root_hash(root_hash_before); - } - // Call interpret with a logged state access object - uarch_record_step_state_access a(m_uarch.get_state(), *this, log_type); - a.push_bracket(bracket_type::begin, "step"); - uarch_step(a); - a.push_bracket(bracket_type::end, "step"); - // Verify access log before returning - if (log_type.has_proofs()) { - hash_type root_hash_after; - get_root_hash(root_hash_after); - verify_uarch_step_state_transition(root_hash_before, *a.get_log(), root_hash_after, m_r, one_based); - } else { - verify_uarch_step_log(*a.get_log(), m_r, one_based); - } - return std::move(*a.get_log()); -} - // NOLINTNEXTLINE(readability-convert-member-functions-to-static) uarch_interpreter_break_reason machine::run_uarch(uint64_t uarch_cycle_end) { if (m_uarch.get_state().ram.get_istart_E()) { diff --git a/src/merkle-tree-proof.h b/src/merkle-tree-proof.h index bba36749b..dc173432c 100644 --- a/src/merkle-tree-proof.h +++ b/src/merkle-tree-proof.h @@ -143,6 +143,10 @@ class merkle_tree_proof final { m_sibling_hashes[log2_size_to_index(log2_size)] = hash; } + const sibling_hashes_type &get_sibling_hashes() const { + return m_sibling_hashes; + } + /// \brief Checks if two Merkle proofs are equal bool operator==(const merkle_tree_proof &other) const { if (get_log2_target_size() != other.get_log2_target_size()) { @@ -256,8 +260,7 @@ class merkle_tree_proof final { /// \brief Converts log2_size to index into siblings array /// \return Index into siblings array, or throws exception if out of bouds int log2_size_to_index(int log2_size) const { - // We know log2_root_size > 0, so log2_root_size-1 >= 0 - const int index = m_log2_root_size - 1 - log2_size; + const int index = log2_size - m_log2_target_size; if (index < 0 || index >= static_cast(m_sibling_hashes.size())) { throw std::out_of_range{"log2_size is out of range"}; } diff --git a/src/protobuf-util.cpp b/src/protobuf-util.cpp index 52c30867a..e24b77433 100644 --- a/src/protobuf-util.cpp +++ b/src/protobuf-util.cpp @@ -253,44 +253,50 @@ void set_proto_merkle_tree_proof(const machine_merkle_tree::proof_type &p, Carte } } +static void set_proto_access(const access &a, CartesiMachine::Access *proto_a) { + switch (a.get_type()) { + case access_type::read: + proto_a->set_type(CartesiMachine::AccessType::READ); + break; + case access_type::write: + proto_a->set_type(CartesiMachine::AccessType::WRITE); + break; + default: + throw std::invalid_argument{"invalid AccessType"}; + break; + } + proto_a->set_log2_size(a.get_log2_size()); + proto_a->set_address(a.get_address()); + set_proto_hash(a.get_read_hash(), proto_a->mutable_read_hash()); + if (a.get_read().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + const auto &value_read = a.get_read().value(); + proto_a->set_read(value_read.data(), value_read.size()); + } + if (a.get_written_hash().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + set_proto_hash(a.get_written_hash().value(), proto_a->mutable_written_hash()); + } + if (a.get_written().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + const auto &value_written = a.get_written().value(); + proto_a->set_written(value_written.data(), value_written.size()); + } + if (a.get_sibling_hashes().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + const auto &sibling_hashes = a.get_sibling_hashes().value(); + for (const auto &s : sibling_hashes) { + set_proto_hash(s, proto_a->add_sibling_hashes()); + } + } +} + void set_proto_access_log(const access_log &al, CartesiMachine::AccessLog *proto_al) { proto_al->mutable_log_type()->set_annotations(al.get_log_type().has_annotations()); proto_al->mutable_log_type()->set_proofs(al.get_log_type().has_proofs()); proto_al->mutable_log_type()->set_large_data(al.get_log_type().has_large_data()); for (const auto &a : al.get_accesses()) { - auto *proto_a = proto_al->add_accesses(); - switch (a.get_type()) { - case access_type::read: - proto_a->set_type(CartesiMachine::AccessType::READ); - break; - case access_type::write: - proto_a->set_type(CartesiMachine::AccessType::WRITE); - break; - default: - throw std::invalid_argument{"invalid AccessType"}; - break; - } - proto_a->set_log2_size(a.get_log2_size()); - proto_a->set_address(a.get_address()); - set_proto_hash(a.get_read_hash(), proto_a->mutable_read_hash()); - if (a.get_read().has_value()) { - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - const auto &value_read = a.get_read().value(); - proto_a->set_read(value_read.data(), value_read.size()); - } - if (a.get_written_hash().has_value()) { - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - set_proto_hash(a.get_written_hash().value(), proto_a->mutable_written_hash()); - } - if (a.get_written().has_value()) { - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - const auto &value_written = a.get_written().value(); - proto_a->set_written(value_written.data(), value_written.size()); - } - if (al.get_log_type().has_proofs() && a.get_proof().has_value()) { - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - set_proto_merkle_tree_proof(a.get_proof().value(), proto_a->mutable_proof()); - } + set_proto_access(a, proto_al->add_accesses()); } if (al.get_log_type().has_annotations()) { for (const auto &bn : al.get_brackets()) { @@ -337,6 +343,38 @@ access_type get_proto_access_type(CartesiMachine::AccessType proto_at) { }; } +static access get_proto_access(const CartesiMachine::Access &pac) { + access a; + a.set_type(get_proto_access_type(pac.type())); + a.set_address(pac.address()); + a.set_log2_size(static_cast(pac.log2_size())); + a.set_read_hash(get_proto_hash(pac.read_hash())); + if (pac.has_read()) { + access_data read_value; + read_value.insert(read_value.end(), pac.read().begin(), pac.read().end()); + a.set_read(read_value); + } + if (pac.has_written_hash()) { + a.set_written_hash(get_proto_hash(pac.written_hash())); + } + if (pac.has_written()) { + access_data written_value; + written_value.insert(written_value.end(), pac.written().begin(), pac.written().end()); + a.set_written(written_value); + } + + if (!pac.sibling_hashes().empty()) { + const auto &proto_sibs = pac.sibling_hashes(); + a.get_sibling_hashes().emplace(); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto &sibling_hashes = a.get_sibling_hashes().value(); + for (const auto &s : proto_sibs) { + sibling_hashes.push_back(get_proto_hash(s)); + } + } + return a; +} + access_log get_proto_access_log(const CartesiMachine::AccessLog &proto_al) { if (proto_al.log_type().annotations() && proto_al.accesses().size() != proto_al.notes().size()) { throw std::invalid_argument("size of log accesses and notes differ"); @@ -361,32 +399,12 @@ access_log get_proto_access_log(const CartesiMachine::AccessLog &proto_al) { assert(pbr->where() == al.get_brackets().back().where); pbr++; } - access a; - a.set_type(get_proto_access_type(pac->type())); - a.set_address(pac->address()); - a.set_log2_size(static_cast(pac->log2_size())); - a.set_read_hash(get_proto_hash(pac->read_hash())); - if (pac->has_read()) { - access_data read_value; - read_value.insert(read_value.end(), pac->read().begin(), pac->read().end()); - a.set_read(read_value); - } - if (pac->has_written_hash()) { - a.set_written_hash(get_proto_hash(pac->written_hash())); - } - if (pac->has_written()) { - access_data written_value; - written_value.insert(written_value.end(), pac->written().begin(), pac->written().end()); - a.set_written(written_value); - } + const auto &proto_a = get_proto_access(*pac); std::string note; if (has_annotations) { note = *pnt++; } - if (has_proofs) { - a.set_proof(get_proto_merkle_tree_proof(pac->proof())); - } - al.push_access(a, note.c_str()); + al.push_access(proto_a, note.c_str()); pac++; iac++; } diff --git a/src/test-machine-c-api.cpp b/src/test-machine-c-api.cpp index c982f178e..1d8873ad6 100644 --- a/src/test-machine-c-api.cpp +++ b/src/test-machine-c-api.cpp @@ -1800,11 +1800,15 @@ class access_log_machine_fixture : public incomplete_machine_fixture { _log_type = {true, true, false}; _machine_dir_path = (std::filesystem::temp_directory_path() / "661b6096c377cdc07756df488059f4407c8f4").string(); + // Encodes: li t0, UARCH_HALT_FLAG_SHADDOW_ADDR + uint32_t li_t0_UARCH_SHADOW_START_ADDRESS = + ((UARCH_HALT_FLAG_SHADDOW_ADDR_DEF >> 12) << 12) | static_cast(0x02b7); + uint32_t test_uarch_ram[] = { - 0x07b00513, // li a0,123 - 0x004002b7, // li t0,UARCH_HALT_FLAG_SHADDOW_ADDR Address of uarch halt flag - 0x00100313, // li t1,1 - 0x0062b023, // sd t1,0(t0) Halt microarchitecture at uarch cycle 4 + 0x07b00513, // li a0,123 + li_t0_UARCH_SHADOW_START_ADDRESS, // li t0,UARCH_HALT_FLAG_SHADDOW_ADDR + 0x00100313, // li t1,1 + 0x0062b023, // sd t1,0(t0) Halt microarchitecture at uarch cycle 4 }; std::ofstream of(_uarch_ram_path, std::ios::binary); of.write(static_cast(static_cast(&test_uarch_ram)), sizeof(test_uarch_ram)); diff --git a/src/test-utils.h b/src/test-utils.h index ee871ed1c..2a799024c 100644 --- a/src/test-utils.h +++ b/src/test-utils.h @@ -75,11 +75,11 @@ static hash_type calculate_proof_root_hash(const cm_merkle_tree_proof *proof) { auto bit = (proof->target_address & (UINT64_C(1) << log2_size)); hash_type first, second; if (bit) { - memcpy(first.data(), proof->sibling_hashes.entry[proof->log2_root_size - log2_size - 1], sizeof(cm_hash)); + memcpy(first.data(), proof->sibling_hashes.entry[log2_size - proof->log2_target_size], sizeof(cm_hash)); second = hash; } else { first = hash; - memcpy(second.data(), proof->sibling_hashes.entry[proof->log2_root_size - log2_size - 1], sizeof(cm_hash)); + memcpy(second.data(), proof->sibling_hashes.entry[log2_size - proof->log2_target_size], sizeof(cm_hash)); } get_concat_hash(h, first, second, hash); } diff --git a/src/tests/machine-bind.lua b/src/tests/machine-bind.lua index 436fa1a00..66925c3e5 100755 --- a/src/tests/machine-bind.lua +++ b/src/tests/machine-bind.lua @@ -764,7 +764,7 @@ do_test("Step log must contain conssitent data hashes", function(machine) -- ensure that verification fails with wrong written hash write_access.written_hash = wrong_hash _, err = pcall(module.machine.verify_uarch_step_log, log, {}) - assert(err:match("logged written data of uarch.cycle does not hash to the logged written hash at access 8")) + assert(err:match("value being written to uarch.cycle does not hash to the logged written hash at access 8")) end) do_test("step when uarch cycle is max", function(machine) @@ -892,10 +892,9 @@ local function test_reset_uarch(machine, with_log, with_proofs, with_annotations assert(#log.accesses == 1) local access = log.accesses[1] if with_proofs then - assert(access.proof ~= nil) - assert(test_util.check_proof(access.proof), "uarch reset proof failed") + assert(access.sibling_hashes ~= nil) else - assert(access.proof == nil) + assert(access.sibling_hashes == nil) end assert(access.address == cartesi.UARCH_SHADOW_START_ADDRESS) assert(access.log2_size == cartesi.UARCH_STATE_LOG2_SIZE) @@ -950,7 +949,7 @@ test_util.make_do_test(build_machine, machine_type, { uarch = test_reset_uarch_c -- verifying incorrect initial hash local wrong_hash = string.rep("0", 32) local _, err = pcall(module.machine.verify_uarch_reset_state_transition, wrong_hash, log, final_hash, {}) - assert(err:match("mismatch in root hash before replay")) + assert(err:match("Mismatch in root hash of access 1")) -- verifying incorrect final hash _, err = pcall(module.machine.verify_uarch_reset_state_transition, initial_hash, log, wrong_hash, {}) assert(err:match("mismatch in root hash after replay")) @@ -972,7 +971,6 @@ test_util.make_do_test(build_machine, machine_type, { uarch = test_reset_uarch_c function(machine) local log = machine:log_uarch_reset({ proofs = true, annotations = true }) local expected_dump = "begin reset uarch state\n" - .. " hash ca8558be\n" .. ' 1: write uarch_state@0x400000(4194304): hash:"cddaea90"(2^22 bytes) -> hash:"b8fdcda1"(2^22 bytes)\n' .. "end reset uarch state\n" @@ -999,10 +997,7 @@ test_util.make_do_test(build_machine, machine_type, { uarch = test_reset_uarch_c "Log uarch reset with large_data option set must have consistent read and written data", function(machine) local module = cartesi - if machine_type ~= "local" then - module = remote - -- return -- TODO: fix grpc and jsorpc failing due to large data - end + if machine_type ~= "local" then module = remote end -- reset uarch and get log local log = machine:log_uarch_reset({ proofs = true, annotations = true, large_data = true }) assert(#log.accesses == 1, "log should have 1 access") @@ -1027,4 +1022,95 @@ test_util.make_do_test(build_machine, machine_type, { uarch = test_reset_uarch_c end ) +do_test("Test unhappy paths of verify_uarch_reset_state_transition", function(machine) + local module = cartesi + if machine_type ~= "local" then + if not remote then remote = connect() end + module = remote + end + local bad_hash = string.rep("\0", 32) + local function assert_error(expected_error, callback) + machine:reset_uarch() + local initial_hash = machine:get_root_hash() + local log = machine:log_uarch_reset({ proofs = true, annotations = false }) + local final_hash = machine:get_root_hash() + callback(log) + local _, err = pcall(module.machine.verify_uarch_reset_state_transition, initial_hash, log, final_hash, {}) + assert( + err:match(expected_error), + 'Error text "' .. err .. '" does not match expected "' .. expected_error .. '"' + ) + end + assert_error("too few accesses in log", function(log) log.accesses = {} end) + assert_error( + "expected address of access 1 to be the start address of the uarch state", + function(log) log.accesses[1].address = 0 end + ) + + assert_error( + "expected access 1 to write 2%^22 bytes to uarchState", + function(log) log.accesses[1].log2_size = 64 end + ) + + assert_error("hash length must be 32 bytes", function(log) log.accesses[#log.accesses].read_hash = nil end) + assert_error("Mismatch in root hash of access 1", function(log) log.accesses[1].read_hash = bad_hash end) + assert_error( + "access log was not fully consumed", + function(log) log.accesses[#log.accesses + 1] = log.accesses[1] end + ) + assert_error("hash length must be 32 bytes", function(log) log.accesses[#log.accesses].written_hash = nil end) + assert_error( + "invalid written %(expected% string with 2%^22 bytes%)", + function(log) log.accesses[#log.accesses].written = "\0" end + ) + assert_error( + "written hash and written data mismatch at access 1", + function(log) log.accesses[#log.accesses].written = string.rep("\0", 2 ^ 22) end + ) + assert_error("Mismatch in root hash of access 1", function(log) log.accesses[1].sibling_hashes[1] = bad_hash end) +end) + +do_test("Test unhappy paths of verify_uarch_step_state_transition", function(machine) + local module = cartesi + if machine_type ~= "local" then + if not remote then remote = connect() end + module = remote + end + local bad_hash = string.rep("\0", 32) + local function assert_error(expected_error, callback) + machine:reset_uarch() + local initial_hash = machine:get_root_hash() + local log = machine:log_uarch_step({ proofs = true, annotations = false }) + local final_hash = machine:get_root_hash() + callback(log) + local _, err = pcall(module.machine.verify_uarch_step_state_transition, initial_hash, log, final_hash, {}) + assert( + err:match(expected_error), + 'Error text "' .. err .. '" does not match expected "' .. expected_error .. '"' + ) + end + assert_error("too few accesses in log", function(log) log.accesses = {} end) + assert_error("expected access 1 to read uarch.uarch_cycle", function(log) log.accesses[1].address = 0 end) + assert_error("invalid log2_size", function(log) log.accesses[1].log2_size = 2 end) + assert_error("invalid log2_size", function(log) log.accesses[1].log2_size = 65 end) + assert_error("missing read uarch.uarch_cycle data at access 1", function(log) log.accesses[1].read = nil end) + assert_error("invalid read %(expected string with 2%^3 bytes%)", function(log) log.accesses[1].read = "\0" end) + assert_error( + "logged read data of uarch.uarch_cycle data does not hash to the logged read hash at access 1", + function(log) log.accesses[1].read_hash = bad_hash end + ) + assert_error("hash length must be 32 bytes", function(log) log.accesses[#log.accesses].read_hash = nil end) + assert_error("too many word accesses in log", function(log) log.accesses[#log.accesses + 1] = log.accesses[1] end) + assert_error("hash length must be 32 bytes", function(log) log.accesses[#log.accesses].written_hash = nil end) + assert_error( + "invalid written %(expected string with 2%^3 bytes%)", + function(log) log.accesses[#log.accesses].written = "\0" end + ) + assert_error( + "logged written data of uarch.cycle does not hash to the logged written hash at access 7", + function(log) log.accesses[#log.accesses].written = "\0\0\0\0\0\0\0\0" end + ) + assert_error("Mismatch in root hash of access 1", function(log) log.accesses[1].sibling_hashes[1] = bad_hash end) +end) + print("\n\nAll machine binding tests for type " .. machine_type .. " passed") diff --git a/src/tests/mcycle-overflow.lua b/src/tests/mcycle-overflow.lua index edd90be37..fc8d170fa 100755 --- a/src/tests/mcycle-overflow.lua +++ b/src/tests/mcycle-overflow.lua @@ -89,7 +89,7 @@ for _, proofs in ipairs({ true, false }) do assert(log.accesses[1].type == "read") assert(log.accesses[1].address == cartesi.UARCH_SHADOW_START_ADDRESS + 8) -- address of uarch_cycle assert(log.accesses[1].read == string.pack("J", MAX_UARCH_CYCLE)) - assert((log.accesses[1].proof ~= nil) == proofs) + assert((log.accesses[1].sibling_hashes ~= nil) == proofs) end) end diff --git a/src/tests/util.lua b/src/tests/util.lua index 25f2057fb..514ba5254 100644 --- a/src/tests/util.lua +++ b/src/tests/util.lua @@ -47,9 +47,13 @@ end local ZERO_PAGE = string.rep("\x00", PAGE_SIZE) +-- Encodes: li t0, UARCH_HALT_FLAG_SHADDOW_ADDR +-- The halt flag is located at the first dword starting from UARCH_SHADOW_START_ADDRESS +local li_t0_UARCH_SHADOW_START_ADDRESS = ((cartesi.UARCH_SHADOW_START_ADDRESS >> 12) << 12) | 0x02b7 + test_util.uarch_programs = { halt = { - 0x004002b7, -- li t0,UARCH_HALT_FLAG_SHADDOW_ADDR Address of uarch halt flag + li_t0_UARCH_SHADOW_START_ADDRESS, -- li t0,UARCH_HALT_FLAG_SHADDOW_ADDR 0x00100313, -- li t1,1 UARCH_MMIO_HALT_VALUE 0x0062b023, -- sd t1,0(t0) Halt uarch }, @@ -201,9 +205,9 @@ function test_util.check_proof(proof) local bit = (proof.target_address & (1 << log2_size)) ~= 0 local first, second if bit then - first, second = proof.sibling_hashes[proof.log2_root_size - log2_size], hash + first, second = proof.sibling_hashes[log2_size - proof.log2_target_size + 1], hash else - first, second = hash, proof.sibling_hashes[proof.log2_root_size - log2_size] + first, second = hash, proof.sibling_hashes[log2_size - proof.log2_target_size + 1] end hash = cartesi.keccak(first, second) end diff --git a/src/uarch-record-reset-state-access.h b/src/uarch-record-reset-state-access.h index cb4c4c449..b99fe1732 100644 --- a/src/uarch-record-reset-state-access.h +++ b/src/uarch-record-reset-state-access.h @@ -149,7 +149,9 @@ class uarch_record_reset_state_access : public i_uarch_reset_state_accessget_log_type().has_proofs()) { - a.set_proof(std::move(proof)); + // We just store the sibling hashes in the access because this is the only missing piece of data needed to + // reconstruct the proof + a.set_sibling_hashes(proof.get_sibling_hashes()); } a.set_written_hash(uarch_pristine_state_hash); diff --git a/src/uarch-record-step-state-access.h b/src/uarch-record-step-state-access.h index 54c31fdbf..ba905b1e5 100644 --- a/src/uarch-record-step-state-access.h +++ b/src/uarch-record-step-state-access.h @@ -161,7 +161,12 @@ class uarch_record_step_state_access : public i_uarch_step_state_access; // NOLINTNEXTLINE(readability-convert-member-functions-to-static) @@ -143,22 +123,21 @@ class uarch_replay_reset_state_access : public i_uarch_reset_state_access { + using tree_type = machine_merkle_tree; + using hash_type = tree_type::hash_type; + using hasher_type = tree_type::hasher_type; + using proof_type = tree_type::proof_type; + ///< Access log generated by step const std::vector &m_accesses; ///< Whether to verify proofs in access log @@ -42,30 +47,33 @@ class uarch_replay_step_state_access : public i_uarch_step_state_access 63) { - throw std::invalid_argument{"invalid access size"}; + if (m_next_access >= m_accesses.size()) { + throw std::invalid_argument{"too few accesses in log"}; } + const auto &access = m_accesses[m_next_access]; if ((paligned & ((UINT64_C(1) << log2_size) - 1)) != 0) { throw std::invalid_argument{"access address not aligned to size"}; } - if (m_next_access >= m_accesses.size()) { - throw std::invalid_argument{"too few accesses in log"}; + if (access.get_address() != paligned) { + std::ostringstream err; + err << "expected access " << access_to_report() << " to read " << text << " at address 0x" << std::hex + << paligned << "(" << std::dec << paligned << ")"; + throw std::invalid_argument{err.str()}; } - const auto &access = m_accesses[m_next_access]; - if (access.get_type() != access_type::read) { - throw std::invalid_argument{"expected access " + std::to_string(access_to_report()) + " to read " + text}; + if (log2_size < 3 || log2_size > 63) { + throw std::invalid_argument{"invalid access size"}; } if (access.get_log2_size() != log2_size) { throw std::invalid_argument{"expected access " + std::to_string(access_to_report()) + " to read 2^" + std::to_string(log2_size) + " bytes from " + text}; } + if (access.get_type() != access_type::read) { + throw std::invalid_argument{"expected access " + std::to_string(access_to_report()) + " to read " + text}; + } if (!access.get_read().has_value()) { throw std::invalid_argument{ "missing read " + std::string(text) + " data at access " + std::to_string(access_to_report())}; @@ -160,41 +156,16 @@ class uarch_replay_step_state_access : public i_uarch_step_state_access 63) { - throw std::invalid_argument{"invalid access size"}; + if (m_next_access >= m_accesses.size()) { + throw std::invalid_argument{"too few accesses in log"}; } + const auto &access = m_accesses[m_next_access]; if ((paligned & ((UINT64_C(1) << log2_size) - 1)) != 0) { throw std::invalid_argument{"access address not aligned to size"}; } - if (m_next_access >= m_accesses.size()) { - throw std::invalid_argument{"too few word accesses in log"}; + if (access.get_address() != paligned) { + std::ostringstream err; + err << "expected access " << access_to_report() << " to write " << text << " at address 0x" << std::hex + << paligned << "(" << std::dec << paligned << ")"; + throw std::invalid_argument{err.str()}; } - const auto &access = m_accesses[m_next_access]; - if (access.get_type() != access_type::write) { - throw std::invalid_argument{"expected access " + std::to_string(access_to_report()) + " to write " + text}; + if (log2_size < 3 || log2_size > 63) { + throw std::invalid_argument{"invalid access size"}; } if (access.get_log2_size() != log2_size) { throw std::invalid_argument{"expected access " + std::to_string(access_to_report()) + " to write 2^" + std::to_string(log2_size) + " bytes from " + text}; } - if (!access.get_read().has_value()) { - throw std::invalid_argument{ - "missing read data from " + std::string(text) + " at access " + std::to_string(access_to_report())}; - } - const auto &value_read = access.get_read().value(); // NOLINT(bugprone-unchecked-optional-access) - if (value_read.size() != UINT64_C(1) << log2_size) { - throw std::invalid_argument{"expected overwritten data from " + std::string(text) + " to contain 2^" + - std::to_string(log2_size) + " bytes at access " + std::to_string(access_to_report())}; - } - // check if logged read data hashes to the logged read hash - machine_merkle_tree::hash_type computed_hash{}; - get_hash(m_hasher, value_read, computed_hash); - if (access.get_read_hash() != computed_hash) { - throw std::invalid_argument{"logged read data of " + std::string(text) + - " does not hash to the logged read hash at access " + std::to_string(access_to_report())}; - } - if (!access.get_written().has_value()) { - throw std::invalid_argument{ - "missing written data from " + std::string(text) + " at access " + std::to_string(access_to_report())}; + if (access.get_type() != access_type::write) { + throw std::invalid_argument{"expected access " + std::to_string(access_to_report()) + " to write " + text}; } - const auto &value_written = access.get_written().value(); // NOLINT(bugprone-unchecked-optional-access) - if (value_written.size() != UINT64_C(1) << log2_size) { - throw std::invalid_argument{"expected written " + std::string(text) + " data to contain 2^" + - std::to_string(log2_size) + " bytes at access " + std::to_string(access_to_report())}; + if (access.get_read().has_value()) { + const auto &value_read = access.get_read().value(); // NOLINT(bugprone-unchecked-optional-access) + if (value_read.size() != UINT64_C(1) << log2_size) { + throw std::invalid_argument{"expected overwritten data from " + std::string(text) + " to contain 2^" + + std::to_string(log2_size) + " bytes at access " + std::to_string(access_to_report())}; + } + // check if read data hashes to the logged read hash + hash_type computed_hash{}; + get_hash(m_hasher, value_read, computed_hash); + if (access.get_read_hash() != computed_hash) { + throw std::invalid_argument{"logged read data of " + std::string(text) + + " does not hash to the logged read hash at access " + std::to_string(access_to_report())}; + } } if (!access.get_written_hash().has_value()) { throw std::invalid_argument{ "missing written " + std::string(text) + " hash at access " + std::to_string(access_to_report())}; } const auto &written_hash = access.get_written_hash().value(); // NOLINT(bugprone-unchecked-optional-access) - // check if logged written data hashes to the logged written hash - get_hash(m_hasher, value_written, computed_hash); + // check if value being written hashes to the logged written hash + hash_type computed_hash{}; + get_hash(m_hasher, val, computed_hash); if (written_hash != computed_hash) { - throw std::invalid_argument{"logged written data of " + std::string(text) + + throw std::invalid_argument{"value being written to " + std::string(text) + " does not hash to the logged written hash at access " + std::to_string(access_to_report())}; } - if (access.get_address() != paligned) { - std::ostringstream err; - err << "expected access " << access_to_report() << " to write " << text << " at address 0x" << std::hex - << paligned << "(" << std::dec << paligned << ")"; - throw std::invalid_argument{err.str()}; - } - if (m_verify_proofs) { - if (!access.get_proof().has_value()) { - throw std::invalid_argument{"write access " + std::to_string(access_to_report()) + " has no proof"}; + if (access.get_written().has_value()) { + const auto &value_written = access.get_written().value(); // NOLINT(bugprone-unchecked-optional-access) + if (value_written.size() != UINT64_C(1) << log2_size) { + throw std::invalid_argument{"expected written " + std::string(text) + " data to contain 2^" + + std::to_string(log2_size) + " bytes at access " + std::to_string(access_to_report())}; } - const auto &proof = access.get_proof().value(); // NOLINT(bugprone-unchecked-optional-access) - if (proof.get_target_address() != access.get_address()) { - throw std::invalid_argument{"mismatch in write access " + std::to_string(access_to_report()) + - " address and its proof address"}; + // check if written data hashes to the logged written hash + get_hash(m_hasher, value_written, computed_hash); + if (written_hash != computed_hash) { + throw std::invalid_argument{"logged written data of " + std::string(text) + + " does not hash to the logged written hash at access " + std::to_string(access_to_report())}; } - if (m_root_hash != proof.get_root_hash()) { - throw std::invalid_argument{ - "mismatch in write access " + std::to_string(access_to_report()) + " root hash"}; - } - machine_merkle_tree::hash_type rolling_hash; - get_hash(m_hasher, access.get_read().value(), rolling_hash); - if (rolling_hash != proof.get_target_hash()) { - throw std::invalid_argument{ - "value before write access " + std::to_string(access_to_report()) + " does not match target hash"}; - } - roll_hash_up_tree(m_hasher, proof, rolling_hash); - if (rolling_hash != proof.get_root_hash()) { - throw std::invalid_argument{ - "value before write access " + std::to_string(access_to_report()) + " fails proof"}; - } - if (access.get_written() != val) { - throw std::invalid_argument{ - "value written in access " + std::to_string(access_to_report()) + " does not match log"}; + } + if (m_verify_proofs) { + auto proof = access.make_proof(m_root_hash); + if (!proof.verify(m_hasher)) { + throw std::invalid_argument{"Mismatch in root hash of access " + std::to_string(access_to_report())}; } - auto value_written = access.get_written().value(); - get_hash(m_hasher, value_written, m_root_hash); - roll_hash_up_tree(m_hasher, proof, m_root_hash); + // Update root hash to reflect the data written by this access + m_root_hash = proof.bubble_up(m_hasher, written_hash); } m_next_access++; } diff --git a/src/uarch-riscv-tests.lua b/src/uarch-riscv-tests.lua index db39208d9..55558b18a 100755 --- a/src/uarch-riscv-tests.lua +++ b/src/uarch-riscv-tests.lua @@ -303,25 +303,16 @@ end local function open_steps_json_log(test_name) return create_json_log_file(test_name, "-steps") end local function write_sibling_hashes_to_log(sibling_hashes, out, indent) + util.indentout(out, indent, '"sibling_hashes": [\n') for i, h in ipairs(sibling_hashes) do - util.indentout(out, indent, '"%s"', util.hexhash(h)) + util.indentout(out, indent + 1, '"%s"', util.hexhash(h)) if sibling_hashes[i + 1] then out:write(",\n") else out:write("\n") end end -end - -local function write_proof_to_log(proof, out, indent) - util.indentout(out, indent, '"target_address": %u,\n', proof.target_address) - util.indentout(out, indent, '"log2_target_size": %u,\n', proof.log2_target_size) - util.indentout(out, indent, '"log2_root_size": %u,\n', proof.log2_root_size) - util.indentout(out, indent, '"target_hash": "%s",\n', util.hexhash(proof.target_hash)) - util.indentout(out, indent, '"sibling_hashes": [\n') - write_sibling_hashes_to_log(proof.sibling_hashes, out, indent + 1) - util.indentout(out, indent, "],\n") - util.indentout(out, indent, '"root_hash": "%s"\n', util.hexhash(proof.root_hash)) + util.indentout(out, indent, "]\n") end local function write_access_to_log(access, out, indent, last) @@ -340,11 +331,9 @@ local function write_access_to_log(access, out, indent, last) util.indentout(out, indent + 1, '"value": %s,', value) util.indentout(out, indent + 1, '"hash": "%s"', util.hexhash(access.read_hash)) end - if access.proof then + if access.sibling_hashes then out:write(",\n") - util.indentout(out, indent + 1, '"proof": {\n') - write_proof_to_log(access.proof, out, indent + 2) - util.indentout(out, indent + 1, "}\n") + write_sibling_hashes_to_log(access.sibling_hashes, out, indent + 2) else out:write("\n") end diff --git a/third-party/riscv-arch-tests/src/run-rv64i-arch-test.lua b/third-party/riscv-arch-tests/src/run-rv64i-arch-test.lua index d67014758..40fe068a9 100755 --- a/third-party/riscv-arch-tests/src/run-rv64i-arch-test.lua +++ b/third-party/riscv-arch-tests/src/run-rv64i-arch-test.lua @@ -40,10 +40,7 @@ end local uarch_ram_image_filename = arg[1] local output_signature_file = arg[2] local uarch_ram_start = cartesi.UARCH_RAM_START_ADDRESS -local dummy_rom_filename = os.tmpname() -io.open(dummy_rom_filename, "w"):close() local deleter = {} -setmetatable(deleter, { __gc = function() os.remove(dummy_rom_filename) end }) local config = { uarch = {