Skip to content

Commit

Permalink
chacha20: Add avx2::StateWord methods for required operations
Browse files Browse the repository at this point in the history
  • Loading branch information
str4d committed Aug 28, 2021
1 parent 687f953 commit 3dba685
Showing 1 changed file with 76 additions and 35 deletions.
111 changes: 76 additions & 35 deletions chacha20/src/backend/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,59 @@ union StateWord {
avx: __m256i,
}

impl StateWord {
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn add_assign_epi32(&mut self, rhs: &Self) {
self.avx = _mm256_add_epi32(self.avx, rhs.avx);
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn xor_assign(&mut self, rhs: &Self) {
self.avx = _mm256_xor_si256(self.avx, rhs.avx);
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn shuffle_epi32<const MASK: i32>(&mut self) {
self.avx = _mm256_shuffle_epi32(self.avx, MASK);
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn rol<const BY: i32, const REST: i32>(&mut self) {
self.avx = _mm256_xor_si256(
_mm256_slli_epi32(self.avx, BY),
_mm256_srli_epi32(self.avx, REST),
);
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn rol_8(&mut self) {
self.avx = _mm256_shuffle_epi8(
self.avx,
_mm256_set_epi8(
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8, 11,
6, 5, 4, 7, 2, 1, 0, 3,
),
);
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn rol_16(&mut self) {
self.avx = _mm256_shuffle_epi8(
self.avx,
_mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10,
5, 4, 7, 6, 1, 0, 3, 2,
),
);
}
}

/// The ChaCha20 core function (AVX2 accelerated implementation for x86/x86_64)
// TODO(tarcieri): zeroize?
#[derive(Clone)]
Expand Down Expand Up @@ -104,16 +157,16 @@ impl<R: Rounds> Core<R> {
v2: &mut StateWord,
v3: &mut StateWord,
) {
let v3_orig = v3.avx;
let v3_orig = *v3;

for _ in 0..(R::COUNT / 2) {
double_quarter_round(v0, v1, v2, v3);
}

v0.avx = _mm256_add_epi32(v0.avx, self.v0.avx);
v1.avx = _mm256_add_epi32(v1.avx, self.v1.avx);
v2.avx = _mm256_add_epi32(v2.avx, self.v2.avx);
v3.avx = _mm256_add_epi32(v3.avx, v3_orig);
v0.add_assign_epi32(&self.v0);
v1.add_assign_epi32(&self.v1);
v2.add_assign_epi32(&self.v2);
v3.add_assign_epi32(&v3_orig);
}
}

Expand Down Expand Up @@ -221,9 +274,9 @@ unsafe fn rows_to_cols(
d: &mut StateWord,
) {
// c = ROR256_B(c); d = ROR256_C(d); a = ROR256_D(a);
c.avx = _mm256_shuffle_epi32(c.avx, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
d.avx = _mm256_shuffle_epi32(d.avx, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
a.avx = _mm256_shuffle_epi32(a.avx, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
}

/// The goal of this function is to transform the state words from:
Expand Down Expand Up @@ -252,43 +305,31 @@ unsafe fn cols_to_rows(
d: &mut StateWord,
) {
// c = ROR256_D(c); d = ROR256_C(d); a = ROR256_B(a);
c.avx = _mm256_shuffle_epi32(c.avx, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
d.avx = _mm256_shuffle_epi32(d.avx, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
a.avx = _mm256_shuffle_epi32(a.avx, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn add_xor_rot(a: &mut StateWord, b: &mut StateWord, c: &mut StateWord, d: &mut StateWord) {
// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_16(d);
a.avx = _mm256_add_epi32(a.avx, b.avx);
d.avx = _mm256_xor_si256(d.avx, a.avx);
d.avx = _mm256_shuffle_epi8(
d.avx,
_mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5,
4, 7, 6, 1, 0, 3, 2,
),
);
a.add_assign_epi32(b);
d.xor_assign(a);
d.rol_16();

// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_12(b);
c.avx = _mm256_add_epi32(c.avx, d.avx);
b.avx = _mm256_xor_si256(b.avx, c.avx);
b.avx = _mm256_xor_si256(_mm256_slli_epi32(b.avx, 12), _mm256_srli_epi32(b.avx, 20));
c.add_assign_epi32(d);
b.xor_assign(c);
b.rol::<12, 20>();

// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_8(d);
a.avx = _mm256_add_epi32(a.avx, b.avx);
d.avx = _mm256_xor_si256(d.avx, a.avx);
d.avx = _mm256_shuffle_epi8(
d.avx,
_mm256_set_epi8(
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8, 11, 6,
5, 4, 7, 2, 1, 0, 3,
),
);
a.add_assign_epi32(b);
d.xor_assign(a);
d.rol_8();

// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_7(b);
c.avx = _mm256_add_epi32(c.avx, d.avx);
b.avx = _mm256_xor_si256(b.avx, c.avx);
b.avx = _mm256_xor_si256(_mm256_slli_epi32(b.avx, 7), _mm256_srli_epi32(b.avx, 25));
c.add_assign_epi32(d);
b.xor_assign(c);
b.rol::<7, 25>();
}

0 comments on commit 3dba685

Please sign in to comment.