Skip to content

Commit

Permalink
add tests for dynvm
Browse files Browse the repository at this point in the history
Signed-off-by: Protryon <[email protected]>
  • Loading branch information
Protryon committed Mar 26, 2024
1 parent 70b1002 commit 0076a3a
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 41 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ jobs:
arch: x86_64
action: test
flags: --config=gcc
- name: 'DynVM on Linux/x86_64'
engine: 'null'
os: ubuntu-20.04
arch: x86_64
action: test
flags: --config=gcc
- name: 'NullVM on Linux/x86_64 with ASan'
engine: 'null'
os: ubuntu-20.04
Expand Down
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ cc_library(
"src/dyn/dyn_vm_plugin.cc",
],
hdrs = [
"include/proxy-wasm/dyn.h",
"include/proxy-wasm/dyn_vm.h",
"include/proxy-wasm/dyn_vm_plugin.h",
"include/proxy-wasm/wasm_api_impl.h",
Expand Down
39 changes: 39 additions & 0 deletions bazel/wasm.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

load("@rules_rust//rust:defs.bzl", "rust_binary")
load("@rules_rust//rust:defs.bzl", "rust_shared_library")

def _wasm_rust_transition_impl(settings, attr):
return {
Expand Down Expand Up @@ -59,6 +60,17 @@ def _wasm_binary_impl(ctx):

return [DefaultInfo(files = depset([out]), runfiles = ctx.runfiles([out]))]

def _dyn_binary_impl(ctx):
out = ctx.actions.declare_file(ctx.label.name)
ctx.actions.run(
executable = "cp",
arguments = [ctx.files.binary[0].path, out.path],
outputs = [out],
inputs = ctx.files.binary,
)

return [DefaultInfo(files = depset([out]), runfiles = ctx.runfiles([out]))]

def _wasm_attrs(transition):
return {
"binary": attr.label(mandatory = True, cfg = transition),
Expand All @@ -67,6 +79,11 @@ def _wasm_attrs(transition):
"_whitelist_function_transition": attr.label(default = "@bazel_tools//tools/whitelists/function_transition_whitelist"),
}

def _dyn_attrs():
return {
"binary": attr.label(mandatory = True),
}

wasm_rust_binary_rule = rule(
implementation = _wasm_binary_impl,
attrs = _wasm_attrs(wasm_rust_transition),
Expand All @@ -77,6 +94,11 @@ wasi_rust_binary_rule = rule(
attrs = _wasm_attrs(wasi_rust_transition),
)

dyn_rust_binary_rule = rule(
implementation = _dyn_binary_impl,
attrs = _dyn_attrs(),
)

def wasm_rust_binary(name, tags = [], wasi = False, signing_key = [], **kwargs):
wasm_name = "_wasm_" + name.replace(".", "_")
kwargs.setdefault("visibility", ["//visibility:public"])
Expand All @@ -100,3 +122,20 @@ def wasm_rust_binary(name, tags = [], wasi = False, signing_key = [], **kwargs):
signing_key = signing_key,
tags = tags + ["manual"],
)

def dyn_rust_library(name, tags = [], **kwargs):
dyn_name = "_dyn_" + name.replace(".", "_")
kwargs.setdefault("visibility", ["//visibility:public"])

rust_shared_library(
name = dyn_name,
edition = "2018",
tags = ["manual"],
**kwargs
)

dyn_rust_binary_rule(
name = name,
binary = ":" + dyn_name,
tags = tags + ["manual"],
)
3 changes: 0 additions & 3 deletions include/proxy-wasm/dyn_vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@

namespace proxy_wasm {

class WasmVm;
std::unique_ptr<WasmVm> createDynVm();

// The DynVm wraps a C++ Wasm plugin which has been compiled with the Wasm API
// and dynamically linked into the proxy.
struct DynVm : public WasmVm {
Expand Down
1 change: 1 addition & 0 deletions src/dyn/dyn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "include/proxy-wasm/dyn.h"
#include "include/proxy-wasm/dyn_vm.h"

namespace proxy_wasm {
Expand Down
20 changes: 10 additions & 10 deletions src/dyn/dyn_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
#include "include/proxy-wasm/dyn_vm.h"

#include <cstring>

#include <dlfcn.h>
#include <limits>
#include <memory>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <unistd.h>
#include <unordered_map>
#include <utility>
#include <vector>
#include <sys/mman.h>
#include <dlfcn.h>
#include <unistd.h>
#include <sys/syscall.h>

namespace proxy_wasm {

Expand All @@ -42,7 +41,7 @@ std::unique_ptr<WasmVm> DynVm::clone() {
}

// "Load" the plugin by obtaining a pointer to it from the factory.
bool DynVm::load(std::string_view shared_lib, std::string_view /*precompiled*/,
bool DynVm::load(std::string_view plugin_name, std::string_view /*precompiled*/,
const std::unordered_map<uint32_t, std::string> & /*function_names*/) {
plugin_ = std::make_unique<DynVmPlugin>();
plugin_->source = std::make_shared<DynVmPluginSource>();
Expand All @@ -59,13 +58,14 @@ bool DynVm::load(std::string_view shared_lib, std::string_view /*precompiled*/,

size_t written = 0;
ssize_t wrote;
const char *data = shared_lib.data();
while (written < shared_lib.size()) {
wrote = write(plugin_->source->memfd, data + written, shared_lib.length() - written);
const char *data = plugin_name.data();
while (written < plugin_name.size()) {
wrote = write(plugin_->source->memfd, data + written, plugin_name.length() - written);
if (wrote < 0) {
integration()->error("failed to write to memfd: " + std::string(strerror(errno)));
return false;
} else if (wrote == 0) {
}
if (wrote == 0) {
integration()->error("failed to write to memfd, EOF on write");
return false;
}
Expand Down
56 changes: 28 additions & 28 deletions src/dyn/dyn_vm_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
#include "include/proxy-wasm/wasm_vm.h"

#include "include/proxy-wasm/dyn_vm_plugin.h"
#include <iostream>
#include <cstdarg>
#include <dlfcn.h>
#include <thread>
#include <iostream>
#include <sstream>
#include <string>
#include <stdexcept>
#include <string>
#include <thread>
#include <unistd.h>
#include <cstdarg>

namespace proxy_wasm {

Expand Down Expand Up @@ -61,16 +61,16 @@ DynVmPluginSource::~DynVmPluginSource() {
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallVoid<0> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
void (*target_func)() = reinterpret_cast<void (*)()>(target);
auto target_func = reinterpret_cast<void (*)()>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context) {
proxy_wasm::SaveRestoreContext saved_context(context);
wasm_vm_->integration()->trace(call_format(function_name, 0));
Expand All @@ -79,16 +79,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallVoid<1> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
void (*target_func)(uint64_t) = reinterpret_cast<void (*)(uint64_t)>(target);
auto target_func = reinterpret_cast<void (*)(uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1) {
proxy_wasm::SaveRestoreContext saved_context(context);
wasm_vm_->integration()->trace(call_format(function_name, 1, w1));
Expand All @@ -97,16 +97,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallVoid<2> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
void (*target_func)(uint64_t, uint64_t) = reinterpret_cast<void (*)(uint64_t, uint64_t)>(target);
auto target_func = reinterpret_cast<void (*)(uint64_t, uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1,
proxy_wasm::Word w2) {
proxy_wasm::SaveRestoreContext saved_context(context);
Expand All @@ -116,16 +116,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallVoid<3> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
void (*target_func)(uint64_t, uint64_t, uint64_t) =
auto target_func =
reinterpret_cast<void (*)(uint64_t, uint64_t, uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1,
proxy_wasm::Word w2, proxy_wasm::Word w3) {
Expand All @@ -136,16 +136,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallVoid<5> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
void (*target_func)(uint64_t, uint64_t, uint64_t, uint64_t, uint64_t) =
auto target_func =
reinterpret_cast<void (*)(uint64_t, uint64_t, uint64_t, uint64_t, uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1,
proxy_wasm::Word w2, proxy_wasm::Word w3,
Expand All @@ -157,16 +157,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallWord<1> *f) {
if (source->dl_handle == NULL || function_name == "malloc") {
if (source->dl_handle == nullptr || function_name == "malloc") {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
uint64_t (*target_func)(uint64_t) = reinterpret_cast<uint64_t (*)(uint64_t)>(target);
auto target_func = reinterpret_cast<uint64_t (*)(uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1) {
proxy_wasm::SaveRestoreContext saved_context(context);
wasm_vm_->integration()->trace(call_format(function_name, 1, w1));
Expand All @@ -175,16 +175,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallWord<2> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
uint64_t (*target_func)(uint64_t, uint64_t) =
auto target_func =
reinterpret_cast<uint64_t (*)(uint64_t, uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1,
proxy_wasm::Word w2) {
Expand All @@ -195,16 +195,16 @@ void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCa
}

void DynVmPlugin::getFunction(std::string_view function_name, proxy_wasm::WasmCallWord<3> *f) {
if (source->dl_handle == NULL) {
if (source->dl_handle == nullptr) {
*f = nullptr;
return;
}
void *target = dlsym(source->dl_handle, std::string(function_name).c_str());
if (target == NULL) {
if (target == nullptr) {
*f = nullptr;
return;
}
uint64_t (*target_func)(uint64_t, uint64_t, uint64_t) =
auto target_func =
reinterpret_cast<uint64_t (*)(uint64_t, uint64_t, uint64_t)>(target);
*f = [this, target_func, function_name](proxy_wasm::ContextBase *context, proxy_wasm::Word w1,
proxy_wasm::Word w2, proxy_wasm::Word w3) {
Expand Down
15 changes: 15 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,21 @@ cc_test(
],
)

cc_test(
name = "dyn_vm_test",
srcs = ["dyn_vm_test.cc"],
data = [
"//test/test_data:abi_export.so",
],
linkstatic = 1,
deps = [
":utility_lib",
"//:lib",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "wasm_vm_test",
timeout = "long",
Expand Down
6 changes: 6 additions & 0 deletions test/test_data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

load("@proxy_wasm_cpp_host//bazel:wasm.bzl", "wasm_rust_binary")
load("@proxy_wasm_cpp_host//bazel:wasm.bzl", "dyn_rust_library")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")

licenses(["notice"]) # Apache 2
Expand All @@ -24,6 +25,11 @@ wasm_rust_binary(
srcs = ["abi_export.rs"],
)

dyn_rust_library(
name = "abi_export.so",
srcs = ["abi_export.rs"],
)

wasm_rust_binary(
name = "abi_export.signed.with.key1.wasm",
srcs = ["abi_export.rs"],
Expand Down
11 changes: 11 additions & 0 deletions test/utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ std::vector<std::string> getWasmEngines() {
return engines;
}

std::vector<std::string> getDynEngines() {
std::vector<std::string> engines = {
#if defined(PROXY_WASM_HOST_ENGINE_DYN)
"dyn",
#endif
""
};
engines.pop_back();
return engines;
}

std::string readTestWasmFile(const std::string &filename) {
auto path = "test/test_data/" + filename;
std::ifstream file(path, std::ios::binary);
Expand Down
Loading

0 comments on commit 0076a3a

Please sign in to comment.