Skip to content

Commit

Permalink
move to workspace_aliases to internal header
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
pratikvn and MarcelKoch committed May 7, 2024
1 parent 8ff343e commit d5a325c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_PUBLIC_CORE_BASE_WORKSPACE_ALIASES_HPP_
#define GKO_PUBLIC_CORE_BASE_WORKSPACE_ALIASES_HPP_
#ifndef GKO_CORE_BASE_WORKSPACE_ALIASES_HPP_
#define GKO_CORE_BASE_WORKSPACE_ALIASES_HPP_


#include <ginkgo/config.hpp>
Expand Down Expand Up @@ -252,4 +252,4 @@ class layout {
} // namespace gko


#endif // GKO_PUBLIC_CORE_BASE_WORKSPACE_ALIASES_HPP_
#endif // GKO_CORE_BASE_WORKSPACE_ALIASES_HPP_
57 changes: 57 additions & 0 deletions core/log/batch_logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,66 @@
#include <ginkgo/core/base/math.hpp>


#include "core/base/workspace_aliases.hpp"


namespace gko {
namespace batch {
namespace log {
namespace detail {


template <typename ValueType>
log_data<ValueType>::log_data(std::shared_ptr<const Executor> exec,
size_type num_batch_items)
: res_norms(exec), iter_counts(exec)
{
if (num_batch_items > 0) {
iter_counts.resize_and_reset(num_batch_items);
res_norms.resize_and_reset(num_batch_items);
} else {
GKO_INVALID_STATE("Invalid num batch items passed in");
}
}


template <typename ValueType>
log_data<ValueType>::log_data(std::shared_ptr<const Executor> exec,
size_type num_batch_items,
array<unsigned char>& workspace)
: res_norms(exec), iter_counts(exec)
{
const size_type reqd_workspace_size = num_batch_items * 32;

if (num_batch_items > 0 && !workspace.is_owning() &&
workspace.get_size() >= reqd_workspace_size) {
gko::detail::layout<2, 8> workspace_alias;
auto slot_1 = workspace_alias.get_slot(0);
auto slot_2 = workspace_alias.get_slot(1);
auto iter_alias = slot_1->create_alias<idx_type>(num_batch_items);
auto res_alias = slot_2->create_alias<real_type>(num_batch_items);

// Temporary storage mapping
auto err = workspace_alias.map_to_buffer(workspace.get_data(),
reqd_workspace_size);
GKO_ASSERT(err == 0);
iter_counts =
array<idx_type>::view(exec, num_batch_items, iter_alias.get());
res_norms =
array<real_type>::view(exec, num_batch_items, res_alias.get());
} else {
GKO_INVALID_STATE("invalid workspace or num batch items passed in");
}
}

#define GKO_DECLARE_LOG_DATA(_type) class log_data<_type>

GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(GKO_DECLARE_LOG_DATA);

#undef GKO_DECLARE_LOG_DATA


} // namespace detail


template <typename ValueType>
Expand Down
38 changes: 2 additions & 36 deletions include/ginkgo/core/log/batch_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/base/workspace_aliases.hpp>
#include <ginkgo/core/log/logger.hpp>


Expand All @@ -36,43 +35,10 @@ struct log_data final {
using real_type = remove_complex<ValueType>;
using idx_type = int;

log_data(std::shared_ptr<const Executor> exec, size_type num_batch_items)
: res_norms(exec), iter_counts(exec)
{
if (num_batch_items > 0) {
iter_counts.resize_and_reset(num_batch_items);
res_norms.resize_and_reset(num_batch_items);
} else {
GKO_INVALID_STATE("Invalid num batch items passed in");
}
}
log_data(std::shared_ptr<const Executor> exec, size_type num_batch_items);

log_data(std::shared_ptr<const Executor> exec, size_type num_batch_items,
array<unsigned char>& workspace)
: res_norms(exec), iter_counts(exec)
{
const size_type reqd_workspace_size = num_batch_items * 32;

if (num_batch_items > 0 && !workspace.is_owning() &&
workspace.get_size() >= reqd_workspace_size) {
gko::detail::layout<2, 8> workspace_alias;
auto slot_1 = workspace_alias.get_slot(0);
auto slot_2 = workspace_alias.get_slot(1);
auto iter_alias = slot_1->create_alias<idx_type>(num_batch_items);
auto res_alias = slot_2->create_alias<real_type>(num_batch_items);

// Temporary storage mapping
auto err = workspace_alias.map_to_buffer(workspace.get_data(),
reqd_workspace_size);
GKO_ASSERT(err == 0);
iter_counts =
array<idx_type>::view(exec, num_batch_items, iter_alias.get());
res_norms =
array<real_type>::view(exec, num_batch_items, res_alias.get());
} else {
GKO_INVALID_STATE("invalid workspace or num batch items passed in");
}
}
array<unsigned char>& workspace);

/**
* Stores residual norm values for every linear system in the batch.
Expand Down
1 change: 0 additions & 1 deletion include/ginkgo/ginkgo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
#include <ginkgo/core/base/utils.hpp>
#include <ginkgo/core/base/utils_helper.hpp>
#include <ginkgo/core/base/version.hpp>
#include <ginkgo/core/base/workspace_aliases.hpp>

#include <ginkgo/core/config/property_tree.hpp>

Expand Down

0 comments on commit d5a325c

Please sign in to comment.