Skip to content

Commit

Permalink
a little optimization of vadd/vsub
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Aug 26, 2024
1 parent be17da7 commit 6bf0400
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
42 changes: 34 additions & 8 deletions src/gen_bint_x64.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def gen_vsubPre(mont, vN=1):
def gen_vadd(mont, vN=1):
SUF = 'A' if vN == 2 else ''
with FuncProc(MSM_PRE+'vadd'+SUF):
with StackFrame(3, 0, useRCX=True, vNum=mont.N*2+2, vType=T_ZMM) as sf:
with StackFrame(3, 0, useRCX=True, vNum=mont.N*2+3, vType=T_ZMM) as sf:
regs = list(reversed(sf.v))
W = mont.W
N = mont.N
Expand All @@ -75,6 +75,7 @@ def gen_vadd(mont, vN=1):
t = pops(regs, N)
vmask = pops(regs, 1)[0]
c = pops(regs, 1)[0]
zero = pops(regs, 1)[0]

mov(rax, mont.mask)
vpbroadcastq(vmask, rax)
Expand Down Expand Up @@ -104,7 +105,6 @@ def gen_vadd(mont, vN=1):
if i > 0:
vpsubq(t[i], t[i], c)
vpsrlq(c, t[i], S)
un(vpandq)(t, t, vmask)
else:
# a little faster
# s = x+y
Expand All @@ -124,14 +124,13 @@ def gen_vadd(mont, vN=1):
if i > 0:
vpsubq(t[i], t[i], c);
vpsrlq(c, t[i], S)
vpandq(t[i], t[i], vmask)

vpxorq(vmask, vmask, vmask)
vpcmpgtq(k1, c, vmask) # k1 = t<0
vpxorq(zero, zero, zero)
vpcmpeqq(k1, c, zero) # k1 = t>=0
# z = select(k1, s, t)
for i in range(N):
vmovdqa64(t[i]|k1, s[i])
un(vmovdqa64)(ptr(z), t)
vpandq(s[i]|k1, t[i], vmask)
un(vmovdqa64)(ptr(z), s)

if vN == 2:
add(x, 64)
Expand Down Expand Up @@ -226,6 +225,32 @@ def vmulUnitAdd(z, px, y, N, H, t):
vpxorq(z[N], z[N], z[N])
vmulH(z[N], t, y)

def gen_vmul(mont):
with FuncProc(MSM_PRE+'vmul'):
with StackFrame(3, 0, vNum=mont.N*2+4, vType=T_ZMM) as sf:
regs = list(reversed(sf.v))
W = mont.W
N = mont.N
pz = sf.p[0]
px = sf.p[1]
py = sf.p[2]

t = pops(regs, N*2)
vmask = pops(regs, 1)[0]
c = pops(regs, 1)[0]
y = pops(regs, 1)[0]
H = pops(regs, 1)[0]

mov(rax, mont.mask)
vpbroadcastq(vmask, rax)

un = genUnrollFunc()

vmovdqa64(y, ptr(py))
un(vmovdqa64)(t[0:N], ptr(pz))
vmulUnitAdd(t, px, y, N, H, c)
un(vmovdqa64)(ptr(pz), t[0:N+1])

def msm_data(mont):
makeLabel(C_p)
dq_(', '.join(map(hex, mont.toArray(mont.p))))
Expand All @@ -234,9 +259,10 @@ def msm_code(mont):
for vN in [1, 2]:
gen_vaddPre(mont, vN)
gen_vsubPre(mont, vN)
gen_vadd(mont, vN)

gen_vadd(mont)
gen_vsub(mont)
gen_vmul(mont)

SUF='_fast'
param=None
Expand Down
8 changes: 4 additions & 4 deletions src/msm_avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void mcl_c5_vsubPreA(VecA *, const VecA *, const VecA *);

void mcl_c5_vadd(Vec *, const Vec *, const Vec *);
void mcl_c5_vsub(Vec *, const Vec *, const Vec *);
//void mcl_c5_vaddA(VecA *, const VecA *, const VecA *);

void mcl_c5_vmul(Vec *, const Vec *, const Vec *);
void mcl_c5_vaddA(VecA *, const VecA *, const VecA *);

}

Expand Down Expand Up @@ -169,7 +169,7 @@ inline void vadd(Vec *z, const Vec *x, const Vec *y)
{
mcl_c5_vadd(z, x, y);
}
#if 0
#if 1
template<>
inline void vadd(VecA *z, const VecA *x, const VecA *y)
{
Expand Down Expand Up @@ -1727,7 +1727,7 @@ CYBOZU_TEST_AUTO(vaddPre)
CYBOZU_BENCH_C("asm vsubPreA", C, mcl_c5_vsubPreA, za.v, za.v, xa.v);
CYBOZU_BENCH_C("asm vadd", C, mcl_c5_vadd, z[0].v, z[0].v, x[0].v);
CYBOZU_BENCH_C("asm vsub", C, mcl_c5_vsub, z[0].v, z[0].v, x[0].v);
// CYBOZU_BENCH_C("asm vaddA", C, mcl_c5_vaddA, za.v, za.v, xa.v);
CYBOZU_BENCH_C("asm vaddA", C, mcl_c5_vaddA, za.v, za.v, xa.v);
#endif
CYBOZU_BENCH_C("vadd::Vec", C, vadd, z[0].v, z[0].v, x[0].v);
CYBOZU_BENCH_C("vsub::Vec", C, vsub, z[0].v, z[0].v, x[0].v);
Expand Down

0 comments on commit 6bf0400

Please sign in to comment.