Skip to content

Commit

Permalink
XXX:tower: Work in progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
lkorenc committed Jun 20, 2024
1 parent 7fc2184 commit abff31e
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 21 deletions.
66 changes: 61 additions & 5 deletions include/vast/Tower/Link.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,47 @@
#pragma once

#include "vast/Util/Common.hpp"

#include "vast/Tower/Handle.hpp"
#include "vast/Tower/LocationInfo.hpp"

VAST_RELAX_WARNINGS
VAST_UNRELAX_WARNINGS

#include <memory>
#include <vector>

namespace vast::tw {

struct one_step_link_t {
handle_t from;
handle_t to;
struct one_step_link_interface {
virtual ~one_step_link_interface() = default;

virtual handle_t from() const = 0;
virtual handle_t to() const = 0;
};

struct light_one_step_link : one_step_link_interface {
protected:
handle_t _from;
handle_t _to;

location_info &li;

public:

explicit light_one_step_link(handle_t _from, handle_t _to, location_info &li)
: _from(_from), _to(_to), li(li)
{}

handle_t from() const override { return _from; }
handle_t to() const override { return _to; }
};

using operations = std::vector< operation >;

// TODO: Implement.
struct link_t {
virtual ~link_t() = default;
struct link_interface {
virtual ~link_interface() = default;

// These are not forced as `const` to allow runtime caching.
virtual operations children(operation) = 0;
Expand All @@ -33,4 +57,36 @@ namespace vast::tw {
virtual operations shared_parents(operations) = 0;
};

using unit_link_ptr = std::unique_ptr< one_step_link_interface >;
using unit_link_vector = std::vector< unit_link_ptr >;

// `A -> ... -> E` - each middle link is kept and there pre-computed
// mapping for `A -> E` transition.
struct fat_link : link_interface {
protected:
unit_link_vector steps;

// some mapping
using op_mapping = std::unordered_map< operation, operations >;

op_mapping down;
op_mapping up;

public:

explicit fat_link(unit_link_vector steps) : steps(std::move(steps)) {}

operations children(operation) override;
operations children(operations) override;

operations shared_children(operations) override;

operations parents(operation) override;
operations parents(operations) override;

operations shared_parents(operations) override;
};



} // namespace vast::tw
4 changes: 3 additions & 1 deletion include/vast/Tower/LocationInfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace vast::tw {

public:
// For the given operation return location to be used in this module.
loc_t get_next(operation high, operation low);
loc_t get_next(operation op);
loc_t get_root(operation op);
static bool are_tied(operation high, operation low);
};
Expand All @@ -52,4 +52,6 @@ namespace vast::tw {
// however require slightly different handling, so we are exposing a hook for that.
void make_root(location_info &, operation);

void transform_locations(location_info &, operation);

} // namespace vast::tw
10 changes: 2 additions & 8 deletions include/vast/Tower/Tower.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace vast::tw {
struct tower
{
private:
mcontext_t &mctx;
[[maybe_unused]] mcontext_t &mctx;
module_storage storage;
handle_t top_handle;

Expand All @@ -39,13 +39,7 @@ namespace vast::tw {

handle_t top() const { return top_handle; }

auto apply(handle_t, mlir::PassManager &) -> handle_t {
return {};
}

auto apply(handle_t, owning_pass_ptr) -> handle_t {
return {};
}
handle_t apply(handle_t, location_info &, mlir::PassManager &);
};

using default_tower = tower;
Expand Down
1 change: 1 addition & 0 deletions lib/vast/Tower/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2022-present, Trail of Bits, Inc.

add_vast_library(Tower
Link.cpp
LocationInfo.cpp
Tower.cpp
)
15 changes: 10 additions & 5 deletions lib/vast/Tower/LocationInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
#include "vast/Tower/LocationInfo.hpp"

namespace vast::tw {
loc_t location_info::get_next(operation high, operation low) {
auto mctx = low->getContext();
auto id = mlir::OpaqueLoc::get< operation >(low, mctx);
return mlir::FusedLoc::get({ self(high), id }, {}, mctx);
loc_t location_info::get_next(operation op) {
auto mctx = op->getContext();
auto id = mlir::OpaqueLoc::get< operation >(op, mctx);
return mlir::FusedLoc::get({ self(op), id }, {}, mctx);
}

loc_t location_info::get_root(operation op) {
auto mctx = op->getContext();
auto id = mlir::OpaqueLoc::get< operation >(op, mctx);
auto id = mlir::OpaqueLoc::get< std::size_t >(reinterpret_cast< std::uintptr_t >(op), mctx);
return mlir::FusedLoc::get({ op->getLoc(), id }, {}, mctx);
}

Expand All @@ -24,4 +24,9 @@ namespace vast::tw {
root->walk(set_loc);
}

void transform_locations(location_info &li, operation root) {
auto set_loc = [&](operation op) { op->setLoc(li.get_next(op)); };
root->walk(set_loc);
}

} // namespace vast::tw
60 changes: 60 additions & 0 deletions lib/vast/Tower/Tower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,65 @@

namespace vast::tw {

struct link_builder : mlir::PassInstrumentation {

location_info &li;
module_storage &storage;

// Start empty and after each callback add to it.
conversion_path_t path = {};

std::vector< handle_t > handles;

explicit link_builder(location_info &li, module_storage &storage, handle_t root)
: li(li), storage(storage), handles{ root }
{}

void runAfterPass(pass_ptr pass, operation op) override {
llvm::errs() << "PASS RAN!\n";
std::ignore = pass;

auto mod = mlir::dyn_cast< vast_module >(op);
VAST_CHECK(mod, "Pass inside tower was not run on module!");

// Update locations so each operation now has a unique loc that also
// encodes backlink.
transform_locations(li, mod);

mlir::OpBuilder bld(op->getContext());
owning_module_ref persistent = mlir::dyn_cast< vast_module >(bld.clone(*op));

mlir::OpPrintingFlags flags;
persistent->print(llvm::outs(), flags.enableDebugInfo(true, false));

path.emplace_back(pass->getArgument().str());
handles.emplace_back(storage.store(path, std::move(persistent)));
}

unit_link_vector link_vector() {
unit_link_vector out;
for (std::size_t i = 1; i < handles.size(); ++i) {
auto step = std::make_unique< light_one_step_link >(handles[i - 1], handles[i], li);
out.emplace_back(std::move(step));
}
return out;
}

std::unique_ptr< link_interface > extract_link() {
auto unit_links = link_vector();


return std::make_unique< fat_link >(std::move(unit_links));
}
};

handle_t tower::apply(handle_t root, location_info &li, mlir::PassManager &pm) {
auto bld = std::make_unique< link_builder >(li, storage, top());
pm.addInstrumentation(std::move(bld));

std::ignore = pm.run(root.mod);

return root;
}

} // namespace vast::tw
6 changes: 4 additions & 2 deletions tools/vast-repl/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ namespace cmd {

void show_module(state_t &state) {
check_and_emit_module(state);
llvm::outs() << state.current_module() << "\n";
mlir::OpPrintingFlags flags;
state.current_module()->print(llvm::outs(), flags.enableDebugInfo(true, false));
llvm::outs() << "\n";
}

void show_symbols(state_t &state) {
Expand Down Expand Up @@ -156,8 +158,8 @@ namespace cmd {
if (mlir::failed(mlir::parsePassPipeline(pass, pm))) {
throw_error("failed to parse pass pipeline");
}
th = state.tower->apply(th, pm);
}
th = state.tower->apply(th, state.li, pm);
}

//
Expand Down

0 comments on commit abff31e

Please sign in to comment.