Skip to content

Commit

Permalink
simd SDosageAbsDotprod
Browse files Browse the repository at this point in the history
  • Loading branch information
chrchang committed Oct 11, 2023
1 parent 97ab45e commit 4e42b43
Showing 1 changed file with 98 additions and 11 deletions.
109 changes: 98 additions & 11 deletions 2.0/plink2_ld.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3312,7 +3312,7 @@ uint64_t DosageUnsignedNomissDotprod(const Dosage* dosage_vec0, const Dosage* do
}
}

int64_t DosageSignedDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
int64_t SDosageDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
const __m256i* dphase_delta0_iter = R_CAST(const __m256i*, dphase_delta0);
const __m256i* dphase_delta1_iter = R_CAST(const __m256i*, dphase_delta1);
const __m256i m16 = _mm256_set1_epi64x(kMask0000FFFF);
Expand Down Expand Up @@ -3361,6 +3361,44 @@ int64_t DosageSignedDotprod(const SDosage* dphase_delta0, const SDosage* dphase_
dotprod += UniVecHsum32(acc_lo) + 65536 * UniVecHsum32(acc_hi);
}
}

uint64_t SDosageAbsDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
const __m256i* dphase_delta0_iter = R_CAST(const __m256i*, dphase_delta0);
const __m256i* dphase_delta1_iter = R_CAST(const __m256i*, dphase_delta1);
const __m256i m16 = _mm256_set1_epi64x(kMask0000FFFF);
uint64_t dotprod = 0;
for (uint32_t vecs_left = vec_ct; ; ) {
__m256i dotprod_lo = _mm256_setzero_si256();
__m256i dotprod_hi = _mm256_setzero_si256();
const __m256i* dphase_delta0_stop;
if (vecs_left < 4096) {
if (!vecs_left) {
return dotprod;
}
dphase_delta0_stop = &(dphase_delta0_iter[vecs_left]);
vecs_left = 0;
} else {
dphase_delta0_stop = &(dphase_delta0_iter[4096]);
vecs_left -= 4096;
}
do {
const __m256i absdosage0 = _mm256_abs_epi16(*dphase_delta0_iter++);
const __m256i absdosage1 = _mm256_abs_epi16(*dphase_delta1_iter++);

__m256i lo16 = _mm256_mullo_epi16(absdosage0, absdosage1);
__m256i hi16 = _mm256_mulhi_epu16(absdosage0, absdosage1);
lo16 = _mm256_add_epi64(_mm256_and_si256(lo16, m16), _mm256_and_si256(_mm256_srli_epi64(lo16, 16), m16));
hi16 = _mm256_and_si256(_mm256_add_epi64(hi16, _mm256_srli_epi64(hi16, 16)), m16);
dotprod_lo = _mm256_add_epi64(dotprod_lo, lo16);
dotprod_hi = _mm256_add_epi64(dotprod_hi, hi16);
} while (dphase_delta0_iter < dphase_delta0_stop);
UniVec acc_lo;
UniVec acc_hi;
acc_lo.vw = R_CAST(VecW, dotprod_lo);
acc_hi.vw = R_CAST(VecW, dotprod_hi);
dotprod += UniVecHsum32(acc_lo) + 65536 * UniVecHsum32(acc_hi);
}
}
# else // !USE_AVX2
void FillDosageHet(const Dosage* dosage_vec, uint32_t dosagev_ct, Dosage* dosage_het) {
const __m128i* dosage_vvec_iter = R_CAST(const __m128i*, dosage_vec);
Expand Down Expand Up @@ -3550,7 +3588,7 @@ uint64_t DosageUnsignedNomissDotprod(const Dosage* dosage_vec0, const Dosage* do
}
}

int64_t DosageSignedDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
int64_t SDosageDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
const __m128i* dphase_delta0_iter = R_CAST(const __m128i*, dphase_delta0);
const __m128i* dphase_delta1_iter = R_CAST(const __m128i*, dphase_delta1);
const __m128i m16 = _mm_set1_epi64x(kMask0000FFFF);
Expand Down Expand Up @@ -3594,6 +3632,55 @@ int64_t DosageSignedDotprod(const SDosage* dphase_delta0, const SDosage* dphase_
dotprod += UniVecHsum32(acc_lo) + 65536 * UniVecHsum32(acc_hi);
}
}

uint64_t SDosageAbsDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
const __m128i* dphase_delta0_iter = R_CAST(const __m128i*, dphase_delta0);
const __m128i* dphase_delta1_iter = R_CAST(const __m128i*, dphase_delta1);
const __m128i m16 = _mm_set1_epi64x(kMask0000FFFF);
# ifndef USE_SSE42
const __m128i zero = _mm_setzero_si128();
# endif
uint64_t dotprod = 0;
for (uint32_t vecs_left = vec_ct; ; ) {
__m128i dotprod_lo = _mm_setzero_si128();
__m128i dotprod_hi = _mm_setzero_si128();
const __m128i* dphase_delta0_stop;
if (vecs_left < 8192) {
if (!vecs_left) {
return dotprod;
}
dphase_delta0_stop = &(dphase_delta0_iter[vecs_left]);
vecs_left = 0;
} else {
dphase_delta0_stop = &(dphase_delta0_iter[8192]);
vecs_left -= 8192;
}
do {
const __m128i dosage0 = *dphase_delta0_iter++;
const __m128i dosage1 = *dphase_delta1_iter++;
# ifdef USE_SSE42
const __m128i absdosage0 = _mm_abs_epi16(dosage0);
const __m128i absdosage1 = _mm_abs_epi16(dosage1);
# else
const __m128i negdosage0 = _mm_sub_epi16(zero, dosage0);
const __m128i negdosage1 = _mm_sub_epi16(zero, dosage1);
const __m128i absdosage0 = _mm_max_epi16(dosage0, negdosage0);
const __m128i absdosage1 = _mm_max_epi16(dosage1, negdosage1);
# endif
__m128i lo16 = _mm_mullo_epi16(absdosage0, absdosage1);
__m128i hi16 = _mm_mulhi_epu16(absdosage0, absdosage1);
lo16 = _mm_add_epi64(_mm_and_si128(lo16, m16), _mm_and_si128(_mm_srli_epi64(lo16, 16), m16));
hi16 = _mm_and_si128(_mm_add_epi64(hi16, _mm_srli_epi64(hi16, 16)), m16);
dotprod_lo = _mm_add_epi64(dotprod_lo, lo16);
dotprod_hi = _mm_add_epi64(dotprod_hi, hi16);
} while (dphase_delta0_iter < dphase_delta0_stop);
UniVec acc_lo;
UniVec acc_hi;
acc_lo.vw = R_CAST(VecW, dotprod_lo);
acc_hi.vw = R_CAST(VecW, dotprod_hi);
dotprod += UniVecHsum32(acc_lo) + 65536 * UniVecHsum32(acc_hi);
}
}
# endif // !USE_AVX2
#else // !__LP64__
void FillDosageHet(const Dosage* dosage_vec, uint32_t dosagev_ct, Dosage* dosage_het) {
Expand Down Expand Up @@ -3661,7 +3748,7 @@ uint64_t DosageUnsignedNomissDotprod(const Dosage* dosage_vec0, const Dosage* do
return dotprod;
}

int64_t DosageSignedDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
int64_t SDosageDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
const uint32_t sample_ctav = vec_ct * kDosagePerVec;
int64_t dotprod = 0;
for (uint32_t sample_idx = 0; sample_idx != sample_ctav; ++sample_idx) {
Expand All @@ -3671,9 +3758,8 @@ int64_t DosageSignedDotprod(const SDosage* dphase_delta0, const SDosage* dphase_
}
return dotprod;
}
#endif
// todo: optimize this
uint64_t DosageAbsDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {

uint64_t SDosageAbsDotprod(const SDosage* dphase_delta0, const SDosage* dphase_delta1, uint32_t vec_ct) {
const uint32_t sample_ctav = vec_ct * kDosagePerVec;
uint64_t dotprod = 0;
for (uint32_t sample_idx = 0; sample_idx != sample_ctav; ++sample_idx) {
Expand All @@ -3683,6 +3769,7 @@ uint64_t DosageAbsDotprod(const SDosage* dphase_delta0, const SDosage* dphase_de
}
return dotprod;
}
#endif

uint32_t DosageR2Prod(const Dosage* dosage_vec0, const uintptr_t* nm_bitvec0, const Dosage* dosage_vec1, const uintptr_t* nm_bitvec1, uint32_t sample_ct, uint32_t nm_ct0, uint32_t nm_ct1, uint64_t* __restrict nmaj_dosages, uint64_t* __restrict dosageprod_ptr) {
const uint32_t sample_ctl = BitCtToWordCt(sample_ct);
Expand Down Expand Up @@ -4172,8 +4259,8 @@ PglErr LdConsole(const uintptr_t* variant_include, const ChrInfo* cip, const cha
hethet_present = (hethet_dosageprod != 0);
uint64_t uhethet_dosageprod = hethet_dosageprod;
if (use_phase && hethet_present) {
dosageprod = S_CAST(int64_t, dosageprod) + DosageSignedDotprod(main_dphase_deltas[0], main_dphase_deltas[1], founder_dosagev_ct);
uhethet_dosageprod -= DosageAbsDotprod(main_dphase_deltas[0], main_dphase_deltas[1], founder_dosagev_ct);
dosageprod = S_CAST(int64_t, dosageprod) + SDosageDotprod(main_dphase_deltas[0], main_dphase_deltas[1], founder_dosagev_ct);
uhethet_dosageprod -= SDosageAbsDotprod(main_dphase_deltas[0], main_dphase_deltas[1], founder_dosagev_ct);
}
nmajsums_d[0] = u63tod(nmaj_dosages[0]) * kRecipDosageMid;
nmajsums_d[1] = u63tod(nmaj_dosages[1]) * kRecipDosageMid;
Expand Down Expand Up @@ -4205,7 +4292,7 @@ PglErr LdConsole(const uintptr_t* variant_include, const ChrInfo* cip, const cha
uint64_t invalid_uhethet_dosageprod = male_hethet_dosageprod;
if (use_phase) {
BitvecInvmask(R_CAST(uintptr_t*, x_male_dosage_invmask), founder_dosagev_ct * kWordsPerVec, R_CAST(uintptr_t*, main_dphase_deltas[0]));
invalid_uhethet_dosageprod -= DosageAbsDotprod(main_dphase_deltas[0], main_dphase_deltas[1], founder_dosagev_ct);
invalid_uhethet_dosageprod -= SDosageAbsDotprod(main_dphase_deltas[0], main_dphase_deltas[1], founder_dosagev_ct);
}
unknown_hethet_d -= u63tod(invalid_uhethet_dosageprod) * kRecipDosageMidSq;
known_dotprod_d += u63tod(invalid_uhethet_dosageprod) * (kRecipDosageMidSq * 0.5);
Expand Down Expand Up @@ -5537,8 +5624,8 @@ uint32_t ComputeR2DosagePhasedStats(const R2DosageVariant* dp0, const R2DosageVa
if ((phase_type == kR2PhaseTypePresent) && (uhethet_dosageprod != 0)) {
const SDosage* dphase_delta0 = dp0->dense_dphase_delta;
const SDosage* dphase_delta1 = dp1->dense_dphase_delta;
dosageprod = S_CAST(int64_t, dosageprod) + DosageSignedDotprod(dphase_delta0, dphase_delta1, sample_dosagev_ct);
uhethet_dosageprod -= DosageAbsDotprod(dphase_delta0, dphase_delta1, sample_dosagev_ct);
dosageprod = S_CAST(int64_t, dosageprod) + SDosageDotprod(dphase_delta0, dphase_delta1, sample_dosagev_ct);
uhethet_dosageprod -= SDosageAbsDotprod(dphase_delta0, dphase_delta1, sample_dosagev_ct);
}
}
nmajsums_d[0] = u63tod(nmaj_dosages[0]) * kRecipDosageMid;
Expand Down

0 comments on commit 4e42b43

Please sign in to comment.