Skip to content

Commit

Permalink
Merge pull request #481 from llaniewski/unstable-refactor
Browse files Browse the repository at this point in the history
Further refactoring of the unstable branch
  • Loading branch information
llaniewski authored Jan 8, 2024
2 parents b01c27f + eed8630 commit 6faf88e
Show file tree
Hide file tree
Showing 34 changed files with 582 additions and 958 deletions.
40 changes: 27 additions & 13 deletions src/ArbLattice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,37 @@ void ArbLattice::SetFirstTabs(int tab_in, int tab_out) {
setSnapOut(tab_out);
}

void ArbLattice::getQuantity(int quant, real_t* host_tab, real_t scale) {



std::vector<big_flag_t> ArbLattice::getFlags() const { throw std::runtime_error{"UNIMPLEMENTED"}; return {}; };
std::vector<real_t> ArbLattice::getField(const Model::Field& f) { throw std::runtime_error{"UNIMPLEMENTED"}; return {}; };
std::vector<real_t> ArbLattice::getFieldAdj(const Model::Field& f) { throw std::runtime_error{"UNIMPLEMENTED"}; return {}; };
void ArbLattice::setFlags(const std::vector<big_flag_t>& x) { throw std::runtime_error{"UNIMPLEMENTED"}; return; };
void ArbLattice::setField(const Model::Field& f, const std::vector<real_t>& x) { throw std::runtime_error{"UNIMPLEMENTED"}; return; };
void ArbLattice::setFieldAdjZero(const Model::Field& f) { throw std::runtime_error{"UNIMPLEMENTED"}; return; };


std::vector<real_t> ArbLattice::getQuantity(const Model::Quantity& q, real_t scale) {
size_t size = getLocalSize();
int comp = q.getComp();
std::vector<real_t> ret(size*comp);
setSnapIn(Snap);
#ifdef ADJOINT
setAdjSnapIn(aSnap);
#endif
launcher.getQuantity(quant, host_tab, scale, data);
launcher.getQuantity(q.id, ret.data(), scale, data);
return ret;
}

std::vector<real_t> ArbLattice::getCoord(const Model::Coord& d, real_t scale) {
size_t size = getLocalSize();
std::vector<real_t> ret(size);
for (size_t i = 0; i < size; ++i) {
size_t j = local_permutation[i];
ret[j] = connect.coord(d.id, i)*scale;
}
return ret;
}

#include <iostream>
Expand Down Expand Up @@ -718,17 +743,6 @@ static int loadImpl(const std::string& filename, storage_t* device_ptr, size_t s
CudaMemcpy(device_ptr, tab.data(), size * sizeof(storage_t), CudaMemcpyHostToDevice);
return EXIT_SUCCESS;
}

/// TODO section
int ArbLattice::loadComp(const std::string& filename, const std::string& comp) {
throw std::runtime_error{"UNIMPLEMENTED"};
return -1;
}
int ArbLattice::saveComp(const std::string& filename, const std::string& comp) const {
throw std::runtime_error{"UNIMPLEMENTED"};
return -1;
}

void ArbLattice::savePrimal(const std::string& filename, int snap_ind) const {
if (saveImpl(filename, getSnapPtr(snap_ind), sizes.snaps_pitch * NF)) throw std::runtime_error{"savePrimal failed"};
}
Expand Down
16 changes: 13 additions & 3 deletions src/ArbLattice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,19 @@ class ArbLattice : public LatticeBase {
int reinitialize(size_t num_snaps_, const std::map<std::string, int>& setting_zones, pugi::xml_node arb_node); /// Init if passed args differ from those passed at construction or the last call to reinitialize (avoid duplicating work)
size_t getLocalSize() const final { return connect.chunk_end - connect.chunk_begin; }
size_t getGlobalSize() const final { return connect.num_nodes_global; }
void getQuantity(int quant, real_t* host_tab, real_t scale); /// Write GPU data to \p host_tab


virtual std::vector<int> shape() const { return {static_cast<int>(getLocalSize())}; };
virtual std::vector<real_t> getQuantity(const Model::Quantity& q, real_t scale = 1) ;
virtual std::vector<big_flag_t> getFlags() const;
virtual std::vector<real_t> getField(const Model::Field& f);
virtual std::vector<real_t> getFieldAdj(const Model::Field& f);
virtual std::vector<real_t> getCoord(const Model::Coord& q, real_t scale = 1);

virtual void setFlags(const std::vector<big_flag_t>& x);
virtual void setField(const Model::Field& f, const std::vector<real_t>& x);
virtual void setFieldAdjZero(const Model::Field& f);

const ArbVTUGeom& getVTUGeom() const { return vtu_geom; }
Span<const flag_t> getNodeTypes() const { return {node_types_host.data(), node_types_host.size()}; } /// Get host view of node types (permuted)
const ArbLatticeConnectivity& getConnectivity() const { return connect; }
Expand Down Expand Up @@ -112,8 +124,6 @@ class ArbLattice : public LatticeBase {
storage_t* getAdjointSnapPtr(int snap_ind); /// Get device pointer to the specified adjoint snap, snap_ind must be 0 or 1
#endif

int loadComp(const std::string& filename, const std::string& comp) final; /// TODO
int saveComp(const std::string& filename, const std::string& comp) const final; /// TODO
int loadPrimal(const std::string& filename, int snap_ind) final; /// TODO
void savePrimal(const std::string& filename, int snap_ind) const final; /// TODO
#ifdef ADJOINT
Expand Down
8 changes: 7 additions & 1 deletion src/ArbLatticeAccess.hpp.Rt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ real_to_storage = function(val,f) storage_convert("real_to_storage",val,f)
#ifndef ARBLATTICEACCESS_HPP
#define ARBLATTICEACCESS_HPP

#include <cassert>

#include "ArbLatticeContainer.hpp"
#include "StorageConversions.h"
#include "cross.h"
Expand Down Expand Up @@ -113,10 +115,14 @@ for (s in rows(all_stages)) { ?>
template<class Node>
CudaDeviceFunction void pop<?%s s$suffix ?>(Node& node) const {
<?R
for (d in rows(Density)[s$load.densities]) {
dens = Density;
dens$load = s$load.densities;
for (d in rows(dens)) if (d$load) {
f = rows(Fields)[[match(d$field, Fields$name)]]
dp = c(-d$dx, -d$dy, -d$dz)
cat(paste0(" node.", d$name, " = load_", f$nicename, "<"), paste(dp, collapse=", "), " >();\n")
} else if (!is.na(d$default)) { ?>
<?%s paste("node",d$name,sep=".") ?> = <?%f d$default ?>; <?R
} ?> }
template<class Node>
CudaDeviceFunction void push<?%s s$suffix ?>(const Node& node) const {
Expand Down
12 changes: 3 additions & 9 deletions src/CartConnectivity.hpp.Rt → src/CartConnectivity.hpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
<?R
source("conf.R")
c_header();
?>

#ifndef CARTCONNECTIVITY_HPP
#define CARTCONNECTIVITY_HPP

#include "cross.h"
#include "pinned_allocator.hpp"
#include "Region.h"

/// Information on connectivity of a processor
struct NodeInfo {
lbRegion region; ///< Local Lattice region <?R
for (m in Margin) { ?>
int <?%s m$side ?>; ///< MPI rank of the processor on [<?%2d m$dx ?>,<?%2d m$dy?>,<?%2d m$dz?>] side <?R
} ?>
lbRegion region; ///< Local Lattice region
int side[27];
};

struct CartConnectivity {
Expand Down
Loading

0 comments on commit 6faf88e

Please sign in to comment.