From bea30762acaefee54ebaf3c68713b66414345e12 Mon Sep 17 00:00:00 2001 From: Tom Lin Date: Sat, 2 Sep 2023 05:56:26 +0100 Subject: [PATCH] Add working thrust implementation --- src/thrust/fasten.hpp | 86 ++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/src/thrust/fasten.hpp b/src/thrust/fasten.hpp index 7696861..521138e 100644 --- a/src/thrust/fasten.hpp +++ b/src/thrust/fasten.hpp @@ -19,27 +19,22 @@ template class IMPL_CLS final : public Bude { public: - static void fasten_main(const Params &p, std::vector &results) { - - thrust::device_vector protein(p.protein); - thrust::device_vector ligand(p.ligand); - thrust::device_vector transforms_0(p.poses[0]); - thrust::device_vector transforms_1(p.poses[1]); - thrust::device_vector transforms_2(p.poses[2]); - thrust::device_vector transforms_3(p.poses[3]); - thrust::device_vector transforms_4(p.poses[4]); - thrust::device_vector transforms_5(p.poses[5]); - thrust::device_vector forcefield(p.forcefield); - thrust::device_vector energies(results.size()); - - thrust::device_vector> out(p.nposes() / PPWI); + static void fasten_main(const Params &p, // + thrust::device_vector &protein, thrust::device_vector &ligand, + thrust::device_vector &transforms_0, thrust::device_vector &transforms_1, + thrust::device_vector &transforms_2, thrust::device_vector &transforms_3, + thrust::device_vector &transforms_4, thrust::device_vector &transforms_5, + thrust::device_vector &forcefield, thrust::device_vector &energies) { thrust::counting_iterator groups(0); - thrust::transform( // - groups, // - groups + (p.nposes() / PPWI), // - out.begin(), // - [=] __device__ __host__(const int group) { + thrust::for_each( + groups, groups + (p.nposes() / PPWI), + [natlig = p.natlig(), natpro = p.natpro(), // + protein = protein.data(), ligand = ligand.data(), // + transforms_0 = transforms_0.data(), transforms_1 = transforms_1.data(), transforms_2 = transforms_2.data(), // + transforms_3 = transforms_3.data(), transforms_4 = transforms_4.data(), transforms_5 = transforms_5.data(), // + forcefield = forcefield.data(), // + energies = energies.data()] __device__ __host__(const int group) { std::array, 3>, PPWI> transform = {}; std::array etot = {}; @@ -69,8 +64,9 @@ template class IMPL_CLS final : public Bude { } // Loop over ligand atoms - for (const Atom &l_atom : ligand) { - const FFParams l_params = forcefield[l_atom.type]; + for (size_t il = 0; il < natlig; il++) { + const Atom &l_atom = ligand[il]; + const FFParams &l_params = forcefield[l_atom.type]; const int lhphb_ltz = l_params.hphb < ZERO; const int lhphb_gtz = l_params.hphb > ZERO; @@ -87,9 +83,10 @@ template class IMPL_CLS final : public Bude { } // Loop over protein atoms - for (const Atom &p_atom : protein) { - // // Load protein atom data - const FFParams p_params = forcefield[p_atom.type]; + for (size_t ip = 0; ip < natpro; ip++) { + // Load protein atom data + const Atom &p_atom = protein[ip]; + const FFParams &p_params = forcefield[p_atom.type]; const float radij = p_params.radius + l_params.radius; const float r_radij = ONE / radij; @@ -141,24 +138,11 @@ template class IMPL_CLS final : public Bude { } } - ////#pragma omp simd - // for (int l = 0; l < PPWI; l++) { - // etot[l] *= HALF; - // } - // - // return std::make_pair(group, etot); - - // Write result - //#pragma omp simd - // for (int l = 0; l < PPWI; l++) { - // energies[group * PPWI + l] = etot[l] *= HALF; - // } - +// Write result #pragma omp simd for (int l = 0; l < PPWI; l++) { - etot[l] *= HALF; + energies[group * PPWI + l] = etot[l] *= HALF; } - return etot; }); } @@ -192,7 +176,11 @@ template class IMPL_CLS final : public Bude { checkError(IMPL_FN__(GetDeviceCount(&count))); std::vector devices(count); for (int i = 0; i < count; ++i) { - IMPL_FN__(DeviceProp) props{}; + #if defined(__HIP_PLATFORM_HCC__) // can't use IMPL_TYPE__ here because of the extra _t suffix, thanks AMD + hipDeviceProp_t props{}; + #else + cudaDeviceProp props{}; + #endif checkError(IMPL_FN__(GetDeviceProperties(&props, i))); devices[i] = {i, std::string(props.name) + " (" + // std::to_string(props.totalGlobalMem / 1024 / 1024) + "MB;" + // @@ -227,13 +215,29 @@ template class IMPL_CLS final : public Bude { Sample sample(PPWI, wgsize, p.nposes()); + thrust::device_vector protein(p.protein); + thrust::device_vector ligand(p.ligand); + thrust::device_vector transforms_0(p.poses[0]); + thrust::device_vector transforms_1(p.poses[1]); + thrust::device_vector transforms_2(p.poses[2]); + thrust::device_vector transforms_3(p.poses[3]); + thrust::device_vector transforms_4(p.poses[4]); + thrust::device_vector transforms_5(p.poses[5]); + thrust::device_vector forcefield(p.forcefield); + thrust::device_vector energies(sample.energies.size()); + for (size_t i = 0; i < p.totalIterations(); ++i) { auto kernelStart = now(); - fasten_main(p, sample.energies); + fasten_main(p, protein, ligand, // + transforms_0, transforms_1, transforms_2, // + transforms_3, transforms_4, transforms_5, // + forcefield, energies // + ); synchronise(); auto kernelEnd = now(); sample.kernelTimes.emplace_back(kernelStart, kernelEnd); } + thrust::copy(energies.begin(), energies.end(), sample.energies.begin()); return sample; };