Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Graph] Implement Graph and node queries #348

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 91 additions & 4 deletions sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,19 @@ enum class graph_support_level {
emulated
};

enum class node_type {
empty,
subgraph,
kernel,
memcpy,
memset,
memfill,
prefetch,
memadvise,
ext_oneapi_barrier,
host_task,
Bensuo marked this conversation as resolved.
Show resolved Hide resolved
};

namespace property {

namespace graph {
Expand Down Expand Up @@ -353,7 +366,18 @@ struct graphs_support;
} // namespace device
} // namespace info

class node {};
class node {
public:
node() = delete;

node_type get_type() const;

std::vector<node> get_predecessors() const;

std::vector<node> get_successors() const;

static node get_node_from_event(event nodeEvent);
};

// State of a graph
enum class graph_state {
Expand Down Expand Up @@ -389,6 +413,9 @@ public:
void make_edge(node& src, node& dest);

void print_graph(std::string path, bool verbose = false) const;

std::vector<node> get_nodes() const;
std::vector<node> get_root_nodes() const;
};

template<>
Expand Down Expand Up @@ -459,12 +486,56 @@ edges.

The `node` class provides the {crs}[common reference semantics].

==== Node Member Functions

Table {counter: tableNumber}. Member functions of the `node` class.
[cols="2a,a"]
|===
|Member Function|Description

|
[source,c++]
----
namespace sycl::ext::oneapi::experimental {
class node {};
}
node_type get_type() const;
----
|Returns a value representing the type of command this node represents.

|
[source,c++]
----
std::vector<node> get_predecessors() const;
----
|Returns a list of the predecessor nodes which this node directly depends on.

|
[source,c++]
----
std::vector<node> get_successors() const;
Bensuo marked this conversation as resolved.
Show resolved Hide resolved
----
|Returns a list of the successor nodes which directly depend on this node.

|
[source,c++]
----
static node get_node_from_event(event nodeEvent);
----
|Finds the node associated with an event created from a submission to a queue
in the recording state.

Parameters:

* `nodeEvent` - Event returned from a submission to a queue in the recording
state.

Returns: Graph node that was created when the command that returned
`nodeEvent` was submitted.

Exceptions:

* Throws with error code `invalid` if `nodeEvent` is not associated with a
graph node.

|===

==== Depends-On Property

Expand Down Expand Up @@ -775,6 +846,22 @@ Exceptions:
* Throws synchronously with error code `invalid` if the path is invalid or
the file extension is not supported or if the write operation failed.

|
[source,c++]
----
std::vector<node> get_nodes() const;
----
|Returns a list of all the nodes present in the graph in the order that they
were added.

|
[source,c++]
----
std::vector<node> get_root_nodes() const;
----
|Returns a list of all nodes in the graph which have no dependencies in the
mfrancepillois marked this conversation as resolved.
Show resolved Hide resolved
order they were added to the graph.

|===

Table {counter: tableNumber}. Member functions of the `command_graph` class for queue recording.
Expand Down
35 changes: 35 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,37 @@ enum class graph_state {
executable, ///< In executable state, the graph is ready to execute.
};

enum class node_type {
empty = 0,
subgraph,
kernel,
memcpy,
memset,
memfill,
prefetch,
memadvise,
ext_oneapi_barrier,
host_task
};

/// Class representing a node in the graph, returned by command_graph::add().
class __SYCL_EXPORT node {
public:
node() = delete;

/// Get the type of command associated with this node.
node_type get_type() const;

/// Get a list of all the node dependencies of this node.
std::vector<node> get_predecessors() const;

/// Get a list of all nodes which depend on this node.
std::vector<node> get_successors() const;

/// Get the node associated with a SYCL event returned from a queue recording
/// submission.
static node get_node_from_event(event nodeEvent);

private:
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}

Expand Down Expand Up @@ -253,6 +282,12 @@ class __SYCL_EXPORT modifiable_command_graph {
/// as kernel args or memory access where applicable.
void print_graph(const std::string path, bool verbose = false) const;

/// Get a list of all nodes contained in this graph.
std::vector<node> get_nodes() const;

/// Get a list of all root nodes (nodes without dependencies) in this graph.
std::vector<node> get_root_nodes() const;

protected:
/// Constructor used internally by the runtime.
/// @param Impl Detail implementation class to construct object with.
Expand Down
10 changes: 10 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,14 @@ class __SYCL_EXPORT handler {
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
getCommandGraph() const;

/// Sets the user facing node type of this operation, used for operations
/// which are recorded to a graph. Since some operations may actually be a
/// different type than the user submitted, e.g. a fill() which is performed
/// as a kernel submission.
/// @param Type The actual type based on what handler functions the user
/// called.
void setUserFacingNodeType(ext::oneapi::experimental::node_type Type);

public:
handler(const handler &) = delete;
handler(handler &&) = delete;
Expand Down Expand Up @@ -2722,6 +2730,7 @@ class __SYCL_EXPORT handler {
checkIfPlaceholderIsBoundToHandler(Dst);

throwIfActionIsCreated();
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
// TODO add check:T must be an integral scalar value or a SYCL vector type
static_assert(isValidTargetForExplicitOp(AccessTarget),
"Invalid accessor target for the fill method.");
Expand Down Expand Up @@ -2760,6 +2769,7 @@ class __SYCL_EXPORT handler {
/// \param Count is the number of times to fill Pattern into Ptr.
template <typename T> void fill(void *Ptr, const T &Pattern, size_t Count) {
throwIfActionIsCreated();
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
static_assert(is_device_copyable<T>::value,
"Pattern must be device copyable");
parallel_for<__usmfill<T>>(range<1>(Count), [=](id<1> Index) {
Expand Down
Loading
Loading