Skip to content

Commit 238f75f

Browse files
authored
Fix loadStateField evaluator to address #1092 (#1096)
Apparently linear access of Kokkos dynamic rank views is no longer working
1 parent 26d3dda commit 238f75f

File tree

2 files changed

+17
-26
lines changed

2 files changed

+17
-26
lines changed

src/evaluators/state/PHAL_LoadStateField.hpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,11 @@ class LoadStateFieldBase : public PHX::EvaluatorWithBaseImpl<Traits>,
3737

3838
using ExecutionSpace = typename PHX::Device::execution_space;
3939

40-
PHX::MDField<ScalarType> data;
40+
PHX::MDField<ScalarType> field;
4141
std::string fieldName;
4242
std::string stateName;
4343

4444
MDFieldMemoizer<Traits> memoizer;
45-
46-
MDFieldVectorRight<ScalarType> dataVec;
4745
};
4846

4947
template<typename EvalT, typename Traits>
@@ -65,13 +63,11 @@ class LoadStateField : public PHX::EvaluatorWithBaseImpl<Traits>,
6563

6664
using ExecutionSpace = typename PHX::Device::execution_space;
6765

68-
PHX::MDField<ParamScalarT> data;
66+
PHX::MDField<ParamScalarT> field;
6967
std::string fieldName;
7068
std::string stateName;
7169

7270
MDFieldMemoizer<Traits> memoizer;
73-
74-
MDFieldVectorRight<ParamScalarT> dataVec;
7571
};
7672

7773
// Shortcut names

src/evaluators/state/PHAL_LoadStateField_Def.hpp

+15-20
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ LoadStateFieldBase(const Teuchos::ParameterList& p)
2222
fieldName = p.get<std::string>("Field Name");
2323
stateName = p.get<std::string>("State Name");
2424

25-
PHX::MDField<ScalarType> f(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );
26-
data = f;
25+
field = PHX::MDField<ScalarType>(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );
2726

28-
this->addEvaluatedField(data);
27+
this->addEvaluatedField(field);
2928
this->setName("LoadStateField("+stateName+")"+PHX::print<EvalT>());
3029
}
3130

@@ -34,7 +33,7 @@ template<typename EvalT, typename Traits, typename ScalarType>
3433
void LoadStateFieldBase<EvalT, Traits, ScalarType>::postRegistrationSetup(typename Traits::SetupData d,
3534
PHX::FieldManager<Traits>& fm)
3635
{
37-
this->utils.setFieldData(data,fm);
36+
this->utils.setFieldData(field,fm);
3837

3938
d.fill_field_dependencies(this->dependentFields(),this->evaluatedFields());
4039
if (d.memoizer_active()) memoizer.enable_memoizer();
@@ -51,15 +50,13 @@ void LoadStateFieldBase<EvalT, Traits, ScalarType>::evaluateFields(typename Trai
5150
// whomever changed the data.
5251
const auto& stateToLoad = (*workset.stateArrayPtr)[stateName];
5352
auto stateData = stateToLoad.dev();
54-
const int stateToLoad_size = stateToLoad.size();
5553

56-
MDFieldVectorRight<ScalarType> g(data);
57-
dataVec = g;
54+
ALBANY_ASSERT (stateData.rank() <= 3, "Current implementation supports only views with rank up to 3. If larger rank is needed modify code below");
5855

5956
Kokkos::parallel_for(this->getName(),
60-
Kokkos::RangePolicy<ExecutionSpace>(0,data.size()),
61-
KOKKOS_CLASS_LAMBDA(const int i) {
62-
dataVec[i] = (i < stateToLoad_size) ? stateData(i) : 0.0;
57+
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<3>>({0,0,0},{stateData.extent(0),stateData.extent(1),stateData.extent(2)}),
58+
KOKKOS_CLASS_LAMBDA(const int i, const int j, const int k) {
59+
field.access(i,j,k) = stateData.access(i,j,k); //works also when rank is less than 3
6360
});
6461
}
6562

@@ -70,10 +67,10 @@ LoadStateField(const Teuchos::ParameterList& p)
7067
fieldName = p.get<std::string>("Field Name");
7168
stateName = p.get<std::string>("State Name");
7269

73-
PHX::MDField<ParamScalarT> f(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );
74-
data = f;
7570

76-
this->addEvaluatedField(data);
71+
field = PHX::MDField<ParamScalarT>(fieldName, p.get<Teuchos::RCP<PHX::DataLayout> >("State Field Layout") );
72+
73+
this->addEvaluatedField(field);
7774
this->setName("Load State Field"+PHX::print<EvalT>());
7875
}
7976

@@ -82,7 +79,7 @@ template<typename EvalT, typename Traits>
8279
void LoadStateField<EvalT, Traits>::postRegistrationSetup(typename Traits::SetupData d,
8380
PHX::FieldManager<Traits>& fm)
8481
{
85-
this->utils.setFieldData(data,fm);
82+
this->utils.setFieldData(field,fm);
8683

8784
d.fill_field_dependencies(this->dependentFields(),this->evaluatedFields());
8885
if (d.memoizer_active()) memoizer.enable_memoizer();
@@ -99,15 +96,13 @@ void LoadStateField<EvalT, Traits>::evaluateFields(typename Traits::EvalData wor
9996
// whomever changed the data.
10097
const auto& stateToLoad = (*workset.stateArrayPtr)[stateName];
10198
auto stateData = stateToLoad.dev();
102-
const int stateToLoad_size = stateToLoad.size();
10399

104-
MDFieldVectorRight<ParamScalarT> g(data);
105-
dataVec = g;
100+
ALBANY_ASSERT (stateData.rank() <= 3, "Current implementation supports only views with rank up to 3. If larger rank is needed modify code below");
106101

107102
Kokkos::parallel_for(this->getName(),
108-
Kokkos::RangePolicy<ExecutionSpace>(0,data.size()),
109-
KOKKOS_CLASS_LAMBDA(const int i) {
110-
dataVec[i] = (i < stateToLoad_size) ? stateData(i) : 0.0;
103+
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<3>>({0,0,0},{stateData.extent(0),stateData.extent(1),stateData.extent(2)}),
104+
KOKKOS_CLASS_LAMBDA(const int i, const int j, const int k) {
105+
field.access(i,j,k) = stateData.access(i,j,k); //works also when rank is less than 3
111106
});
112107
}
113108

0 commit comments

Comments
 (0)