Skip to content

Commit

Permalink
Change to import std::simd in the global namespace.
Browse files Browse the repository at this point in the history
  • Loading branch information
yotarok committed Sep 28, 2023
1 parent b5bf9dc commit 171980a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 42 deletions.
21 changes: 10 additions & 11 deletions src/coding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use std::cell::RefCell;
#[cfg(feature = "experimental")]
use std::collections::BTreeSet;
use std::simd;
#[cfg(feature = "experimental")]
use std::sync::Arc;

Expand Down Expand Up @@ -93,15 +94,13 @@ pub fn encode_residual(config: &config::Prc, errors: &[i32], warmup_length: usiz
}

/// Pack scalars into `Vec` of `Simd`s.
pub fn pack_into_simd_vec<T, const LANES: usize>(
src: &[T],
dest: &mut Vec<std::simd::Simd<T, LANES>>,
) where
T: std::simd::SimdElement + From<i8>,
std::simd::LaneCount<LANES>: std::simd::SupportedLaneCount,
pub fn pack_into_simd_vec<T, const LANES: usize>(src: &[T], dest: &mut Vec<simd::Simd<T, LANES>>)
where
T: simd::SimdElement + From<i8>,
simd::LaneCount<LANES>: simd::SupportedLaneCount,
{
dest.clear();
let mut v = std::simd::Simd::<T, LANES>::splat(0i8.into());
let mut v = simd::Simd::<T, LANES>::splat(0i8.into());
for slice in src.chunks(LANES) {
if slice.len() < LANES {
v.as_mut_array()[0..slice.len()].copy_from_slice(slice);
Expand All @@ -114,10 +113,10 @@ pub fn pack_into_simd_vec<T, const LANES: usize>(
}

/// Unpack slice of `Simd` into `Vec` of elements.
pub fn unpack_simds<T, const LANES: usize>(src: &[std::simd::Simd<T, LANES>], dest: &mut Vec<T>)
pub fn unpack_simds<T, const LANES: usize>(src: &[simd::Simd<T, LANES>], dest: &mut Vec<T>)
where
T: std::simd::SimdElement + From<i8>,
std::simd::LaneCount<LANES>: std::simd::SupportedLaneCount,
T: simd::SimdElement + From<i8>,
simd::LaneCount<LANES>: simd::SupportedLaneCount,
{
dest.resize(src.len() * LANES, 0i8.into());
let mut offset = 0;
Expand All @@ -130,7 +129,7 @@ where

/// Helper struct holding working memory for fixed LPC.
struct FixedLpcEncoder {
errors: Vec<std::simd::i32x16>,
errors: Vec<simd::i32x16>,
/// Length of errors in the number of samples (scalars).
error_len_in_samples: usize,
/// Temporary buffer for unpacked error signal.
Expand Down
27 changes: 15 additions & 12 deletions src/lpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::rc::Rc;
use std::simd::SimdInt;

use serde::Deserialize;
use serde::Serialize;
Expand All @@ -26,6 +25,10 @@ use super::constant::MAX_LPC_ORDER;
use super::constant::QLPC_MAX_SHIFT;
use super::constant::QLPC_MIN_SHIFT;

use std::simd;

use simd::SimdInt;

/// Analysis window descriptor.
///
/// This enum is `Serializable` and `Deserializable` because this will be
Expand Down Expand Up @@ -136,17 +139,17 @@ fn dequantize_parameter(coef: i16, shift: i8) -> f32 {
const QLPC_SIMD_LANES: usize = 16usize;
const MAX_COEF_VECTORS: usize = (MAX_LPC_ORDER + (QLPC_SIMD_LANES - 1)) / QLPC_SIMD_LANES;

const LOW_WORD_MASK: std::simd::i32x16 = std::simd::i32x16::from_array([0x0000_FFFFi32; 16]);
const LOW_WORD_DENOM: std::simd::i32x16 = std::simd::i32x16::from_array([0x0001_0000i32; 16]);
const LOW_WORD_MASK: simd::i32x16 = simd::i32x16::from_array([0x0000_FFFFi32; 16]);
const LOW_WORD_DENOM: simd::i32x16 = simd::i32x16::from_array([0x0001_0000i32; 16]);
#[allow(dead_code)]
const HIGH_WORD_SHIFT: std::simd::i32x16 = std::simd::i32x16::from_array([16i32; 16]);
const HIGH_WORD_SHIFT: simd::i32x16 = simd::i32x16::from_array([16i32; 16]);

/// Shifts elements in a vector of `T` represented as a slice of `Simd<T, N>`.
#[inline]
fn shift_lanes_right<T, const N: usize>(val: T, vecs: &mut [std::simd::Simd<T, N>])
fn shift_lanes_right<T, const N: usize>(val: T, vecs: &mut [simd::Simd<T, N>])
where
T: std::simd::SimdElement,
std::simd::LaneCount<N>: std::simd::SupportedLaneCount,
T: simd::SimdElement,
simd::LaneCount<N>: simd::SupportedLaneCount,
{
let mut carry = val;
for v in vecs {
Expand All @@ -159,7 +162,7 @@ where
/// Quantized LPC coefficients.
#[derive(Clone, Debug)]
pub struct QuantizedParameters {
coefs: heapless::Vec<std::simd::i32x16, MAX_COEF_VECTORS>,
coefs: heapless::Vec<simd::i32x16, MAX_COEF_VECTORS>,
order: usize,
shift: i8,
precision: usize,
Expand Down Expand Up @@ -198,7 +201,7 @@ impl QuantizedParameters {

let mut coefs_v = heapless::Vec::new();
for arr in q_coefs.chunks(QLPC_SIMD_LANES) {
let mut v = std::simd::i32x16::splat(0);
let mut v = simd::i32x16::splat(0);
v.as_mut_array()[0..arr.len()].copy_from_slice(arr);
coefs_v
.push(v)
Expand Down Expand Up @@ -228,15 +231,15 @@ impl QuantizedParameters {
for p in errors.iter_mut().take(self.order()) {
*p = 0;
}
let mut window_h = heapless::Vec::<std::simd::i32x16, MAX_COEF_VECTORS>::new();
let mut window_l = heapless::Vec::<std::simd::i32x16, MAX_COEF_VECTORS>::new();
let mut window_h = heapless::Vec::<simd::i32x16, MAX_COEF_VECTORS>::new();
let mut window_l = heapless::Vec::<simd::i32x16, MAX_COEF_VECTORS>::new();

for i in 0..MAX_COEF_VECTORS {
let tau: isize = (self.order() as isize - 1) - (i * QLPC_SIMD_LANES) as isize;
if tau < 0 {
break;
}
let mut v = std::simd::i32x16::splat(0);
let mut v = simd::i32x16::splat(0);
for j in 0..QLPC_SIMD_LANES {
let j = j as isize;
if tau - j < 0 {
Expand Down
41 changes: 22 additions & 19 deletions src/rice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,45 @@
//! Functions for partitioned rice coding (PRC).

use std::cell::RefCell;
use std::simd::SimdPartialEq;
use std::simd::SimdPartialOrd;
use std::simd::SimdUint;

use super::constant::MAX_RICE_PARAMETER;
use super::constant::MAX_RICE_PARTITIONS;
use super::constant::MAX_RICE_PARTITION_ORDER;
use super::constant::MIN_RICE_PARTITION_SIZE;

use std::simd;

use simd::SimdPartialEq;
use simd::SimdPartialOrd;
use simd::SimdUint;

/// Table that contains the numbers of bits needed for a partition.
#[derive(Clone, Debug, PartialEq, PartialOrd)]
struct PrcBitTable {
p_to_bits: std::simd::u32x16,
mask: std::simd::Mask<<u32 as std::simd::SimdElement>::Mask, 16>,
p_to_bits: simd::u32x16,
mask: simd::Mask<<u32 as simd::SimdElement>::Mask, 16>,
}

static ZEROS: std::simd::u32x16 = std::simd::u32x16::from_array([0u32; 16]);
static INDEX: std::simd::u32x16 =
std::simd::u32x16::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
static INDEX1: std::simd::u32x16 =
std::simd::u32x16::from_array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
static MAXES: std::simd::u32x16 = std::simd::u32x16::from_array([u32::MAX; 16]);
static ZEROS: simd::u32x16 = simd::u32x16::from_array([0u32; 16]);
static INDEX: simd::u32x16 =
simd::u32x16::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
static INDEX1: simd::u32x16 =
simd::u32x16::from_array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
static MAXES: simd::u32x16 = simd::u32x16::from_array([u32::MAX; 16]);

// max value of p_to_bits is chosen so that the estimates doesn't overflow
// after added 2^4 = 16 times at maximum.
// The current version exploits the fact that `MAX_P_TO_BITS` is actually a bit mask, i.e.
// can be written as 2^N - 1 for faster processing. Do not use arbitrary value here.
static MAX_P_TO_BITS: u32 = (1 << 28) - 1;
static MAX_P_TO_BITS_VEC: std::simd::u32x16 = std::simd::u32x16::from_array([MAX_P_TO_BITS; 16]);
static MAX_P_TO_BITS_VEC: simd::u32x16 = simd::u32x16::from_array([MAX_P_TO_BITS; 16]);

impl PrcBitTable {
pub fn zero(max_p: usize) -> Self {
debug_assert!(max_p <= MAX_RICE_PARAMETER);
Self {
p_to_bits: ZEROS,
mask: INDEX.simd_le(std::simd::u32x16::splat(max_p as u32)),
mask: INDEX.simd_le(simd::u32x16::splat(max_p as u32)),
}
}

Expand All @@ -63,18 +66,18 @@ impl PrcBitTable {
/// Initializes PRC bit table from the error signal.
#[allow(unused_assignments, clippy::identity_op)]
fn init_with_errors(&mut self, errors: &[u32], offset: usize) {
let mut p_to_bits = std::simd::u32x16::splat(offset as u32)
+ std::simd::u32x16::splat(errors.len() as u32) * INDEX1;
let mut p_to_bits =
simd::u32x16::splat(offset as u32) + simd::u32x16::splat(errors.len() as u32) * INDEX1;

for v in errors {
// Below is faster than doing:
// vs = splat(*v) >> INDEX;
// or
// vs = std::simd::u32x16::from_array(std::array::from_fn(
// vs = simd::u32x16::from_array(std::array::from_fn(
// |i| v >> i));
// Perhaps due to smaller memory footprint by avoiding `splat`?
let v = *v;
let vs = std::simd::u32x16::from_array([
let vs = simd::u32x16::from_array([
v,
v >> 1,
v >> 2,
Expand Down Expand Up @@ -107,7 +110,7 @@ impl PrcBitTable {
let ret_bits = self.mask.select(self.p_to_bits, MAXES).reduce_min();
let ret_p = self
.p_to_bits
.simd_eq(std::simd::u32x16::splat(ret_bits))
.simd_eq(simd::u32x16::splat(ret_bits))
.select(INDEX, ZEROS)
.reduce_max();

Expand All @@ -117,7 +120,7 @@ impl PrcBitTable {
#[allow(unused_comparisons)]
#[inline]
pub fn merge(&self, other: &Self, offset: usize) -> Self {
let offset = std::simd::u32x16::splat(offset as u32);
let offset = simd::u32x16::splat(offset as u32);
let offset = self.mask.select(offset, ZEROS);
Self {
p_to_bits: self.p_to_bits + other.p_to_bits - offset,
Expand Down

0 comments on commit 171980a

Please sign in to comment.