Skip to content

Commit

Permalink
[CPU]fix sdpa test with shapeof
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel committed Jan 7, 2025
1 parent bd217a3 commit b307de7
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
result << "Prc=" << inType << "_";
result << "HasShapeOf=" << hasShapeof << "_";
result << "quantKeyByChannel=" << quantKeyByChannel << "_";
result << "groupSize= " << groupSize << "_";
result << "groupSize=" << groupSize << "_";
result << "TransposeOrder=";
result << "(";
for (const auto& itr : transposeOrder) {
Expand All @@ -92,7 +92,6 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
void SetUp() override {
ElementType inType;
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeOf;
std::tie(inType, inputShapeAndOrders, hasShapeOf, quantKeyByChannel, keyGroupSize) = this->GetParam();
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
transposeOrder = inputShapeAndOrders.second;
Expand Down Expand Up @@ -372,6 +371,8 @@ TEST_P(ConcatSDPTransposeTest, CompareWithRefs) {
// Transformation TSShapeOfForward will change:
// ?->transpose->shapeof ==> ?-->shapeof->gather
// |->transpose
size_t expectedGatherCount = hasShapeOf ? 1 : 0;
std::cout << "ConcatSDPTEST|" << expectedGatherCount << std::endl;
CheckNumberOfNodesWithType(compiledModel, "Gather", hasShapeOf ? 1 : 0);
auto expectedOutputs = run_test(functionRefs);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
Expand Down Expand Up @@ -467,7 +468,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeByChannelTest,
ConcatSDPTransposeTest,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::ValuesIn(shapesWithGreedySearch),
::testing::Values(true),
::testing::Values(false),
::testing::Values(true),
::testing::Values(8)),
ConcatSDPTransposeTest::getTestCaseName);
Expand Down

0 comments on commit b307de7

Please sign in to comment.