Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named lambda operation #1667

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions core/test/base/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,4 +521,48 @@ TEST_F(ExecutorLogging, LogsOperation)
}


struct NameLogger : public gko::log::Logger {
protected:
void on_operation_launched(const gko::Executor* exec,
const gko::Operation* op) const override
{
op_name = op->get_name();
}

public:
mutable std::string op_name;
};


TEST(LambdaOperation, CanSetName)
{
auto name_logger = std::make_shared<NameLogger>();
auto exec = gko::ReferenceExecutor::create();
exec->add_logger(name_logger);

exec->run(
"name", [] {}, [] {}, [] {}, [] {}, [] {});

ASSERT_EQ("name", name_logger->op_name);
}


GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS


TEST(LambdaOperation, HasDefaultName)
{
auto name_logger = std::make_shared<NameLogger>();
auto exec = gko::ReferenceExecutor::create();
exec->add_logger(name_logger);

exec->run([] {}, [] {}, [] {}, [] {});

ASSERT_NE(nullptr, name_logger->op_name.c_str());
}


GKO_END_DISABLE_DEPRECATION_WARNINGS


} // namespace
74 changes: 63 additions & 11 deletions include/ginkgo/core/base/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,42 @@ class Executor : public log::EnableLogging<Executor> {
*/
template <typename ClosureOmp, typename ClosureCuda, typename ClosureHip,
typename ClosureDpcpp>
GKO_DEPRECATED(
"Please use the overload with std::string as first parameter.")
void run(const ClosureOmp& op_omp, const ClosureCuda& op_cuda,
const ClosureHip& op_hip, const ClosureDpcpp& op_dpcpp) const
{
LambdaOperation<ClosureOmp, ClosureCuda, ClosureHip, ClosureDpcpp> op(
op_omp, op_cuda, op_hip, op_dpcpp);
LambdaOperation<ClosureOmp, ClosureOmp, ClosureCuda, ClosureHip,
ClosureDpcpp>
op(op_omp, op_cuda, op_hip, op_dpcpp);
this->run(op);
}

/**
* Runs one of the passed in functors, depending on the Executor type.
*
* @tparam ClosureReference type of op_ref
* @tparam ClosureOmp type of op_omp
* @tparam ClosureCuda type of op_cuda
* @tparam ClosureHip type of op_hip
* @tparam ClosureDpcpp type of op_dpcpp
*
* @param name the name of the operation
* @param op_ref functor to run in case of a ReferenceExecutor
* @param op_omp functor to run in case of a OmpExecutor
* @param op_cuda functor to run in case of a CudaExecutor
* @param op_hip functor to run in case of a HipExecutor
* @param op_dpcpp functor to run in case of a DpcppExecutor
*/
template <typename ClosureReference, typename ClosureOmp,
typename ClosureCuda, typename ClosureHip, typename ClosureDpcpp>
void run(std::string name, const ClosureReference& op_ref,
const ClosureOmp& op_omp, const ClosureCuda& op_cuda,
const ClosureHip& op_hip, const ClosureDpcpp& op_dpcpp) const
{
LambdaOperation<ClosureReference, ClosureOmp, ClosureCuda, ClosureHip,
ClosureDpcpp>
op(std::move(name), op_ref, op_omp, op_cuda, op_hip, op_dpcpp);
this->run(op);
}

Expand Down Expand Up @@ -1105,10 +1136,21 @@ class Executor : public log::EnableLogging<Executor> {
* @tparam ClosureHip the type of the third functor
* @tparam ClosureDpcpp the type of the fourth functor
*/
template <typename ClosureOmp, typename ClosureCuda, typename ClosureHip,
typename ClosureDpcpp>
template <typename ClosureReference, typename ClosureOmp,
typename ClosureCuda, typename ClosureHip, typename ClosureDpcpp>
class LambdaOperation : public Operation {
public:
LambdaOperation(std::string name, const ClosureReference& op_ref,
const ClosureOmp& op_omp, const ClosureCuda& op_cuda,
const ClosureHip& op_hip, const ClosureDpcpp& op_dpcpp)
: name_(std::move(name)),
op_ref_(op_ref),
op_omp_(op_omp),
op_cuda_(op_cuda),
op_hip_(op_hip),
op_dpcpp_(op_dpcpp)
{}

/**
* Creates an LambdaOperation object from four functors.
*
Expand All @@ -1121,10 +1163,8 @@ class Executor : public log::EnableLogging<Executor> {
*/
LambdaOperation(const ClosureOmp& op_omp, const ClosureCuda& op_cuda,
const ClosureHip& op_hip, const ClosureDpcpp& op_dpcpp)
: op_omp_(op_omp),
op_cuda_(op_cuda),
op_hip_(op_hip),
op_dpcpp_(op_dpcpp)
: LambdaOperation("unnamed", op_omp, op_omp, op_cuda, op_hip,
op_dpcpp)
{}

void run(std::shared_ptr<const OmpExecutor>) const override
Expand All @@ -1134,7 +1174,7 @@ class Executor : public log::EnableLogging<Executor> {

void run(std::shared_ptr<const ReferenceExecutor>) const override
{
op_omp_();
op_ref_();
}

void run(std::shared_ptr<const CudaExecutor>) const override
Expand All @@ -1152,7 +1192,11 @@ class Executor : public log::EnableLogging<Executor> {
op_dpcpp_();
}

const char* get_name() const noexcept override { return name_.c_str(); }

private:
std::string name_;
ClosureReference op_ref_;
ClosureOmp op_omp_;
ClosureCuda op_cuda_;
ClosureHip op_hip_;
Expand Down Expand Up @@ -1230,8 +1274,6 @@ class ExecutorBase : public Executor {
friend class ReferenceExecutor;

public:
using Executor::run;

void run(const Operation& op) const override
{
this->template log<log::Logger::operation_launched>(this, &op);
Expand Down Expand Up @@ -1341,6 +1383,8 @@ class OmpExecutor : public detail::ExecutorBase<OmpExecutor>,
friend class detail::ExecutorBase<OmpExecutor>;

public:
using Executor::run;

/**
* Creates a new OmpExecutor.
*/
Expand Down Expand Up @@ -1418,6 +1462,8 @@ using DefaultExecutor = OmpExecutor;
*/
class ReferenceExecutor : public OmpExecutor {
public:
using Executor::run;

static std::shared_ptr<ReferenceExecutor> create(
std::shared_ptr<CpuAllocatorBase> alloc =
std::make_shared<CpuAllocator>())
Expand Down Expand Up @@ -1492,6 +1538,8 @@ class CudaExecutor : public detail::ExecutorBase<CudaExecutor>,
friend class detail::ExecutorBase<CudaExecutor>;

public:
using Executor::run;

/**
* Creates a new CudaExecutor.
*
Expand Down Expand Up @@ -1727,6 +1775,8 @@ class HipExecutor : public detail::ExecutorBase<HipExecutor>,
friend class detail::ExecutorBase<HipExecutor>;

public:
using Executor::run;

/**
* Creates a new HipExecutor.
*
Expand Down Expand Up @@ -1942,6 +1992,8 @@ class DpcppExecutor : public detail::ExecutorBase<DpcppExecutor>,
friend class detail::ExecutorBase<DpcppExecutor>;

public:
using Executor::run;

/**
* Creates a new DpcppExecutor.
*
Expand Down
22 changes: 22 additions & 0 deletions test/base/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,28 @@ TEST_F(Executor, RunsCorrectHostOperation)
}


TEST_F(Executor, RunsCorrectLambdaOperationWithReferenceExecutor)
{
int value = 0;
auto ref_lambda = [&value]() { value = reference::value; };
auto omp_lambda = [&value]() { value = omp::value; };
auto cuda_lambda = [&value]() { value = cuda::value; };
auto hip_lambda = [&value]() { value = hip::value; };
auto dpcpp_lambda = [&value]() { value = dpcpp::value; };

exec->run("test", ref_lambda, omp_lambda, cuda_lambda, hip_lambda,
dpcpp_lambda);

ASSERT_EQ(GKO_DEVICE_NAMESPACE::value, value);
}


#ifndef GKO_COMPILING_REFERENCE


GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS


TEST_F(Executor, RunsCorrectLambdaOperation)
{
int value = 0;
Expand All @@ -107,4 +126,7 @@ TEST_F(Executor, RunsCorrectLambdaOperation)
}


GKO_END_DISABLE_DEPRECATION_WARNINGS


#endif // GKO_COMPILING_REFERENCE
Loading