Skip to content

Commit

Permalink
Fix misaligned add for log by storing 64bit int
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Mar 22, 2024
1 parent c82d81c commit 1cfe7e8
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 25 deletions.
6 changes: 4 additions & 2 deletions common/cuda_hip/log/batch_logger.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ template <typename RealType>
class SimpleFinalLogger final {
public:
using real_type = RealType;
using idx_type = int64;

SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters)
SimpleFinalLogger(real_type* const batch_residuals,
idx_type* const batch_iters)
: final_residuals_{batch_residuals}, final_iters_{batch_iters}
{}

Expand All @@ -24,5 +26,5 @@ public:

private:
real_type* const final_residuals_;
int* const final_iters_;
idx_type* const final_iters_;
};
6 changes: 3 additions & 3 deletions core/log/batch_logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ namespace log {

template <typename ValueType>
void BatchConvergence<ValueType>::on_batch_solver_completed(
const array<int>& iteration_count,
const array<int64>& iteration_count,
const array<remove_complex<ValueType>>& residual_norm) const
{
if (this->iteration_count_.get_size() == 0) {
this->iteration_count_ = gko::array<int>(iteration_count.get_executor(),
iteration_count.get_size());
this->iteration_count_ = gko::array<int64>(
iteration_count.get_executor(), iteration_count.get_size());
}
if (this->residual_norm_.get_size() == 0) {
this->residual_norm_ = gko::array<remove_complex<ValueType>>(
Expand Down
6 changes: 4 additions & 2 deletions dpcpp/log/batch_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ template <typename RealType>
class SimpleFinalLogger final {
public:
using real_type = remove_complex<RealType>;
using idx_type = int64;

SimpleFinalLogger(real_type* const batch_residuals, int* const batch_iters)
SimpleFinalLogger(real_type* const batch_residuals,
indx_type* const batch_iters)
: final_residuals_{batch_residuals}, final_iters_{batch_iters}
{}

Expand All @@ -43,7 +45,7 @@ class SimpleFinalLogger final {

private:
real_type* const final_residuals_;
int* const final_iters_;
idx_type* const final_iters_;
};


Expand Down
24 changes: 13 additions & 11 deletions include/ginkgo/core/log/batch_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ namespace detail {
template <typename ValueType>
struct log_data final {
using real_type = remove_complex<ValueType>;
using idx_type = int64;

log_data(std::shared_ptr<const Executor> exec, size_type num_batch_items)
: res_norms(exec), iter_counts(exec)
{
const size_type workspace_size =
num_batch_items * (sizeof(real_type) + sizeof(int));
num_batch_items * (sizeof(real_type) + sizeof(idx_type));
if (num_batch_items > 0) {
iter_counts.resize_and_reset(num_batch_items);
res_norms.resize_and_reset(num_batch_items);
Expand All @@ -52,16 +53,17 @@ struct log_data final {
: res_norms(exec), iter_counts(exec)
{
const size_type workspace_size =
num_batch_items * (sizeof(real_type) + sizeof(int));
num_batch_items * (sizeof(real_type) + sizeof(idx_type));
if (num_batch_items > 0 && !workspace.is_owning() &&
workspace.get_size() >= workspace_size) {
iter_counts =
array<int>::view(exec, num_batch_items,
reinterpret_cast<int*>(workspace.get_data()));
iter_counts = array<idx_type>::view(
exec, num_batch_items,
reinterpret_cast<idx_type*>(workspace.get_data()));
res_norms = array<real_type>::view(
exec, num_batch_items,
reinterpret_cast<real_type*>(workspace.get_data() +
(sizeof(int) * num_batch_items)));
reinterpret_cast<real_type*>(
workspace.get_data() +
(sizeof(idx_type) * num_batch_items)));
} else {
GKO_INVALID_STATE("invalid workspace or num batch items passed in");
}
Expand All @@ -75,7 +77,7 @@ struct log_data final {
/**
* Stores convergence iteration counts for every matrix in the batch
*/
array<int> iter_counts;
array<idx_type> iter_counts;
};


Expand All @@ -101,7 +103,7 @@ class BatchConvergence final : public gko::log::Logger {
using mask_type = gko::log::Logger::mask_type;

void on_batch_solver_completed(
const array<int>& iteration_count,
const array<int64>& iteration_count,
const array<real_type>& residual_norm) const override;

/**
Expand All @@ -127,7 +129,7 @@ class BatchConvergence final : public gko::log::Logger {
/**
* @return The number of iterations for entire batch
*/
const array<int>& get_num_iterations() const noexcept
const array<int64>& get_num_iterations() const noexcept
{
return iteration_count_;
}
Expand All @@ -147,7 +149,7 @@ class BatchConvergence final : public gko::log::Logger {
{}

private:
mutable array<int> iteration_count_{};
mutable array<int64> iteration_count_{};
mutable array<real_type> residual_norm_{};
};

Expand Down
4 changes: 2 additions & 2 deletions include/ginkgo/core/log/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ public: \
* @param residual_norms the array storing the residual norms.
*/
virtual void on_batch_solver_completed(
const array<int>& iters, const array<double>& residual_norms) const
const array<int64>& iters, const array<double>& residual_norms) const
{}

/**
Expand All @@ -577,7 +577,7 @@ public: \
* @param residual_norms the array storing the residual norms.
*/
virtual void on_batch_solver_completed(
const array<int>& iters, const array<float>& residual_norms) const
const array<int64>& iters, const array<float>& residual_norms) const
{}

public:
Expand Down
2 changes: 1 addition & 1 deletion include/ginkgo/core/solver/batch_solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class EnableBatchSolver
preconditioner_ = std::move(id);
}
const size_type workspace_size = system_matrix->get_num_batch_items() *
(sizeof(real_type) + sizeof(int));
(sizeof(real_type) + sizeof(int64));
workspace_.set_executor(exec);
workspace_.resize_and_reset(workspace_size);
}
Expand Down
12 changes: 8 additions & 4 deletions reference/log/batch_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ namespace batch_log {
template <typename RealType>
class SimpleFinalLogger final {
public:
using real_type = RealType;
using idx_type = int64;

/**
* Constructor
*
Expand All @@ -31,7 +34,8 @@ class SimpleFinalLogger final {
* @param batch_iters final iteration counts for each
* linear system in the batch.
*/
SimpleFinalLogger(RealType* const batch_residuals, int* const batch_iters)
SimpleFinalLogger(real_type* const batch_residuals,
idx_type* const batch_iters)
: final_residuals_{batch_residuals}, final_iters_{batch_iters}
{}

Expand All @@ -43,15 +47,15 @@ class SimpleFinalLogger final {
* @param res_norm Norm of final residual norm
*/
void log_iteration(const size_type batch_idx, const int iter,
const RealType res_norm)
const real_type res_norm)
{
final_iters_[batch_idx] = iter;
final_residuals_[batch_idx] = res_norm;
}

private:
RealType* const final_residuals_;
int* const final_iters_;
real_type* const final_residuals_;
idx_type* const final_iters_;
};


Expand Down

0 comments on commit 1cfe7e8

Please sign in to comment.