Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Tobias Ribizel <[email protected]>
Co-authored-by: Yu-Hsiang Tasi <[email protected]>
  • Loading branch information
3 people committed Oct 10, 2023
1 parent 0078777 commit e9e1905
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 41 deletions.
91 changes: 80 additions & 11 deletions core/test/base/batch_lin_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,36 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/log/logger.hpp>


namespace {


struct DummyLogger : gko::log::Logger {
DummyLogger()
: gko::log::Logger(gko::log::Logger::batch_linop_factory_events_mask)
{}

void on_batch_linop_factory_generate_started(
const gko::batch::BatchLinOpFactory*,
const gko::batch::BatchLinOp*) const override
{
batch_linop_factory_generate_started++;
}

void on_batch_linop_factory_generate_completed(
const gko::batch::BatchLinOpFactory*, const gko::batch::BatchLinOp*,
const gko::batch::BatchLinOp*) const override
{
batch_linop_factory_generate_completed++;
}

int mutable batch_linop_factory_generate_started = 0;
int mutable batch_linop_factory_generate_completed = 0;
};


class DummyBatchLinOp : public gko::batch::EnableBatchLinOp<DummyBatchLinOp>,
public gko::EnableCreateMethod<DummyBatchLinOp> {
public:
Expand All @@ -63,33 +88,25 @@ class EnableBatchLinOp : public ::testing::Test {
protected:
EnableBatchLinOp()
: ref{gko::ReferenceExecutor::create()},
ref2{gko::ReferenceExecutor::create()},
op{DummyBatchLinOp::create(ref2,
gko::batch_dim<2>(1, gko::dim<2>{3, 5}))},
op2{DummyBatchLinOp::create(ref2,
gko::batch_dim<2>(2, gko::dim<2>{3, 5}))}
op{DummyBatchLinOp::create(ref,
gko::batch_dim<2>(1, gko::dim<2>{3, 5}))}
{}

std::shared_ptr<const gko::ReferenceExecutor> ref;
std::shared_ptr<const gko::ReferenceExecutor> ref2;
std::unique_ptr<DummyBatchLinOp> op;
std::unique_ptr<DummyBatchLinOp> op2;
};


TEST_F(EnableBatchLinOp, KnowsNumBatchItems)
{
ASSERT_EQ(op->get_num_batch_items(), 1);
ASSERT_EQ(op2->get_num_batch_items(), 2);
}


TEST_F(EnableBatchLinOp, KnowsItsSizes)
{
auto op1_sizes = gko::batch_dim<2>(1, gko::dim<2>{3, 5});
auto op2_sizes = gko::batch_dim<2>(2, gko::dim<2>{3, 5});
ASSERT_EQ(op->get_size(), op1_sizes);
ASSERT_EQ(op2->get_size(), op2_sizes);
}


Expand Down Expand Up @@ -123,9 +140,14 @@ class DummyBatchLinOpWithFactory

class EnableBatchLinOpFactory : public ::testing::Test {
protected:
EnableBatchLinOpFactory() : ref{gko::ReferenceExecutor::create()} {}
EnableBatchLinOpFactory()
: ref{gko::ReferenceExecutor::create()},
logger{std::make_shared<DummyLogger>()}

{}

std::shared_ptr<const gko::ReferenceExecutor> ref;
std::shared_ptr<DummyLogger> logger;
};


Expand Down Expand Up @@ -161,4 +183,51 @@ TEST_F(EnableBatchLinOpFactory, PassesParametersToBatchLinOp)
}


TEST_F(EnableBatchLinOpFactory, FactoryGenerateIsLogged)
{
auto before_logger = *logger;
auto factory = DummyBatchLinOpWithFactory<>::build().on(ref);
factory->add_logger(logger);
factory->generate(
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 5})));

ASSERT_EQ(logger->batch_linop_factory_generate_started,
before_logger.batch_linop_factory_generate_started + 1);
ASSERT_EQ(logger->batch_linop_factory_generate_completed,
before_logger.batch_linop_factory_generate_completed + 1);
}


TEST_F(EnableBatchLinOpFactory, WithLoggersWorksAndPropagates)
{
auto before_logger = *logger;
auto factory =
DummyBatchLinOpWithFactory<>::build().with_loggers(logger).on(ref);
auto op = factory->generate(
DummyBatchLinOp::create(ref, gko::batch_dim<2>(1, gko::dim<2>{3, 5})));

ASSERT_EQ(logger->batch_linop_factory_generate_started,
before_logger.batch_linop_factory_generate_started + 1);
ASSERT_EQ(logger->batch_linop_factory_generate_completed,
before_logger.batch_linop_factory_generate_completed + 1);
}


TEST_F(EnableBatchLinOpFactory, CopiesLinOpToOtherExecutor)
{
auto ref2 = gko::ReferenceExecutor::create();
auto dummy = gko::share(
DummyBatchLinOp::create(ref2, gko::batch_dim<2>(1, gko::dim<2>{3, 5})));
auto factory = DummyBatchLinOpWithFactory<>::build().with_value(6).on(ref);

auto op = factory->generate(dummy);

ASSERT_EQ(op->get_executor(), ref);
ASSERT_EQ(op->get_parameters().value, 6);
ASSERT_EQ(op->op_->get_executor(), ref);
ASSERT_NE(op->op_.get(), dummy.get());
ASSERT_TRUE(dynamic_cast<const DummyBatchLinOp*>(op->op_.get()));
}


} // namespace
46 changes: 24 additions & 22 deletions include/ginkgo/core/base/batch_lin_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,13 @@ namespace batch {
*
* A key difference between the LinOp and the BatchLinOp class is that the apply
* between BatchLinOps is no longer supported. The user can apply a BatchLinOp
* to a batch::MultiVector but not to any general BatchLinOp. Therefore, the
* BatchLinOp serves only as a base class providing necessary core functionality
* from Polymorphic object and store the dimensions of the batched object.
* to a batch::MultiVector but not to any general BatchLinOp. This apply to a
* batch::MultiVector is handled by the concrete LinOp and may be moved to tbe

Check warning on line 77 in include/ginkgo/core/base/batch_lin_op.hpp

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"tbe" should be "the".
* base BatchLinOp class in the future.
*
* Therefore, the BatchLinOp serves only as a base class providing necessary
* core functionality from Polymorphic object and store the dimensions of the
* batched object.
*
* @ref BatchLinOp
*/
Expand All @@ -84,24 +88,24 @@ class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
/**
* Returns the number of items in the batch operator.
*
* @return number of items in the batch operator
* @return number of items in the batch operator
*/
size_type get_num_batch_items() const noexcept
{
return size_.get_num_batch_items();
return get_size().get_num_batch_items();
}

/**
* Returns the common size of the batch items.
*
* @return the common size stored
* @return the common size stored
*/
dim<2> get_common_size() const { return size_.get_common_size(); }
dim<2> get_common_size() const { return get_size().get_common_size(); }

/**
* Returns the size of the batch operator.
*
* @return size of the batch operator, a batch_dim object
* @return size of the batch operator, a batch_dim object
*/
const batch_dim<2>& get_size() const noexcept { return size_; }

Expand All @@ -117,27 +121,28 @@ class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
* Creates a batch operator storing items of uniform sizes.
*
* @param exec the executor where all the operations are performed
* @param num_batch_items the number of batch items to be stored in the
* operator
* @param size the common size of the items in the batched operator
* @param batch_size the size the batched operator, as a batch_dim object
*/
explicit BatchLinOp(std::shared_ptr<const Executor> exec,
const size_type num_batch_items = 0,
const dim<2>& common_size = dim<2>{})
: EnableAbstractPolymorphicObject<BatchLinOp>(exec),
size_{num_batch_items > 0 ? batch_dim<2>(num_batch_items, common_size)
: batch_dim<2>{}}
const batch_dim<2>& batch_size)
: EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
{}

/**
* Creates a batch operator storing items of uniform sizes.
*
* @param exec the executor where all the operations are performed
* @param batch_size the size the batched operator, as a batch_dim object
* @param num_batch_items the number of batch items to be stored in the
* operator
* @param size the common size of the items in the batched operator
*/
explicit BatchLinOp(std::shared_ptr<const Executor> exec,
const batch_dim<2>& batch_size)
: EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
const size_type num_batch_items = 0,
const dim<2>& common_size = dim<2>{})
: BatchLinOp{std::move(exec),
num_batch_items > 0
? batch_dim<2>(num_batch_items, common_size)
: batch_dim<2>{}}
{}

private:
Expand Down Expand Up @@ -234,9 +239,6 @@ class EnableBatchLinOp
public:
using EnablePolymorphicObject<ConcreteBatchLinOp,
PolymorphicBase>::EnablePolymorphicObject;

protected:
GKO_ENABLE_SELF(ConcreteBatchLinOp);
};


Expand Down
20 changes: 12 additions & 8 deletions include/ginkgo/core/log/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ class stopping_status;


namespace batch {


class BatchLinOp;
class BatchLinOpFactory;


} // namespace batch


Expand Down Expand Up @@ -455,9 +459,9 @@ public: \
* @warning This on_iteration_complete function that this macro declares is
* deprecated. Please use the version with the stopping information.
*/
[[deprecated(
"Please use the version with the additional stopping "
"information.")]] virtual void
[
[deprecated("Please use the version with the additional stopping "
"information.")]] virtual void
on_iteration_complete(const LinOp* solver, const size_type& it,
const LinOp* r, const LinOp* x = nullptr,
const LinOp* tau = nullptr) const
Expand All @@ -476,9 +480,9 @@ public: \
* @warning This on_iteration_complete function that this macro declares is
* deprecated. Please use the version with the stopping information.
*/
[[deprecated(
"Please use the version with the additional stopping "
"information.")]] virtual void
[
[deprecated("Please use the version with the additional stopping "
"information.")]] virtual void
on_iteration_complete(const LinOp* solver, const size_type& it,
const LinOp* r, const LinOp* x, const LinOp* tau,
const LinOp* implicit_tau_sq) const
Expand Down Expand Up @@ -810,8 +814,8 @@ class EnableLogging : public PolymorphicBase {
template <size_type Event, typename ConcreteLoggableT>
struct propagate_log_helper<
Event, ConcreteLoggableT,
xstd::void_t<
decltype(std::declval<ConcreteLoggableT>().get_executor())>> {
xstd::void_t<decltype(
std::declval<ConcreteLoggableT>().get_executor())>> {
template <typename... Args>
static void propagate_log(const ConcreteLoggableT* loggable,
Args&&... args)
Expand Down

0 comments on commit e9e1905

Please sign in to comment.