Skip to content

Commit

Permalink
Enabled shots in HPC virtualization
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Claudino <[email protected]>
  • Loading branch information
danclaudino committed Dec 18, 2023
1 parent d1edaa7 commit 59bf2ee
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 86 deletions.
18 changes: 16 additions & 2 deletions quantum/plugins/decorators/hpc-virtualization/MPIProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,27 @@ Copyright (C) 2018-2021 Oak Ridge National Laboratory (UT-Battelle) **/

#include "MPIProxy.hpp"

#include "mpi.h"

#include <cstdlib>

#include <iostream>
#include <algorithm>

template <>
MPI_Datatype MPIDataTypeResolver<int>::getMPIDatatype() {
return MPI_INT;
}

template <>
MPI_Datatype MPIDataTypeResolver<double>::getMPIDatatype() {
return MPI_DOUBLE;
}

template <>
MPI_Datatype MPIDataTypeResolver<char>::getMPIDatatype() {
return MPI_CHAR;
}

namespace xacc {

//Temporary buffers:
Expand Down Expand Up @@ -129,4 +143,4 @@ std::shared_ptr<ProcessGroup> ProcessGroup::split(int my_subgroup) const
return subgroup;
}

} //namespace xacc
} //namespace xacc
53 changes: 52 additions & 1 deletion quantum/plugins/decorators/hpc-virtualization/MPIProxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Copyright (C) 2018-2021 Oak Ridge National Laboratory (UT-Battelle) **/
#include <vector>
#include <memory>
#include <cassert>
#include "mpi.h"

template <typename T>
class MPIDataTypeResolver {
public:
MPI_Datatype getMPIDatatype();
};

namespace xacc {

Expand Down Expand Up @@ -142,13 +149,57 @@ class ProcessGroup {
different MPI processes, thus putting them into disjoint subgroups. **/
std::shared_ptr<ProcessGroup> split(int my_subgroup) const;


// some useful wrappers

// I could move this to a single function, but don't
// want to abuse template specialization here
// this broadcasts a single element (int/char/double)
template<typename T>
void broadcast(T element) {

MPIDataTypeResolver<T> resolver;
MPI_Datatype mpiType = resolver.getMPIDatatype();
MPI_Bcast(&element, 1, mpiType, 0,
this->getMPICommProxy().getRef<MPI_Comm>());
}

// this broadcasts a vector
template<typename T>
void broadcast(std::vector<T> &vec) {

MPIDataTypeResolver<T> resolver;
MPI_Datatype mpiType = resolver.getMPIDatatype();
MPI_Bcast(vec.data(), vec.size(), mpiType, 0,
this->getMPICommProxy().getRef<MPI_Comm>());
};


// this Allgatherv's the content of local vectors
// into a global vector
template<typename T>
void allGatherv(std::vector<T> &local,
std::vector<T> &global,
std::vector<int> &nLocalData,
std::vector<int> &shift) {

MPIDataTypeResolver<T> resolver;
MPI_Datatype mpiType = resolver.getMPIDatatype();
MPI_Allgatherv(local.data(), local.size(), mpiType,
global.data(), nLocalData.data(),
shift.data(), mpiType,
this->getMPICommProxy().getRef<MPI_Comm>());

}

protected:

std::vector<unsigned int> process_ranks_; //global ranks of the MPI processes forming the process group
MPICommProxy intra_comm_; //associated MPI intra-communicator
std::size_t mem_per_process_; //dynamic memory limit per process (bytes)

};

} //namespace xacc

#endif //XACC_MPI_COMM_PROXY_HPP_
#endif //XACC_MPI_COMM_PROXY_HPP_
Loading

0 comments on commit 59bf2ee

Please sign in to comment.