Skip to content

Commit

Permalink
add layout update of reshape node in handle reshape pass
Browse files Browse the repository at this point in the history
  • Loading branch information
wilson-seok committed Dec 21, 2023
1 parent f6cd215 commit bc1c2a6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,19 @@ void handle_reshape::run(program& p) {
auto reshape_users = node->get_users();
for (const auto& user : reshape_users) {
auto reshape_output = std::make_shared<reorder>("reorder:_reshape_output_" + node->id(),
user->id(),
node->id(),
reshape_layout.format,
reshape_layout.data_type);
GPU_DEBUG_LOG << "reshape_handler: " << reshape_output->id
<< " input_info : " << reshape_output->dependencies().front().to_string() << std::endl;
auto& reshape_output_node = p.get_or_create(reshape_output);
p.add_intermediate(reshape_output_node,
*user,
*node,
reshape_output_node.get_dependencies().empty());
reshape_output_node.recalc_output_layout();
}
node->recalc_output_layout();
}
}
}
Expand Down
49 changes: 49 additions & 0 deletions src/plugins/intel_gpu/tests/unit/passes/handle_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,52 @@ TEST(handle_reshape, reshape_input_reorder) {
ASSERT_EQ(expected, actual) << " i = " << i;
}
}

TEST(handle_reshape, reshape_opt_out_layout_update) {
tests::random_generator rg(GET_SUITE_NAME);
auto& engine = get_test_engine();
if (!engine.get_device_info().supports_immad)
return;

auto input = engine.allocate_memory({ data_types::f16, format::b_fs_yx_fsv16, { 1, 512, 30, 4 } });
auto weights1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 512, 512, 3, 3 } });
auto weights2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 512, 512, 2, 2 } });
auto const1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 512, 1, 1 } });
auto const2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 512, 1, 1 } });

topology topology;
topology.add(data("weights1", weights1));
topology.add(data("weights2", weights2));
topology.add(data("const1", const1));
topology.add(data("const2", const2));
topology.add(input_layout("input", input->get_layout()));
topology.add(convolution("conv1", input_info("input"), "weights1", "", 1, {1, 1}, {1, 1}, {0, 0}, {0, 0}, false, ov::op::PadType::SAME_UPPER));
//topology.add(eltwise("add1", input_info("conv1"), input_info("const1"), eltwise_mode::sum));
topology.add(activation("relu1", input_info("conv1"), activation_func::relu));
topology.add(pooling("pool", input_info("relu1"), pooling_mode::max, { 2, 1 }, { 2, 1 }));
topology.add(convolution("conv2", input_info("pool"), "weights2", "", 1, {2, 1}, {1, 1}, {0, 0}, {0, 0}, false, ov::op::PadType::SAME_UPPER));
//topology.add(eltwise("add2", input_info("conv2"), input_info("const2"), eltwise_mode::sum));
topology.add(activation("relu2", input_info("conv2"), activation_func::relu));
topology.add(reshape("reshape", input_info("relu2"), false, {1,512,30}, {1,512,30}));
topology.add(permute("permute", input_info("reshape"), { 2, 0, 1 }));
topology.add(reorder("reorder", input_info("permute"), format::b_fs_yx_fsv16, data_types::f32));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
auto prog = program::build_program(engine, topology, config);

ASSERT_NE(prog, nullptr);
ASSERT_TRUE(has_node_with_type<reshape>(*prog));

ASSERT_TRUE(prog->get_node("reshape").can_be_optimized());
auto reshape_layout_in = prog->get_node("reshape").get_input_layouts()[0];
auto reshape_layout_out = prog->get_node("reshape").get_output_layout();

// At this moment transfomations insert reorder before reshape which
// converts tensor to default format with rank = reshape_out_rank
// Likely in the future we'll update that reorder so it will use reshape_input_rank
// After that expected in format will be bfzyx
// [Updated] get_preferred_format() updated to use 'in_lay_rank' instead of 'out_lay_rank' for preferred input format
ASSERT_EQ(reshape_layout_in.format, format::bfyx);
ASSERT_EQ(reshape_layout_out.format, format::bfyx);
}

0 comments on commit bc1c2a6

Please sign in to comment.