Skip to content

Commit

Permalink
add function to access device data from field + unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrdar committed Oct 7, 2024
1 parent 051dfc4 commit 9fe3543
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 48 deletions.
2 changes: 2 additions & 0 deletions src/atlas/array/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class Array : public util::Object {

const std::vector<int>& stridesf() const { return spec_.stridesf(); }

const ArrayStrides& device_stridesf() const { return spec_.strides(); }

bool contiguous() const { return spec_.contiguous(); }

bool hasDefaultLayout() const { return spec_.hasDefaultLayout(); }
Expand Down
7 changes: 7 additions & 0 deletions src/atlas/array/ArraySpec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,19 @@ const std::vector<int>& ArraySpec::stridesf() const {
return stridesf_;
}

const std::vector<int>& ArraySpec::device_stridesf() const {
return device_stridesf_;
}

void ArraySpec::allocate_fortran_specs() {
shapef_.resize(rank_);
stridesf_.resize(rank_);
device_stridesf_.resize(rank_);
device_stridesf_[rank_ - 1] = stridesf_[rank_ - 1];
for (idx_t j = 0; j < rank_; ++j) {
shapef_[j] = shape_[rank_ - 1 - layout_[j]];
stridesf_[j] = strides_[rank_ -1 - layout_[j]];
device_stridesf_[rank_ - j - 1] = device_stridesf_[rank_ - j] * shapef_[rank_ - j - 1];
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/atlas/array/ArraySpec.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ArraySpec {
ArrayAlignment alignment_;
std::vector<int> shapef_;
std::vector<int> stridesf_;
std::vector<int> device_stridesf_;
bool contiguous_;
bool default_layout_;

Expand Down Expand Up @@ -64,6 +65,7 @@ class ArraySpec {
const ArrayLayout& layout() const { return layout_; }
const std::vector<int>& shapef() const;
const std::vector<int>& stridesf() const;
const std::vector<int>& device_stridesf() const;
bool contiguous() const { return contiguous_; }
bool hasDefaultLayout() const { return default_layout_; }

Expand Down
3 changes: 3 additions & 0 deletions src/atlas/field/detail/FieldImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class FieldImpl : public util::Object {
/// @brief Strides of this field in Fortran style (reverse order of C style)
const std::vector<int>& stridesf() const { return array_->stridesf(); }

/// @brief Strides of this field on the device in Fortran style (reverse order of C style)
const std::vector<int>& device_stridesf() const { return array_->device_stridesf(); }

/// @brief Shape of this field (reverse order of Fortran style)
const array::ArrayShape& shape() const { return array_->shape(); }

Expand Down
28 changes: 28 additions & 0 deletions src/atlas/field/detail/FieldInterface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ void atlas__Field__data_specf(FieldImpl* This, Value*& data, int& rank, int*& sh
rank = This->shapef().size();
}

template <typename Value>
void atlas__Field__device_data_specf(FieldImpl* This, Value*& data, int& rank, int*& shapef, int*& stridesf) {
ATLAS_ASSERT(This != nullptr, "Cannot access data of uninitialised atlas_Field");
if (This->datatype() != array::make_datatype<Value>()) {
throw_Exception("Datatype mismatch for accessing field data");
}
data = This->array().device_data<Value>();
shapef = const_cast<int*>(This->shapef().data());
stridesf = const_cast<int*>(This->device_stridesf().data());
rank = This->shapef().size();
}

template <typename Value>
FieldImpl* atlas__Field__wrap_specf(const char* name, Value data[], int rank, int shapef[], int stridesf[]) {
array::ArrayShape shape;
Expand Down Expand Up @@ -189,6 +201,22 @@ void atlas__Field__data_double_specf(FieldImpl* This, double*& data, int& rank,
atlas__Field__data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_int_specf(FieldImpl* This, int*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_long_specf(FieldImpl* This, long*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_float_specf(FieldImpl* This, float*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

void atlas__Field__device_data_double_specf(FieldImpl* This, double*& data, int& rank, int*& shapef, int*& stridesf) {
atlas__Field__device_data_specf(This, data, rank, shapef, stridesf);
}

int atlas__Field__host_needs_update(const FieldImpl* This) {
return This->hostNeedsUpdate();
}
Expand Down
8 changes: 8 additions & 0 deletions src/atlas/field/detail/FieldInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ void atlas__Field__data_float_specf(FieldImpl* This, float*& field_data, int& ra
int*& field_stridesf);
void atlas__Field__data_double_specf(FieldImpl* This, double*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_int_specf(FieldImpl* This, int*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_long_specf(FieldImpl* This, long*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_float_specf(FieldImpl* This, float*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
void atlas__Field__device_data_double_specf(FieldImpl* This, double*& field_data, int& rank, int*& field_shapef,
int*& field_stridesf);
util::Metadata* atlas__Field__metadata(FieldImpl* This);
const functionspace::FunctionSpaceImpl* atlas__Field__functionspace(FieldImpl* This);
void atlas__Field__rename(FieldImpl* This, const char* name);
Expand Down
Loading

0 comments on commit 9fe3543

Please sign in to comment.