Skip to content

Commit

Permalink
updated ScaleDownMultipleLayers
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Nov 11, 2024
1 parent 5a65569 commit 6491951
Showing 1 changed file with 28 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ using ov::pass::pattern::op::Or;
ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float scale_factor) {
MATCHER_SCOPE(ScaleDownSingleLayer);

auto is_single_matmul = [](const Output<Node>& output) {
auto are_not_mutiple_matmuls_connected = [](const Output<Node>& output) {
auto matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(output.get_node_shared_ptr());
auto input = matmul->get_input_node_shared_ptr(0);
if (input->get_output_size() > 1)
return true;
size_t user_matmul_count = 0;
for (const auto& u : input->get_users()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u);
Expand All @@ -76,7 +78,7 @@ ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float
auto activation_m = any_input();
auto weights_m = any_input();
auto convolution_m = wrap_type<ov::op::v1::Convolution>({activation_m, weights_m});
auto matmul_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m}, is_single_matmul);
auto matmul_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m}, are_not_mutiple_matmuls_connected);
auto scaled_op_m = std::make_shared<Or>(OutputVector{convolution_m, matmul_m});

ov::Shape scale_const_shape = {1};
Expand Down Expand Up @@ -153,27 +155,35 @@ ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float
this->register_matcher(m, callback);
}

// scale_down
// |
// input ==> input
// / \ / \_
// MatMul_a MatMul_b MatMul_a MatMul_b
// | |
// scale_up scale_up
ov::pass::activations_scaling::ScaleDownMultipleLayers::ScaleDownMultipleLayers(float scale_factor) {
MATCHER_SCOPE(ScaleDownMultipleLayers);

auto is_mutiple_matmuls = [](const Output<Node>& output) {
auto are_mutiple_matmuls_connected = [](const Output<Node>& output) {
auto matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(output.get_node_shared_ptr());
auto input = matmul->get_input_node_shared_ptr(0);
if (input->get_output_size() > 1)
return false;
size_t user_matmul_count = 0;
for (const auto& u : input->get_users()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u);
if (!matmul_user)
continue;
user_matmul_count++;
}
return !ov::is_type<ov::op::v1::Multiply>(input) &&
input->get_users().size() > 1 &&
return !ov::is_type<ov::op::v1::Multiply>(input) && input->get_users().size() > 1 &&
input->get_users().size() == user_matmul_count;
};

auto activation_m = any_input();
auto weights_m = any_input();
auto scaled_op_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m}, is_mutiple_matmuls);
auto scaled_op_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m}, are_mutiple_matmuls_connected);

ov::Shape scale_const_shape = {1};
std::vector<float> scale_down_value = {1.f / scale_factor};
Expand All @@ -190,18 +200,20 @@ ov::pass::activations_scaling::ScaleDownMultipleLayers::ScaleDownMultipleLayers(
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

auto scaled_op = std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(scaled_op_m).get_node_shared_ptr());
auto scaled_op
= std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(scaled_op_m).get_node_shared_ptr());
if (!scaled_op || transformation_callback(scaled_op))
return false;

auto input_node = scaled_op->get_input_node_shared_ptr(0);
auto input_node = scaled_op->input(0).get_source_output();
auto scale_down = std::make_shared<ov::op::v1::Multiply>(
input_node, (input_node->get_element_type() == ov::element::f32) ? scale_down_const_f32 : scale_down_const_f16);
input_node,
(input_node.get_element_type() == ov::element::f32) ? scale_down_const_f32 : scale_down_const_f16);
scale_down->set_friendly_name(scaled_op->get_friendly_name() + "_scale_down");
ov::copy_runtime_info(scaled_op, scale_down);

for (const auto& u : input_node->get_users()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u);
for (const auto& u : input_node.get_target_inputs()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u.get_node()->shared_from_this());
if (matmul_user) {
matmul_user->input(0).replace_source_output(scale_down);
auto child = matmul_user->get_output_target_inputs(0).begin()->get_node();
Expand All @@ -210,14 +222,16 @@ ov::pass::activations_scaling::ScaleDownMultipleLayers::ScaleDownMultipleLayers(
auto target_inputs = add->get_output_target_inputs(0);
auto scale_down_bias = std::make_shared<ov::op::v1::Multiply>(
add->input(1).get_source_output(),
(add->input(1).get_element_type() == ov::element::f32) ? scale_down_const_f32 : scale_down_const_f16);
(add->input(1).get_element_type() == ov::element::f32) ? scale_down_const_f32
: scale_down_const_f16);
scale_down_bias->set_friendly_name(add->get_friendly_name() + "_scale_down");
ov::copy_runtime_info(add, scale_down_bias);
add->input(1).replace_source_output(scale_down_bias->output(0));

auto scale_up = register_new_node<ov::op::v1::Multiply>(
add->output(0),
(add->output(0).get_element_type() == ov::element::f32) ? scale_up_const_f32 : scale_up_const_f16);
(add->output(0).get_element_type() == ov::element::f32) ? scale_up_const_f32
: scale_up_const_f16);
scale_up->set_friendly_name(matmul_user->get_friendly_name() + "_scale_up");
ov::copy_runtime_info(matmul_user, scale_up);
for (auto& in : target_inputs) {
Expand All @@ -228,7 +242,7 @@ ov::pass::activations_scaling::ScaleDownMultipleLayers::ScaleDownMultipleLayers(
auto scale_up = register_new_node<ov::op::v1::Multiply>(
matmul_user->output(0),
(matmul_user->output(0).get_element_type() == ov::element::f32) ? scale_up_const_f32
: scale_up_const_f16);
: scale_up_const_f16);
scale_up->set_friendly_name(matmul_user->get_friendly_name() + "_scale_up");
ov::copy_runtime_info(matmul_user, scale_up);
for (auto& in : target_inputs) {
Expand All @@ -244,7 +258,6 @@ ov::pass::activations_scaling::ScaleDownMultipleLayers::ScaleDownMultipleLayers(
this->register_matcher(m, callback);
}


// MulMulAddTransformation makes the target pattern to be easy to be merged with followig nodes.
//
// input_a const_a input_b const_b input_a (const_a/const_b)
Expand Down

0 comments on commit 6491951

Please sign in to comment.