Skip to content

Commit

Permalink
Merge pull request #5259 from ikbuibui/shortcircuit_binning
Browse files Browse the repository at this point in the history
Short circuit axis binning if any previous axis is invalid
  • Loading branch information
psychocoderHPC authored Feb 4, 2025
2 parents 3b860d5 + a5bd660 commit 8b64dbb
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 49 deletions.
34 changes: 18 additions & 16 deletions include/picongpu/plugins/binning/Axis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@

#pragma once

#include "picongpu/plugins/binning/DomainInfo.hpp"
#include "picongpu/plugins/binning/UnitConversion.hpp"

#include <array>
#include <cstdint>
#include <string>
#include <type_traits>
#include <vector>
#include <utility>

namespace picongpu
{
Expand Down Expand Up @@ -82,7 +80,7 @@ namespace picongpu
// @todo mark functions which are mandatory for each type of axis


template<typename T_BinningFunctor>
template<typename T_AttrFunctor>
class GenericAxis
{
public:
Expand All @@ -91,17 +89,24 @@ namespace picongpu
struct GenericAxisKernel
{
uint32_t n_bins;
T_BinningFunctor getBinIdx;
T_AttrFunctor getAttributeValue;

constexpr GenericAxisKernel(uint32_t n_bins, T_BinningFunctor binFunctor)
constexpr GenericAxisKernel(uint32_t n_bins, T_AttrFunctor attrFunctor)
: n_bins{n_bins}
, getBinIdx{binFunctor}
, getAttributeValue{attrFunctor}
{
}

// Forwards arguments to getAttributeValue
template<typename... Args>
ALPAKA_FN_ACC auto getBinIdx(const Args&... args) const
{
return std::make_pair(true, getAttributeValue(args...));
}
};
GenericAxisKernel gAK;

GenericAxis(uint32_t n_bins, T_BinningFunctor binFunctor) : gAK{n_bins, binFunctor}
GenericAxis(uint32_t n_bins, T_AttrFunctor attrFunctor) : gAK{n_bins, attrFunctor}
{
}

Expand All @@ -119,20 +124,17 @@ namespace picongpu
std::array<double, numUnits> units;
struct BoolAxisKernel
{
uint32_t n_bins;
static constexpr uint32_t n_bins = 2u;
T_AttrFunctor getAttributeValue;

constexpr BoolAxisKernel(T_AttrFunctor attrFunctor) : n_bins{2u}, getAttributeValue{attrFunctor}
constexpr BoolAxisKernel(T_AttrFunctor attrFunctor) : getAttributeValue{attrFunctor}
{
}

template<typename T_Worker, typename T_Particle>
ALPAKA_FN_ACC uint32_t
getBinIdx(const DomainInfo& domainInfo, const T_Worker& worker, const T_Particle& particle) const
template<typename... Args>
ALPAKA_FN_ACC std::pair<bool, bool> getBinIdx(const Args&... args) const
{
// static cast to bool ?
// bool val = getAttributeValue(worker, particle);
return 0;
return {true, getAttributeValue(args...)};
}
};

Expand Down
39 changes: 17 additions & 22 deletions include/picongpu/plugins/binning/BinningFunctors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,15 @@ namespace picongpu

DINLINE FunctorParticle() = default;

template<
typename T_Worker,
typename T_HistBox,
typename T_DepositionFunctor,
typename T_AxisTuple,
typename T_Particle,
uint32_t T_nAxes>
template<typename T_HistBox, typename T_DepositionFunctor, uint32_t T_nAxes>
DINLINE void operator()(
T_Worker const& worker,
auto const& worker,
T_HistBox histBox,
T_DepositionFunctor const& quantityFunctor,
T_AxisTuple const& axes,
auto const& axes,
DomainInfo const& domainInfo,
DataSpace<T_nAxes> const& extents,
T_Particle const& particle) const
auto const& particle) const
{
using DepositionType = typename T_HistBox::ValueType;

Expand All @@ -64,7 +58,15 @@ namespace picongpu
{
uint32_t i = 0;
// This assumes n_bins and getBinIdx exist
((binsDataspace[i++] = tupleArgs.getBinIdx(domainInfo, worker, particle, validIdx)), ...);
validIdx
= ((
[&]
{
auto [isValid, binIdx] = tupleArgs.getBinIdx(domainInfo, worker, particle);
binsDataspace[i++] = binIdx;
return isValid;
}())
&& ...);
},
axes);

Expand All @@ -85,21 +87,14 @@ namespace picongpu

HINLINE BinningFunctor() = default;

template<
typename T_Worker,
typename TParticlesBox,
typename T_HistBox,
typename T_DepositionFunctor,
typename T_AxisTuple,
typename T_Mapping,
uint32_t T_nAxes>
template<typename T_HistBox, typename T_DepositionFunctor, typename T_Mapping, uint32_t T_nAxes>
DINLINE void operator()(
T_Worker const& worker,
auto const& worker,
T_HistBox binningBox,
TParticlesBox particlesBox,
auto particlesBox,
pmacc::DataSpace<simDim> const& localOffset,
pmacc::DataSpace<simDim> const& globalOffset,
T_AxisTuple const& axisTuple,
auto const& axisTuple,
T_DepositionFunctor const& quantityFunctor,
DataSpace<T_nAxes> const& extents,
auto const& filter,
Expand Down
8 changes: 3 additions & 5 deletions include/picongpu/plugins/binning/axis/LinearAxis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ namespace picongpu
}

template<typename T_Worker, typename T_Particle>
ALPAKA_FN_ACC uint32_t getBinIdx(
ALPAKA_FN_ACC std::pair<bool, uint32_t> getBinIdx(
const DomainInfo& domainInfo,
const T_Worker& worker,
const T_Particle& particle,
bool& validIdx) const
const T_Particle& particle) const
{
auto val = getAttributeValue(domainInfo, worker, particle);

Expand Down Expand Up @@ -128,8 +127,7 @@ namespace picongpu
else
binIdx = nBins - 1;
}
validIdx = validIdx && enableBinning;
return binIdx;
return {enableBinning, binIdx};
}
};

Expand Down
9 changes: 3 additions & 6 deletions include/picongpu/plugins/binning/axis/LogAxis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,10 @@ namespace picongpu


template<typename T_Worker, typename T_Particle>
ALPAKA_FN_ACC uint32_t getBinIdx(
ALPAKA_FN_ACC std::pair<bool, uint32_t> getBinIdx(
const DomainInfo& domainInfo,
const T_Worker& worker,
const T_Particle& particle,
bool& validIdx) const
const T_Particle& particle) const
{
auto val = getAttributeValue(domainInfo, worker, particle);

Expand Down Expand Up @@ -155,9 +154,7 @@ namespace picongpu
binIdx = nBins - 1;
}
}

validIdx = validIdx && enableBinning;
return binIdx;
return {enableBinning, binIdx};
}
};

Expand Down

0 comments on commit 8b64dbb

Please sign in to comment.