Skip to content

Commit 904541d

Browse files
authored
[SYCL][Graph] Add test for linear graph optimizations (#20340)
Follow-up to #20291 - Adds unit test to ensure linear graphs are not tracking sync points and non-linear graphs are tracking sync points. - Forward declares a test friend class in `exec_graph_impl` to be able to inspect the optimizations performed on private members. --------- Signed-off-by: Matthew Michel <[email protected]>
1 parent ddac33d commit 904541d

File tree

4 files changed

+143
-0
lines changed

4 files changed

+143
-0
lines changed

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
#include <shared_mutex> // for shared_mutex
2424
#include <vector> // for vector
2525

26+
// For testing of graph internals
27+
class GraphImplTest;
28+
2629
namespace sycl {
2730
inline namespace _V1 {
2831
// Forward declarations
@@ -732,6 +735,10 @@ class exec_graph_impl {
732735
}
733736

734737
private:
738+
// Test helper class for inspecting private graph internals to validate
739+
// under-the-hood behavior and optimizations.
740+
friend class ::GraphImplTest;
741+
735742
/// Create a command-group for the node and add it to command-buffer by going
736743
/// through the scheduler.
737744
/// @param CommandBuffer Command-buffer to add node to as a command.

sycl/unittests/Extensions/CommandGraph/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ add_sycl_unittest(CommandGraphExtensionTests OBJECT
1313
TopologicalSort.cpp
1414
Update.cpp
1515
Properties.cpp
16+
LinearGraphOptimization.cpp
1617
)

sycl/unittests/Extensions/CommandGraph/Common.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ using namespace sycl::ext::oneapi;
2727

2828
using sycl::detail::getSyclObjImpl;
2929

30+
// Implement the test friend class forward declared in graph_impl.hpp so tests
31+
// can access private members to analyze internal optimizations (partitions,
32+
// sync points).
33+
class GraphImplTest {
34+
using exec_graph_impl = experimental::detail::exec_graph_impl;
35+
using partition = experimental::detail::partition;
36+
37+
public:
38+
static int NumPartitionsInOrder(const exec_graph_impl &Impl) {
39+
int NumInOrder = 0;
40+
for (const auto &P : Impl.MPartitions) {
41+
if (P && P->MIsInOrderGraph)
42+
++NumInOrder;
43+
}
44+
return NumInOrder;
45+
}
46+
static int NumSyncPoints(const exec_graph_impl &Impl) {
47+
return Impl.MSyncPoints.size();
48+
}
49+
};
50+
3051
// Common Test fixture
3152
class CommandGraphTest : public ::testing::Test {
3253
public:
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//==--------------------- LinearGraphOptimization.cpp ----------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Test for linear graph optimization which skips creating and tracking UR sync
10+
// points. Optimization is an internal implementation detail, validated through
11+
// inspecting private members of exec_graph_impl. Test achieves two goals: 1)
12+
// Validates that linear partitions in graphs are optimized to avoid using UR
13+
// sync points 2) Validates that non-linear partitions contain the expected
14+
// number of sync points
15+
16+
#include "Common.hpp"
17+
#include <optional>
18+
19+
using namespace sycl;
20+
using namespace sycl::ext::oneapi::experimental;
21+
using namespace sycl::ext::oneapi::experimental::detail;
22+
23+
// Helper to build a linear chain of N kernels on a queue inside graph capture.
24+
static void BuildLinearChain(queue &Queue, bool IsInOrderQueue, int N) {
25+
std::optional<sycl::event> Event;
26+
for (int I = 0; I < N; ++I) {
27+
if (IsInOrderQueue) {
28+
experimental::single_task<TestKernel>(Queue, []() {});
29+
} else {
30+
Event = Queue.submit([&](handler &h) {
31+
if (Event) {
32+
h.depends_on(*Event);
33+
}
34+
h.single_task<TestKernel>([]() {});
35+
});
36+
}
37+
}
38+
}
39+
40+
// Validate linear optimization invariants on an executable graph.
41+
static void ValidateLinearExec(exec_graph_impl &Impl, int NumLinearChains) {
42+
EXPECT_EQ(GraphImplTest::NumPartitionsInOrder(Impl), NumLinearChains);
43+
EXPECT_EQ(GraphImplTest::NumSyncPoints(Impl), 0);
44+
}
45+
46+
TEST_F(CommandGraphTest, LinearInOrderQueue) {
47+
sycl::property_list Props{sycl::property::queue::in_order{}};
48+
queue InOrderQ{Dev, Props};
49+
50+
experimental::command_graph<graph_state::modifiable> G{InOrderQ.get_context(),
51+
InOrderQ.get_device()};
52+
G.begin_recording(InOrderQ);
53+
BuildLinearChain(InOrderQ, /*IsInOrderQueue=*/true, /*N=*/3);
54+
InOrderQ.submit([&](sycl::handler &cgh) { cgh.host_task([]() {}); });
55+
BuildLinearChain(InOrderQ, /*IsInOrderQueue=*/true, /*N=*/4);
56+
G.end_recording(InOrderQ);
57+
58+
auto Exec = G.finalize();
59+
auto &Impl = *getSyclObjImpl(Exec);
60+
ValidateLinearExec(Impl, /*InOrderPartitions=*/3);
61+
}
62+
63+
TEST_F(CommandGraphTest, LinearOutOfOrderQueue) {
64+
// Out-of-order queue but we submit a strict linear dependency chain by
65+
// adding explicit depends_on between each node to achieve linearity.
66+
queue OOOQ{Dev};
67+
experimental::command_graph<graph_state::modifiable> G{OOOQ.get_context(),
68+
OOOQ.get_device()};
69+
G.begin_recording(OOOQ);
70+
BuildLinearChain(OOOQ, /*IsInOrderQueue=*/false, /*N=*/6);
71+
G.end_recording(OOOQ);
72+
73+
auto Exec = G.finalize();
74+
auto &Impl = *getSyclObjImpl(Exec);
75+
ValidateLinearExec(Impl, /*InOrderPartitions=*/1);
76+
}
77+
78+
// Ensures non-linear graphs are creating and tracking sync points internally
79+
// for proper scheduling and that the linear optimization is not improperly
80+
// applied.
81+
TEST_F(CommandGraphTest, NonLinearOutOfOrderQueue) {
82+
queue Q{Dev};
83+
experimental::command_graph<graph_state::modifiable> G{Q.get_context(),
84+
Q.get_device()};
85+
G.begin_recording(Q);
86+
// Root node
87+
event Root = Q.submit([&](handler &h) { h.single_task<TestKernel>([] {}); });
88+
// Two parallel branches depending on Root
89+
event A = Q.submit([&](handler &h) {
90+
h.depends_on(Root);
91+
h.single_task<TestKernel>([] {});
92+
});
93+
event B = Q.submit([&](handler &h) {
94+
h.depends_on(Root);
95+
h.single_task<TestKernel>([] {});
96+
});
97+
// Join node depends on both A and B
98+
Q.submit([&](handler &h) {
99+
h.depends_on(A);
100+
h.depends_on(B);
101+
h.single_task<TestKernel>([] {});
102+
});
103+
G.end_recording(Q);
104+
105+
auto Exec = G.finalize();
106+
auto &Impl = *getSyclObjImpl(Exec);
107+
108+
const int NumLinear = GraphImplTest::NumPartitionsInOrder(Impl);
109+
const int NumSyncPoints = GraphImplTest::NumSyncPoints(Impl);
110+
111+
// We should track a sync point per node for a total of 4
112+
EXPECT_EQ(NumSyncPoints, 4);
113+
EXPECT_EQ(NumLinear, 0);
114+
}

0 commit comments

Comments
 (0)