Skip to content

Commit

Permalink
Update pybind/gil release scope/stress test (#234)
Browse files Browse the repository at this point in the history
Co-authored-by: Arvid Norberg <[email protected]>
  • Loading branch information
wjblanke and arvidn authored Feb 10, 2025
1 parent f31b096 commit a248af2
Show file tree
Hide file tree
Showing 5 changed files with 834 additions and 24 deletions.
9 changes: 8 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ if(BUILD_PYTHON)
FetchContent_Declare(
pybind11-src
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG v2.11.1
GIT_TAG v2.13.6
)
FetchContent_MakeAvailable(pybind11-src)

Expand All @@ -81,10 +81,17 @@ add_executable(verifier_test
${CMAKE_CURRENT_SOURCE_DIR}/verifier_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/refcode/lzcnt.c
)
add_executable(stress_test
${CMAKE_CURRENT_SOURCE_DIR}/stress_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/refcode/lzcnt.c
)

target_link_libraries(verifier_test PRIVATE ${GMP_LIBRARIES} ${GMPXX_LIBRARIES})
target_link_libraries(stress_test PRIVATE ${GMP_LIBRARIES} ${GMPXX_LIBRARIES})

if(UNIX)
target_link_libraries(verifier_test PRIVATE -pthread)
target_link_libraries(stress_test PRIVATE -pthread)
endif()

if(BUILD_CHIAVDFC)
Expand Down
76 changes: 53 additions & 23 deletions src/python_bindings/fastvdf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ PYBIND11_MODULE(chiavdf, m) {
// Creates discriminant.
m.def("create_discriminant", [] (const py::bytes& challenge_hash, int discriminant_size_bits) {
std::string challenge_hash_str(challenge_hash);
py::gil_scoped_release release;
auto challenge_hash_bits = std::vector<uint8_t>(challenge_hash_str.begin(), challenge_hash_str.end());
integer D = CreateDiscriminant(
challenge_hash_bits,
discriminant_size_bits
);
integer D;
{
py::gil_scoped_release release;
auto challenge_hash_bits = std::vector<uint8_t>(challenge_hash_str.begin(), challenge_hash_str.end());
D = CreateDiscriminant(
challenge_hash_bits,
discriminant_size_bits
);
}
return D.to_string();
});

Expand All @@ -29,13 +32,14 @@ PYBIND11_MODULE(chiavdf, m) {
std::string x_s_copy(x_s);
std::string y_s_copy(y_s);
std::string proof_s_copy(proof_s);
py::gil_scoped_release release;
form x = DeserializeForm(D, (const uint8_t *)x_s_copy.data(), x_s_copy.size());
form y = DeserializeForm(D, (const uint8_t *)y_s_copy.data(), y_s_copy.size());
form proof = DeserializeForm(D, (const uint8_t *)proof_s_copy.data(), proof_s_copy.size());

bool is_valid = false;
VerifyWesolowskiProof(D, x, y, proof, num_iterations, is_valid);
{
py::gil_scoped_release release;
form x = DeserializeForm(D, (const uint8_t *)x_s_copy.data(), x_s_copy.size());
form y = DeserializeForm(D, (const uint8_t *)y_s_copy.data(), y_s_copy.size());
form proof = DeserializeForm(D, (const uint8_t *)proof_s_copy.data(), proof_s_copy.size());
VerifyWesolowskiProof(D, x, y, proof, num_iterations, is_valid);
}
return is_valid;
});

Expand All @@ -47,17 +51,40 @@ PYBIND11_MODULE(chiavdf, m) {
std::string discriminant_copy(discriminant);
std::string x_s_copy(x_s);
std::string proof_blob_copy(proof_blob);
py::gil_scoped_release release;
uint8_t *proof_blob_ptr = reinterpret_cast<uint8_t *>(proof_blob_copy.data());
int proof_blob_size = proof_blob_copy.size();
bool is_valid = false;
{
py::gil_scoped_release release;
is_valid=CheckProofOfTimeNWesolowski(integer(discriminant_copy), (const uint8_t *)x_s_copy.data(), proof_blob_ptr, proof_blob_size, num_iterations, disc_size_bits, recursion);
}
return is_valid;
});

return CheckProofOfTimeNWesolowski(integer(discriminant_copy), (const uint8_t *)x_s_copy.data(), proof_blob_ptr, proof_blob_size, num_iterations, disc_size_bits, recursion);
// Checks an N wesolowski proof.
m.def("create_discriminant_and_verify_n_wesolowski", [] (const py::bytes& challenge_hash,
const int discriminant_size_bits,
const string& x_s,
const string& proof_blob,
const uint64_t num_iterations,
const uint64_t recursion) {
std::string challenge_hash_str(challenge_hash);
std::vector<uint8_t> challenge_hash_bits = std::vector<uint8_t>(challenge_hash_str.begin(), challenge_hash_str.end());
std::string x_s_copy(x_s);
std::string proof_blob_copy(proof_blob);
bool is_valid = false;
{
py::gil_scoped_release release;
is_valid=CreateDiscriminantAndCheckProofOfTimeNWesolowski(challenge_hash_bits, discriminant_size_bits,(const uint8_t *)x_s_copy.data(), (const uint8_t *)proof_blob_copy.data(), proof_blob_copy.size(), num_iterations, recursion);
}
return is_valid;
});

m.def("prove", [] (const py::bytes& challenge_hash, const string& x_s, int discriminant_size_bits, uint64_t num_iterations, const string& shutdown_file_path) {
std::string challenge_hash_str(challenge_hash);
std::string x_s_copy(x_s);
std::vector<uint8_t> result;
std::string shutdown_file_path_copy(shutdown_file_path);
{
py::gil_scoped_release release;
std::vector<uint8_t> challenge_hash_bytes(challenge_hash_str.begin(), challenge_hash_str.end());
Expand All @@ -66,7 +93,7 @@ PYBIND11_MODULE(chiavdf, m) {
discriminant_size_bits
);
form x = DeserializeForm(D, (const uint8_t *) x_s_copy.data(), x_s_copy.size());
result = ProveSlow(D, x, num_iterations, shutdown_file_path);
result = ProveSlow(D, x, num_iterations, shutdown_file_path_copy);
}
py::bytes ret = py::bytes(reinterpret_cast<char*>(result.data()), result.size());
return ret;
Expand All @@ -78,12 +105,12 @@ PYBIND11_MODULE(chiavdf, m) {
const string& x_s,
const string& proof_blob,
const uint64_t num_iterations, const uint64_t recursion) {
std::string discriminant_copy(discriminant);
std::string B_copy(B);
std::string x_s_copy(x_s);
std::string proof_blob_copy(proof_blob);
std::pair<bool, std::vector<uint8_t>> result;
{
std::string discriminant_copy(discriminant);
std::string B_copy(B);
std::string x_s_copy(x_s);
std::string proof_blob_copy(proof_blob);
py::gil_scoped_release release;
uint8_t *proof_blob_ptr = reinterpret_cast<uint8_t *>(proof_blob_copy.data());
int proof_blob_size = proof_blob_copy.size();
Expand All @@ -101,10 +128,13 @@ PYBIND11_MODULE(chiavdf, m) {
std::string discriminant_copy(discriminant);
std::string x_s_copy(x_s);
std::string proof_blob_copy(proof_blob);
py::gil_scoped_release release;
uint8_t *proof_blob_ptr = reinterpret_cast<uint8_t *>(proof_blob_copy.data());
int proof_blob_size = proof_blob_copy.size();
integer B = GetBFromProof(integer(discriminant_copy), (const uint8_t *)x_s_copy.data(), proof_blob_ptr, proof_blob_size, num_iterations, recursion);
integer B;
{
py::gil_scoped_release release;
uint8_t *proof_blob_ptr = reinterpret_cast<uint8_t *>(proof_blob_copy.data());
int proof_blob_size = proof_blob_copy.size();
B = GetBFromProof(integer(discriminant_copy), (const uint8_t *)x_s_copy.data(), proof_blob_ptr, proof_blob_size, num_iterations, recursion);
}
return B.to_string();
});
}
103 changes: 103 additions & 0 deletions src/stress_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include "verifier.h"
#include <sstream>
#include <string>
#include <fstream>
#include <thread>

std::vector<uint8_t> HexToBytes(const char *hex_proof) {
int len = strlen(hex_proof);
assert(len % 2 == 0);
std::vector<uint8_t> result;
for (int i = 0; i < len; i += 2)
{
int hex1 = hex_proof[i] >= 'a' ? (hex_proof[i] - 'a' + 10) : (hex_proof[i] - '0');
int hex2 = hex_proof[i + 1] >= 'a' ? (hex_proof[i + 1] - 'a' + 10) : (hex_proof[i + 1] - '0');
result.push_back(hex1 * 16 + hex2);
}
return result;
}

struct job
{
std::vector<uint8_t> challengebytes;
std::vector<uint8_t> inputbytes;
std::vector<uint8_t> outputbytes;
uint64 number_of_iterations;
uint32 discriminant_size;
uint8 witness_type;
};

void doit(int thread, std::vector<job> const& jobs)
{
int cnt = 0;
for (job const& j : jobs)
{
bool const is_valid = CreateDiscriminantAndCheckProofOfTimeNWesolowski(
j.challengebytes,
j.discriminant_size,
j.inputbytes.data(),
j.outputbytes.data(),
j.outputbytes.size(),
j.number_of_iterations,
j.witness_type);
if (!is_valid) {
printf("thread %d cnt %d is valid %d %llu %d\n",
thread,
cnt,
is_valid,
j.number_of_iterations,
j.witness_type);
std::terminate();
}
cnt++;
}
}

int main()
{
std::ifstream infile("vdf.txt");

std::string challenge;
std::string discriminant_size;
std::string input_el;
std::string output;
std::string number_of_iterations;
std::string witness_type;

std::vector<job> jobs;

while (true) {
std::getline(infile, challenge);
if (infile.eof())
break;
std::getline(infile, discriminant_size);
std::getline(infile, input_el);
std::getline(infile, output);
std::getline(infile, number_of_iterations);
std::getline(infile, witness_type);

std::vector<uint8_t> challengebytes=HexToBytes(challenge.c_str());
std::vector<uint8_t> inputbytes=HexToBytes(input_el.c_str());
std::vector<uint8_t> outputbytes=HexToBytes(output.c_str());

char *endptr;

uint64 noi=strtoll(number_of_iterations.c_str(),&endptr,10);
if (errno == ERANGE) std::terminate();
uint32 ds=strtoll(discriminant_size.c_str(),&endptr,10);
if (errno == ERANGE) std::terminate();
uint8 wt=strtoll(witness_type.c_str(),&endptr,10);
if (errno == ERANGE) std::terminate();

jobs.push_back({challengebytes, inputbytes, outputbytes, noi, ds, wt});
}

std::vector<std::thread> threads;
for (int i = 0; i < 20; ++i)
threads.emplace_back(doit, i, std::ref(jobs));

for (auto& t : threads)
t.join();

return 0;
}
Loading

0 comments on commit a248af2

Please sign in to comment.