Skip to content

Commit

Permalink
Review comments 5
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Jan 6, 2025
1 parent db23c6d commit d5c02c7
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ namespace snippets {
namespace lowered {
namespace pass {

/**
* @brief Default function to enable RegSpill insertion
* @return True if RegSpill is required around a certain op, False otherwise.
*/
inline bool brgemm_reg_spill(const ExpressionPtr& expr) {
return ov::is_type<snippets::op::Brgemm>(expr->get_node());
}

/**
* @interface InsertRegSpills
* @brief Insert RegSpill and RegRestore operations for binary call emitters to comply with ABI conventions.
Expand All @@ -20,11 +28,14 @@ namespace pass {
class InsertRegSpills : public Pass {
public:
OPENVINO_RTTI("InsertRegSpills", "0", Pass)
explicit InsertRegSpills(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
explicit InsertRegSpills(RegManager& reg_manager,
std::function<bool(const ExpressionPtr&)> needs_reg_spill = brgemm_reg_spill) :
m_reg_manager(reg_manager), m_needs_reg_spill(std::move(needs_reg_spill)) {}
bool run(LinearIR& linear_ir) override;

private:
RegManager& m_reg_manager;
std::function<bool(const ExpressionPtr&)> m_needs_reg_spill;
};

} // namespace pass
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/lowered/reg_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class RegManager {
return res;
}

inline const LiveInterval& get_live_range(const Reg& reg) {
inline const LiveInterval& get_live_range(const Reg& reg) const {
OPENVINO_ASSERT(m_reg_live_range.count(reg), "Live range for this reg was not set");
return m_reg_live_range[reg];
return m_reg_live_range.at(reg);
}
inline const std::map<Reg, LiveInterval>& get_live_range_map() const {
return m_reg_live_range;
Expand Down
1 change: 0 additions & 1 deletion src/common/snippets/include/snippets/op/reg_spill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class RegSpillBase : public ov::op::Op {
RegSpillBase() = default;
virtual const std::set<Reg>& get_regs_to_spill() const = 0;
bool visit_attributes(AttributeVisitor& visitor) override;
protected:
};
class RegSpillEnd;
/**
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ AssignRegisters::RegMap AssignRegisters::assign_regs_manually(const LinearIR& li
OPENVINO_ASSERT(gpr_pool.size() > static_cast<size_t>(max_buffer_group),
"Not enough gp registers in the pool to perform manual assignment");
const auto& assigned = *std::next(gpr_pool.begin(), reg_group);
const auto& out_reg = expr->get_output_port_descriptor(0)->get_reg();
const auto& out_reg = buffer->get_output_port_descriptor(0)->get_reg();
manually_assigned[out_reg] = assigned;
// Buffer abstract registers validation:
bool all_equal = true;
Expand Down
5 changes: 1 addition & 4 deletions src/common/snippets/src/lowered/pass/insert_reg_spills.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ namespace pass {
bool InsertRegSpills::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::InsertRegSpills")

auto needs_reg_spill = [](const ExpressionPtr& expr) {
return ov::is_type<snippets::op::Brgemm>(expr->get_node());
};
bool modified = false;
for (auto it = linear_ir.begin(); it != linear_ir.end(); it++) {
const auto& expr = *it;
if (!needs_reg_spill(expr))
if (!m_needs_reg_spill(expr))
continue;
auto start_it = std::prev(it);
auto stop_it = std::next(it);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jit_reg_spill_begin_emitter::jit_reg_spill_begin_emitter(dnnl::impl::cpu::x64::j
: jit_emitter(h, isa) {
const auto& reg_spill_node = ov::as_type_ptr<snippets::op::RegSpillBegin>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(reg_spill_node, "expects RegSpillBegin expression");
const auto& rinfo = expr->get_reg_info();
const auto& rinfo = expr->get_reg_info();
m_regs_to_spill = std::set<snippets::Reg>(rinfo.second.begin(), rinfo.second.end());
m_abi_reg_spiller = std::make_shared<EmitABIRegSpills>(h);
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
Expand Down

0 comments on commit d5c02c7

Please sign in to comment.