Skip to content

Commit

Permalink
⚡️ Optimize keccak
Browse files Browse the repository at this point in the history
  • Loading branch information
Philogy committed Feb 17, 2025
1 parent fa564d4 commit 9f8cb34
Showing 1 changed file with 135 additions and 95 deletions.
230 changes: 135 additions & 95 deletions lib/src/keccak.rs
Original file line number Diff line number Diff line change
@@ -1,159 +1,199 @@
use std::ops::{Deref, DerefMut};
use tiny_keccak::keccakf;

const WORDS: usize = 25;
const BYTES: usize = WORDS * 8;
type Word = u32;

const STATE_BYTES: usize = 200;
const WORD_BYTES: usize = (Word::BITS as usize) / 8;
const WORDS: usize = STATE_BYTES / WORD_BYTES;

const DELIM: u8 = 0x01;

const BLOCK_SIZE: usize = 136;

#[derive(Debug, Clone)]
struct Keccak256State([Word; WORDS]);

impl Deref for Keccak256State {
type Target = [Word; WORDS];

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for Keccak256State {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl Default for Keccak256State {
fn default() -> Self {
Self([0; WORDS])
}
}

impl AsMut<[u8; STATE_BYTES]> for Keccak256State {
#[inline]
fn as_mut(&mut self) -> &mut [u8; STATE_BYTES] {
unsafe { core::mem::transmute(self) }
}
}

#[derive(Debug, Clone)]
pub struct Keccak256 {
buffer: [u64; WORDS],
state: Keccak256State,
offset: usize,
first_block: bool,
}

impl Default for Keccak256 {
fn default() -> Self {
Self {
buffer: [0u64; WORDS],
state: Default::default(),
offset: 0,
first_block: true,
}
}
}

const DELIM: u8 = 0x01;
const RATE: usize = 136;

/// Keccak256 struct optimized for repeated hashing & outputting 32-byte hashes.
impl Keccak256 {
#[inline]
fn keccak(&mut self) {
keccakf(&mut self.buffer)
fn permute(&mut self) {
keccakf(unsafe { core::mem::transmute(&mut self.state) })
}

/// Expects `offset` & `len` to be within the byte-bounds of `self.buffer`.
#[cfg(target_endian = "little")]
#[inline]
unsafe fn execute<F: FnOnce(&mut [u8])>(&mut self, offset: usize, len: usize, f: F) {
let buffer: &mut [u8; BYTES] = unsafe { core::mem::transmute(&mut self.buffer) };
f(buffer.get_unchecked_mut(offset..).get_unchecked_mut(..len));
fn absorb<T: std::ops::BitXorAssign<T> + Copy>(first: bool, dst: &mut T, src: &T) {
if first {
*dst = *src;
} else {
*dst ^= *src;
}
}

#[cfg(target_endian = "big")]
/// Safety: Assumes caller has checked that `block` is aligned with width of ```Word```
#[inline]
fn execute<F: FnOnce(&mut [u8])>(&mut self, offset: usize, len: usize, f: F) {
fn swap_endianess(buffer: &mut [u64]) {
for item in buffer {
*item = item.swap_bytes();
}
}
unsafe fn absorb_aligned(&mut self, first: bool, block: &[u8]) {
let block: &[Word] =
std::slice::from_raw_parts(block.as_ptr() as *const Word, BLOCK_SIZE / WORD_BYTES);

let start = offset / 8;
let end = (offset + len + 7) / 8;
swap_endianess(&mut self.0[start..end]);
let buffer: &mut [u8; BYTES] = unsafe { core::mem::transmute(&mut self.0) };
f(&mut buffer[offset..][..len]);
swap_endianess(&mut self.0[start..end]);
}
for i in 0..BLOCK_SIZE / WORD_BYTES {
Self::absorb(first, &mut self.state[i], &block[i]);
}

unsafe fn xorin(&mut self, src: &[u8], offset: usize, len: usize) {
self.execute(offset, len, |dst| {
let len = dst.len();
let mut dst_ptr = dst.as_mut_ptr();
let mut src_ptr = src.as_ptr();
for _ in 0..len {
*dst_ptr ^= *src_ptr;
src_ptr = src_ptr.offset(1);
dst_ptr = dst_ptr.offset(1);
}
});
self.permute();
}
#[inline]
fn absorb_block(&mut self, first: bool, block: &[u8]) {
for (s, b) in self.state.iter_mut().zip(block.chunks_exact(WORD_BYTES)) {
Self::absorb(first, s, &Word::from_le_bytes(b.try_into().unwrap()));
}

unsafe fn setin(&mut self, src: &[u8], offset: usize, len: usize) {
self.execute(offset, len, |dst| {
let len = dst.len();
let mut dst_ptr = dst.as_mut_ptr();
let mut src_ptr = src.as_ptr();
for _ in 0..len {
*dst_ptr = *src_ptr;
src_ptr = src_ptr.offset(1);
dst_ptr = dst_ptr.offset(1);
}
});
self.permute();
}

pub fn update(&mut self, input: impl AsRef<[u8]>) {
let mut input = input.as_ref();
let mut rate = RATE - self.offset;
let mut offset = self.offset;
let offset = self.offset;
let rem = BLOCK_SIZE - self.offset;

// If input not long enough to fill block partially absorb.
if input.len() < rem {
let state_bytes = self.state.as_mut();
for (s, inp) in state_bytes.iter_mut().zip(input) {
Self::absorb(self.first_block, s, inp);
}
self.offset = offset + input.len();
return;
}

if self.first_block {
if input.len() >= rate {
unsafe {
self.setin(input, offset, rate);
}
self.keccak();
self.first_block = false;
input = &input[rate..];
rate = RATE;
offset = 0;
// If last block was incomplete, complete first.
if offset != 0 {
let (left, right) = input.split_at(rem);
input = right;

let state_bytes = self.state.as_mut();
for (s, inp) in state_bytes[offset..].iter_mut().zip(left) {
Self::absorb(self.first_block, s, inp);
}

self.permute();
self.first_block = false;
}

let align_offset = input.as_ptr().align_offset(WORD_BYTES);

if self.first_block && input.len() >= BLOCK_SIZE {
let (block, right) = input.split_at(BLOCK_SIZE);
input = right;

if align_offset == 0 {
unsafe { self.absorb_aligned(true, block) };
} else {
unsafe {
self.setin(input, offset, input.len());
}
self.offset = offset + input.len();
return;
self.absorb_block(true, block);
}

self.first_block = false;
}

while input.len() >= rate {
unsafe {
self.xorin(input, offset, rate);
if align_offset == 0 {
while input.len() >= BLOCK_SIZE {
let (block, right) = input.split_at(BLOCK_SIZE);
input = right;

// If `input` was aligned initially and we only split in aligned increments we know
// the resulting slice is aligned.
unsafe { self.absorb_aligned(false, block) };
}
} else {
while input.len() >= BLOCK_SIZE {
let (block, right) = input.split_at(BLOCK_SIZE);
input = right;

self.absorb_block(false, block);
}
self.keccak();
input = &input[rate..];
rate = RATE;
offset = 0;
}

unsafe {
self.xorin(input, offset, input.len());
let buffer = self.state.as_mut();
for i in 0..input.len() {
buffer[i] ^= input[i];
}
self.offset = offset + input.len();
self.offset = input.len();
}

fn pad(&mut self) {
if self.first_block {
let buffer = self.state.as_mut();
for i in self.offset..BLOCK_SIZE {
buffer[i] = 0;
}
}
unsafe {
self.execute(self.offset, 1, |buff| buff[0] ^= DELIM);
self.execute(RATE - 1, 1, |buff| buff[0] ^= 0x80);
let buffer: &mut [u8; STATE_BYTES] = core::mem::transmute(&mut self.state);
*buffer.get_unchecked_mut(self.offset) ^= DELIM;
*buffer.get_unchecked_mut(BLOCK_SIZE - 1) ^= 0x80;
}
}

pub fn finalize_and_reset(&mut self, output: &mut [u8; 32]) {
if self.first_block {
unsafe {
self.execute(self.offset, RATE - self.offset, |b| {
for i in 0..b.len() {
b[i] = 0;
}
});
}
}

self.pad();

self.keccak();
self.permute();

let words_out: &mut [u64; 4] = unsafe { core::mem::transmute(output) };
for i in 0..4 {
words_out[i] = self.buffer[i];
let words_out: &mut [Word; 32 / WORD_BYTES] = unsafe { core::mem::transmute(output) };
for (a, b) in words_out.iter_mut().zip(*self.state) {
*a = b;
}

self.reset();
}

fn reset(&mut self) {
for i in RATE / 8..WORDS {
self.buffer[i] = 0;
for i in BLOCK_SIZE / WORD_BYTES..WORDS {
self.state[i] = 0;
}
self.first_block = true;
self.offset = 0;
Expand Down

0 comments on commit 9f8cb34

Please sign in to comment.