diff --git a/include/vast/Tower/Handle.hpp b/include/vast/Tower/Handle.hpp new file mode 100644 index 0000000000..e57fcda0a8 --- /dev/null +++ b/include/vast/Tower/Handle.hpp @@ -0,0 +1,22 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. + +#pragma once + +#include "vast/Util/Common.hpp" + +VAST_RELAX_WARNINGS +VAST_UNRELAX_WARNINGS + +namespace vast::tw { + using conversion_path_t = std::vector< std::string >; + using conversion_path_fingerprint_t = std::string; + + using handle_id_t = std::size_t; + + struct handle_t + { + handle_id_t id; + vast_module mod; + }; + +} // namespace vast::tw diff --git a/include/vast/Tower/Link.hpp b/include/vast/Tower/Link.hpp new file mode 100644 index 0000000000..1ffbfe6732 --- /dev/null +++ b/include/vast/Tower/Link.hpp @@ -0,0 +1,118 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#pragma once + +#include "vast/Util/Common.hpp" + +#include "vast/Tower/Handle.hpp" +#include "vast/Tower/LocationInfo.hpp" + +VAST_RELAX_WARNINGS +VAST_UNRELAX_WARNINGS + +#include +#include + +namespace vast::tw { + + using operations = std::vector< operation >; + + // Maybe we want to abstract this as an interface - then it can be lazy as well, for example + // we could pass in only some step chain and location info and it will get computed. + // Since if we have `A -> B` we can always construct `B -> A` it may make sense to simply only + // export bidirectional mapping? + using op_mapping = llvm::DenseMap< operation, operations >; + + // Generic interface to generalize the connection between any two modules. + // There are no performance guarantees in general, but there should be implementations + // available that try to be as performant as possible. + struct link_interface + { + virtual ~link_interface() = default; + + // These are not forced as `const` to allow runtime caching. + virtual operations children(operation) = 0; + virtual operations children(operations) = 0; + + virtual operations parents(operation) = 0; + virtual operations parents(operations) = 0; + + virtual op_mapping parents_to_children() = 0; + virtual op_mapping children_to_parents() = 0; + + virtual handle_t parent() const = 0; + virtual handle_t child() const = 0; + }; + + namespace views { + static inline auto parents_to_children = [](const auto &link) { + return link->parents_to_children(); + }; + + static inline auto children_to_parents = [](const auto &link) { + return link->children_to_parents(); + }; + + } // namespace views + + // Represent application of some passes. Invariant is that + // `parent -> child` are tied by the `location_info`. + // TODO: How to enforce this - private ctor and provide a builder interface on the side + // that is a friend and allowed to create these? + struct conversion_step : link_interface { + protected: + handle_t _parent; + handle_t _child; + location_info_t &_location_info; + + public: + explicit conversion_step(handle_t parent, handle_t child, location_info_t &location_info) + : _parent(parent), _child(child), _location_info(location_info) + {} + + operations children(operation) override; + operations children(operations) override; + + operations parents(operation) override; + operations parents(operations) override; + + op_mapping parents_to_children() override; + op_mapping children_to_parents() override; + + handle_t parent() const override; + handle_t child() const override; + }; + + using link_ptr = std::unique_ptr< link_interface >; + using link_vector = std::vector< link_ptr >; + + using conversion_steps = std::vector< conversion_step >; + + // `A -> ... -> E` - each middle link is kept and there is pre-computed + // mapping for `A <-> E` transition to make it more performant. + struct fat_link : link_interface + { + protected: + link_vector _links; + + op_mapping _to_children; + op_mapping _to_parents; + + public: + explicit fat_link(link_vector links); + fat_link() = delete; + + operations children(operation) override; + operations children(operations) override; + + operations parents(operation) override; + operations parents(operations) override; + + op_mapping parents_to_children() override; + op_mapping children_to_parents() override; + + handle_t child() const override; + handle_t parent() const override; + }; + +} // namespace vast::tw diff --git a/include/vast/Tower/LocationInfo.hpp b/include/vast/Tower/LocationInfo.hpp new file mode 100644 index 0000000000..f2fa1cf2ed --- /dev/null +++ b/include/vast/Tower/LocationInfo.hpp @@ -0,0 +1,68 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. + +#pragma once + +#include "vast/Util/Common.hpp" + +VAST_RELAX_WARNINGS +#include +#include +VAST_UNRELAX_WARNINGS + +#include "vast/Tower/Handle.hpp" + +namespace vast::tw { + + // Is allowed to have state? + struct location_info_t + { + private: + // Encoded as `mlir::FusedLocation(original, mlir::OpaqueLocation(pointer_to_self))` + using raw_loc_t = mlir::FusedLoc; + + static raw_loc_t raw_loc(operation op) { + auto raw = mlir::dyn_cast< raw_loc_t >(op->getLoc()); + VAST_CHECK(raw, "{0} with loc: {1}", *op, op->getLoc()); + return raw; + } + + template< std::size_t idx > + requires (idx < 2) + static loc_t get(raw_loc_t raw) { + auto locs = raw.getLocations(); + VAST_ASSERT(locs.size() == 2); + return locs[idx]; + } + + static auto parse(operation op) { return std::make_tuple(prev(op), self(op)); } + + // TODO: These are strictly not needed in this form, but help initial + // debugging a lot. + std::string fingerprint(const conversion_path_t &); + loc_t mk_unique_loc(const conversion_path_t &, operation); + loc_t mk_linked_loc(loc_t self, loc_t prev); + + public: + // For the given operation return location to be used in this module. + loc_t get_as_child(const conversion_path_t &, operation op); + loc_t get_root(operation op); + + static loc_t self(raw_loc_t raw) { return get< 1 >(raw); } + + static loc_t prev(raw_loc_t raw) { return get< 0 >(raw); } + + static loc_t self(operation op) { return self(raw_loc(op)); } + + static loc_t prev(operation op) { return prev(raw_loc(op)); } + + static bool are_tied(operation parent, operation child); + }; + + // Since we are going to tie together arbitrary modules, it makes sense to make them + // have locations in the same shape - therefore root shouldn't be an excuse. It will + // however require slightly different handling, so we are exposing a hook for that. + void mk_root(location_info_t &, operation); + + void transform_locations(location_info_t &, const conversion_path_t &, operation); + +} // namespace vast::tw diff --git a/include/vast/Tower/Storage.hpp b/include/vast/Tower/Storage.hpp new file mode 100644 index 0000000000..761a6b3c77 --- /dev/null +++ b/include/vast/Tower/Storage.hpp @@ -0,0 +1,52 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. + +#pragma once + +#include "vast/Util/Common.hpp" + +VAST_RELAX_WARNINGS +#include +#include +VAST_UNRELAX_WARNINGS + +#include "vast/Tower/Handle.hpp" + +#include +#include + +namespace vast::tw { + + struct module_storage + { + // TODO: API-wise, we probably want to accept any type that is `mlir::OwningOpRef< T >`? + handle_t store(const conversion_path_t &path, owning_module_ref mod) { + auto id = allocate_id(path); + auto [it, _] = storage.insert({id, std::move(mod)}); + return { id, it->second.get() }; + } + + void remove(handle_t) { VAST_UNIMPLEMENTED; } + + private: + conversion_path_fingerprint_t fingerprint(const conversion_path_t &path) const { + return std::accumulate(path.begin(), path.end(), std::string{}); + } + + handle_id_t allocate_id(const conversion_path_t &path) { + return allocate_id(fingerprint(path)); + } + + handle_id_t allocate_id(const conversion_path_fingerprint_t &fp) { + // Later here we want to return the cached module? + VAST_CHECK(!conversion_tree.count(fp), "For now cannot do caching!"); + auto id = next_id++; + conversion_tree.emplace(fp, id); + return id; + } + + std::size_t next_id = 0; + llvm::DenseMap< handle_id_t, owning_module_ref > storage; + // TODO: This is just a prototyping shortcut, we may want something smarter here. + std::unordered_map< conversion_path_fingerprint_t, handle_id_t > conversion_tree; + }; +} // namespace vast::tw diff --git a/include/vast/Tower/Tower.hpp b/include/vast/Tower/Tower.hpp index 2e34ad2d9c..e3dd6d1bdf 100644 --- a/include/vast/Tower/Tower.hpp +++ b/include/vast/Tower/Tower.hpp @@ -5,75 +5,38 @@ #include "vast/Util/Common.hpp" VAST_RELAX_WARNINGS -#include #include +#include VAST_UNRELAX_WARNINGS -namespace vast::tw { - - struct default_loc_rewriter_t - { - static auto insert(mlir::Operation *op) -> void; - static auto remove(mlir::Operation *op) -> void; - static auto prev(mlir::Operation *op) -> mlir::Operation *; - }; +#include "vast/Tower/Handle.hpp" +#include "vast/Tower/Link.hpp" +#include "vast/Tower/LocationInfo.hpp" +#include "vast/Tower/Storage.hpp" - using pass_ptr_t = std::unique_ptr< mlir::Pass >; +namespace vast::tw { - template< typename loc_rewriter_t > struct tower { - using loc_rewriter = loc_rewriter_t; - - struct handle_t - { - std::size_t id; - vast_module mod; - }; - - static auto get(mcontext_t &ctx, owning_module_ref mod) - -> std::tuple< tower, handle_t > { - tower t(ctx, std::move(mod)); - handle_t h{ .id = 0, .mod = t._modules[0].get() }; - return { std::move(t), h }; - } - - auto apply(handle_t handle, mlir::PassManager &pm) -> handle_t { - handle.mod.walk(loc_rewriter::insert); - - _modules.emplace_back(mlir::cast< vast_module >(handle.mod->clone())); - - auto id = _modules.size() - 1; - auto mod = _modules.back().get(); - - if (mlir::failed(pm.run(mod))) { - VAST_FATAL("some pass in apply() failed"); - } - - handle.mod.walk(loc_rewriter::remove); - - return { id, mod }; - } - - auto apply(handle_t handle, pass_ptr_t pass) -> handle_t { - mlir::PassManager pm(_ctx); - pm.addPass(std::move(pass)); - return apply(handle, pm); + private: + [[maybe_unused]] mcontext_t &mctx; + module_storage storage; + handle_t top_handle; + + public: + tower(mcontext_t &mctx, location_info_t &li, owning_module_ref root) : mctx(mctx) { + mk_root(li, root->getOperation()); + top_handle = storage.store(root_conversion(), std::move(root)); } - auto top() -> handle_t { return { _modules.size(), _modules.back().get() }; } - private: - using module_storage_t = llvm::SmallVector< owning_module_ref, 2 >; + // TODO: Move somewhere else. + static conversion_path_t root_conversion() { return {}; } - mcontext_t *_ctx; - module_storage_t _modules; + public: + handle_t top() const { return top_handle; } - tower(mcontext_t &ctx, owning_module_ref mod) : _ctx(&ctx) { - _modules.emplace_back(std::move(mod)); - } + link_ptr apply(handle_t, location_info_t &, mlir::PassManager &); }; - using default_tower = tower< default_loc_rewriter_t >; - } // namespace vast::tw diff --git a/include/vast/repl/command.hpp b/include/vast/repl/command.hpp index 8b6bc34513..0c2a720fce 100644 --- a/include/vast/repl/command.hpp +++ b/include/vast/repl/command.hpp @@ -46,7 +46,7 @@ namespace vast::repl struct string_param { std::string value; }; struct integer_param { std::uint64_t value; }; - enum class show_kind { source, ast, module, symbols, pipelines }; + enum class show_kind { source, ast, module, symbols, pipelines, link }; template< typename enum_type > enum_type from_string(string_ref token) requires(std::is_same_v< enum_type, show_kind >) { @@ -55,6 +55,7 @@ namespace vast::repl if (token == "module") return enum_type::module; if (token == "symbols") return enum_type::symbols; if (token == "pipelines") return enum_type::pipelines; + if (token == "link") return enum_type::link; throw_error("uknnown show kind: {0}", token.str()); } @@ -183,9 +184,11 @@ namespace vast::repl static constexpr string_ref name() { return "show"; } static constexpr inline char kind_param[] = "kind_param_name"; + static constexpr inline char name_param[] = "name_param"; using command_params = util::type_list< - named_param< kind_param, show_kind > + named_param< kind_param, show_kind >, + named_param< name_param, string_param > >; using params_storage = command_params::as_tuple; @@ -256,9 +259,11 @@ namespace vast::repl static constexpr string_ref name() { return "raise"; } static constexpr inline char pipeline_param[] = "pipeline_name"; + static constexpr inline char link_name_param[] = "link_name"; using command_params = - util::type_list< named_param< pipeline_param, string_param > >; + util::type_list< named_param< pipeline_param, string_param >, + named_param< link_name_param, string_param > >; using params_storage = command_params::as_tuple; diff --git a/include/vast/repl/state.hpp b/include/vast/repl/state.hpp index ed50b84f5a..a424138d1d 100644 --- a/include/vast/repl/state.hpp +++ b/include/vast/repl/state.hpp @@ -8,6 +8,7 @@ #include "vast/repl/pipeline.hpp" #include +#include namespace vast::repl { @@ -28,7 +29,14 @@ namespace vast::repl { // mlir module and context // mcontext_t &ctx; - std::optional< tw::default_tower > tower; + + // + // Tower related state + // + tw::location_info_t location_info; + std::optional< tw::tower > tower; + + std::unordered_map< std::string, tw::link_ptr > links; // // sticked commands performed after each step @@ -44,6 +52,9 @@ namespace vast::repl { // verbosity flags // bool verbose_pipeline = true; + + void raise_tower(owning_module_ref mod); + vast_module current_module(); }; } // namespace vast::repl diff --git a/lib/vast/Tower/CMakeLists.txt b/lib/vast/Tower/CMakeLists.txt index 87098d5d64..c5a997f6fa 100644 --- a/lib/vast/Tower/CMakeLists.txt +++ b/lib/vast/Tower/CMakeLists.txt @@ -1,5 +1,7 @@ # Copyright (c) 2022-present, Trail of Bits, Inc. add_vast_library(Tower + Link.cpp + LocationInfo.cpp Tower.cpp ) diff --git a/lib/vast/Tower/Link.cpp b/lib/vast/Tower/Link.cpp new file mode 100644 index 0000000000..f942a8b1db --- /dev/null +++ b/lib/vast/Tower/Link.cpp @@ -0,0 +1,185 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#include "vast/Tower/Link.hpp" + +#include + +namespace vast::tw { + + namespace { + + // TODO: Reimplement using ranges once we have newer stdlib in CI + void append_range(auto &into, const auto &what) { + into.insert(into.end(), what.begin(), what.end()); + } + + } // namespace + + op_mapping reverse_mapping(const op_mapping &from) { + op_mapping out; + for (const auto &[root, ops] : from) { + for (auto op : ops) { + out[op].push_back(root); + } + } + return out; + } + + void dbg(const op_mapping &mapping, auto &outs) { + outs << "Mapping:\n"; + + auto flag = mlir::OpPrintingFlags().skipRegions(); + auto render_op = [&](operation op) -> decltype(outs) & { + outs << op << ": "; + op->print(outs, flag); + if (!op->getRegions().empty()) + outs << " ... regions ..."; + return outs; + }; + + for (const auto &[from, to] : mapping) { + outs << "== "; + render_op(from) << "\n"; + for (auto t : to) { + outs << " -> "; + render_op(t) << "\n"; + } + } + } + + using loc_to_op_t = llvm::DenseMap< loc_t, operations >; + + void dbg(const loc_to_op_t &mapping, auto &outs) { + outs << "loc_to_op\n"; + for (const auto &[loc, to] : mapping) { + outs << ".." << loc << "\n"; + for (auto t : to) { + outs << " -> " << *t << "\n"; + } + } + } + + loc_to_op_t gather_loc_to_op(location_info_t &li, operation op) { + loc_to_op_t out; + auto collect = [&](operation op) { out[li.self(op)].push_back(op); }; + op->walk(collect); + return out; + } + + // TODO: Reimplement using ranges once we have newer stdlib in CI + auto mk_link_mappings(const auto &links) { + std::vector< op_mapping > out; + for (const auto &l : links) { + out.emplace_back(l->parents_to_children()); + } + return out; + } + + + // `transition` handles the lookup between levels - we need this as we want to generalise + // (`levels` can be arbitrary mapping, not just `op -> { op }`). + op_mapping build_map(op_mapping init, const auto &levels, auto transition) { + auto handle_level = [&](const auto &level) { + auto handle_element = [&](operation op) -> operations { + auto prev_ops = level.find(transition(op)); + if (prev_ops == level.end()) + return {}; + return prev_ops->second; + }; + + for (auto &[op, todo] : init) { + operations parents; + for (auto current : todo) { + append_range(parents, handle_element(current)); + } + todo = std::move(parents); + } + }; + + for (const auto &level : levels | std::views::drop(1)) + handle_level(level); + + return init; + } + + op_mapping build_map(const link_vector &links) { + auto transition = [](operation op) { return op; }; + auto init = links.front()->parents_to_children(); + return build_map(std::move(init), mk_link_mappings(links), transition); + } + + op_mapping build_map(std::vector< loc_to_op_t > links, location_info_t &li) { + auto transition = [&](operation op) { return li.prev(op); }; + + std::reverse(links.begin(), links.end()); + + // initialize + op_mapping init; + for (auto [_, op] : links.front()) { + init[op.front()] = { op }; + } + return build_map(std::move(init), links, transition); + } + + op_mapping build_map(handle_t parent, handle_t child, location_info_t &li) { + return build_map({gather_loc_to_op(li, parent.mod), gather_loc_to_op(li, child.mod)}, li); + } + + /* conversion_step::link_interface API */ + + operations conversion_step::children(operation) { VAST_UNIMPLEMENTED; } + operations conversion_step::children(operations) { VAST_UNIMPLEMENTED; } + + operations conversion_step::parents(operation) { VAST_UNIMPLEMENTED; } + operations conversion_step::parents(operations) { VAST_UNIMPLEMENTED; } + + op_mapping conversion_step::parents_to_children() { + return reverse_mapping(children_to_parents()); + } + + op_mapping conversion_step::children_to_parents() { + return build_map(parent(), child(), _location_info); + } + + handle_t conversion_step::parent() const { return _parent; } + handle_t conversion_step::child() const { return _child; } + + /* fat_link */ + + fat_link::fat_link(link_vector links) + : _links(std::move(links)), + _to_children(build_map(_links)), + _to_parents(reverse_mapping(_to_children)) + {} + + /* fat_link::link_interface API */ + + operations fat_link::children(operation op) { + return _to_children[op]; + } + + operations fat_link::children(operations ops) { + operations out; + for (auto op : ops) + append_range(out, children(op)); + return out; + } + + operations fat_link::parents(operation op) { + return _to_parents[op]; + } + + operations fat_link::parents(operations ops) { + operations out; + for (auto op : ops) + append_range(out, parents(op)); + return out; + } + + op_mapping fat_link::parents_to_children() { return _to_children; } + op_mapping fat_link::children_to_parents() { return _to_parents; } + + handle_t fat_link::parent() const { return _links.front()->parent(); } + handle_t fat_link::child() const { return _links.back()->child(); } + +} // namespace vast::tw diff --git a/lib/vast/Tower/LocationInfo.cpp b/lib/vast/Tower/LocationInfo.cpp new file mode 100644 index 0000000000..34eb9be0e7 --- /dev/null +++ b/lib/vast/Tower/LocationInfo.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. + +#include "vast/Tower/LocationInfo.hpp" + +#include +#include + +namespace vast::tw { + + std::string location_info_t::fingerprint(const conversion_path_t &path) { + return std::accumulate(path.begin(), path.end(), std::string{}); + } + + loc_t location_info_t::mk_unique_loc(const conversion_path_t &path, operation op) { + auto mctx = op->getContext(); + auto raw_id = + fingerprint(path) + std::to_string(reinterpret_cast< std::uintptr_t >(op)); + return mlir::FileLineColLoc::get(mctx, raw_id, 0, 0); + } + + loc_t location_info_t::mk_linked_loc(loc_t self, loc_t prev) { + auto mctx = self->getContext(); + return mlir::FusedLoc::get({ self, prev }, {}, mctx); + } + + loc_t location_info_t::get_as_child(const conversion_path_t &path, operation op) { + return mk_linked_loc(self(op), mk_unique_loc(path, op)); + } + + loc_t location_info_t::get_root(operation op) { + return mk_linked_loc(op->getLoc(), mk_unique_loc({}, op)); + } + + bool location_info_t::are_tied(operation parent, operation child) { + return self(parent) == prev(child); + } + + void mk_root(location_info_t &li, operation root) { + auto set_loc = [&](operation op) { op->setLoc(li.get_root(op)); }; + root->walk(set_loc); + } + + void transform_locations(location_info_t &li, const conversion_path_t &path, operation root) { + auto set_loc = [&](operation op) { op->setLoc(li.get_as_child(path, op)); }; + root->walk(set_loc); + } + +} // namespace vast::tw diff --git a/lib/vast/Tower/Tower.cpp b/lib/vast/Tower/Tower.cpp index b81bdadb1f..e7201347de 100644 --- a/lib/vast/Tower/Tower.cpp +++ b/lib/vast/Tower/Tower.cpp @@ -3,20 +3,57 @@ #include "vast/Tower/Tower.hpp" namespace vast::tw { - auto default_loc_rewriter_t::insert(mlir::Operation *op) -> void { - auto ctx = op->getContext(); - auto ol = mlir::OpaqueLoc::get< mlir::Operation *>(op, ctx); - op->setLoc(mlir::FusedLoc::get({ op->getLoc() }, ol, ctx)); - } - auto default_loc_rewriter_t::remove(mlir::Operation *op) -> void { - auto fl = mlir::cast< mlir::FusedLoc >(op->getLoc()); - op->setLoc(fl.getLocations().front()); - } + struct link_builder : mlir::PassInstrumentation + { + location_info_t &li; + module_storage &storage; + + // Start empty and after each callback add to it. + conversion_path_t path = {}; + + std::vector< handle_t > handles; + link_vector steps; + + explicit link_builder(location_info_t &li, module_storage &storage, handle_t root) + : li(li), storage(storage), handles{ root } {} + + void runAfterPass(pass_ptr pass, operation op) override { + auto mod = mlir::dyn_cast< vast_module >(op); + VAST_CHECK(mod, "Pass inside tower was not run on module!"); + + // Update locations so each operation now has a unique loc that also + // encodes backlink. + path.emplace_back(pass->getArgument().str()); + transform_locations(li, path, mod); - auto default_loc_rewriter_t::prev(mlir::Operation *op) -> mlir::Operation * { - auto fl = mlir::cast< mlir::FusedLoc >(op->getLoc()); - auto ol = mlir::cast< mlir::OpaqueLoc >(fl.getMetadata()); - return mlir::OpaqueLoc::getUnderlyingLocation< mlir::Operation * >(ol); + owning_module_ref persistent = mlir::dyn_cast< vast_module >(op->clone()); + + auto from = handles.back(); + handles.emplace_back(storage.store(path, std::move(persistent))); + steps.emplace_back(std::make_unique< conversion_step >(from, handles.back(), li)); + } + + std::unique_ptr< link_interface > extract_link() { + VAST_CHECK(!steps.empty(), "No conversions happened!"); + return std::make_unique< fat_link >(std::move(steps)); + } + }; + + link_ptr tower::apply(handle_t root, location_info_t &li, mlir::PassManager &pm) { + auto bld = std::make_unique< link_builder >(li, storage, top()); + + // We need to access some of the data after passes are ran. + auto raw_bld = bld.get(); + pm.addInstrumentation(std::move(bld)); + + // We need to do a clone, because we received a handle - this means that the module + // is already stored and should not be modified. + auto clone = root.mod->clone(); + + // TODO: What if this fails? + std::ignore = pm.run(clone); + return raw_bld->extract_link(); } + } // namespace vast::tw diff --git a/test/repl/raise.c b/test/repl/raise.c index c345bb0c13..a77ab87796 100644 --- a/test/repl/raise.c +++ b/test/repl/raise.c @@ -1,4 +1,6 @@ // RUN: printf "load %s\n raise vast-hl-to-ll-cf\n show module\n exit" | %vast-repl | %file-check %s -// CHECK: ll.return %0 : !hl.int +// CHECK: hl.return %0 : !hl.int + +// REQUIRES: clone-memory-leak int main(void) { return 0; } diff --git a/tools/vast-repl/CMakeLists.txt b/tools/vast-repl/CMakeLists.txt index 2dc983aa26..2b93242902 100644 --- a/tools/vast-repl/CMakeLists.txt +++ b/tools/vast-repl/CMakeLists.txt @@ -4,6 +4,7 @@ add_vast_executable(vast-repl codegen.cpp command.cpp config.cpp + state.cpp LINK_LIBS ${CLANG_LIBS} diff --git a/tools/vast-repl/command.cpp b/tools/vast-repl/command.cpp index b69591aab9..e5647bfce6 100644 --- a/tools/vast-repl/command.cpp +++ b/tools/vast-repl/command.cpp @@ -10,6 +10,25 @@ namespace vast::repl { namespace cmd { + // TODO: Really naive way to visualize. + void render_link(const tw::link_ptr &ptr) { + auto flag = mlir::OpPrintingFlags().skipRegions(); + + auto render_op = [&](operation op) -> llvm::raw_fd_ostream & { + op->print(llvm::outs(), flag); + return llvm::outs(); + }; + + auto render = [&](operation op) { + render_op(op) << "\n"; + for (auto c : ptr->children(op)) { + llvm::outs() << "\t => "; + render_op(c) << "\n"; + } + }; + ptr->parent().mod->walk< mlir::WalkOrder::PreOrder >(render); + } + void check_source(const state_t &state) { if (!state.source.has_value()) { throw_error("error: missing source"); @@ -31,12 +50,9 @@ namespace cmd { } void check_and_emit_module(state_t &state) { - if (!state.tower) { - check_source(state); - auto mod = codegen::emit_module(state.source.value(), state.ctx); - auto [t, _] = tw::default_tower::get(state.ctx, std::move(mod)); - state.tower = std::move(t); - } + check_source(state); + auto mod = codegen::emit_module(state.source.value(), state.ctx); + state.raise_tower(std::move(mod)); } // @@ -77,7 +93,7 @@ namespace cmd { void show_module(state_t &state) { check_and_emit_module(state); - llvm::outs() << state.tower->top().mod << "\n"; + llvm::outs() << state.current_module() << "\n"; } void show_symbols(state_t &state) { @@ -98,14 +114,23 @@ namespace cmd { } } + void show_link(state_t &state, const std::string &name) { + auto it = state.links.find(name); + if (it == state.links.end()) + return throw_error("Link with name: {0} not found!", name); + return render_link(it->second); + } + void show::run(state_t &state) const { auto what = get_param< kind_param >(params); + auto name = get_param< name_param >(params).value; switch (what) { case show_kind::source: return show_source(state); case show_kind::ast: return show_ast(state); case show_kind::module: return show_module(state); case show_kind::symbols: return show_symbols(state); case show_kind::pipelines: return show_pipelines(state); + case show_kind::link: return show_link(state, name); } }; @@ -170,17 +195,20 @@ namespace cmd { check_and_emit_module(state); std::string pipeline = get_param< pipeline_param >(params).value; + auto link_name = get_param< link_name_param >(params).value; + llvm::SmallVector< llvm::StringRef, 2 > passes; llvm::StringRef(pipeline).split(passes, ','); mlir::PassManager pm(&state.ctx); - auto th = state.tower->top(); + auto top = state.tower->top(); for (auto pass : passes) { if (mlir::failed(mlir::parsePassPipeline(pass, pm))) { throw_error("failed to parse pass pipeline"); } - th = state.tower->apply(th, pm); } + auto link = state.tower->apply(top, state.location_info, pm); + state.links.emplace(link_name, std::move(link)); } // diff --git a/tools/vast-repl/state.cpp b/tools/vast-repl/state.cpp new file mode 100644 index 0000000000..0d11352453 --- /dev/null +++ b/tools/vast-repl/state.cpp @@ -0,0 +1,17 @@ +// Copyright (c) 2024-present, Trail of Bits, Inc. + +#include "vast/repl/state.hpp" + +VAST_RELAX_WARNINGS +VAST_UNRELAX_WARNINGS + +namespace vast::repl { + + void state_t::raise_tower(owning_module_ref mod) { + tower.emplace(ctx, location_info, std::move(mod)); + } + + vast_module state_t::current_module() { + return tower->top().mod; + } +} // namespace vast::repl::codegen