Skip to content

Commit

Permalink
Implement C code changes in Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Bossen committed Mar 18, 2024
1 parent 2285d1c commit de63e3e
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 298 deletions.
265 changes: 8 additions & 257 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,229 +663,6 @@ fn findoddzero(buf: &[u8]) -> bool {
.is_some()
}

unsafe fn read_pal_plane(
t: &mut Rav1dTaskContext,
f: &mut Rav1dFrameData,
b: &mut Av1Block,
pl: bool,
sz_ctx: u8,
bx4: usize,
by4: usize,
) {
let pli = pl as usize;
let not_pl = !pl as u16;

let ts = &mut *t.ts;

let pal_sz = rav1d_msac_decode_symbol_adapt8(
&mut ts.msac,
&mut ts.cdf.m.pal_sz[pli][sz_ctx as usize],
6,
) as u8
+ 2;
b.pal_sz_mut()[pli] = pal_sz;
let pal_sz = pal_sz as usize;
let mut cache = <[u16; 16]>::default();
let mut used_cache = <[u16; 8]>::default();
let mut l_cache = if pl {
t.pal_sz_uv[1][by4]
} else {
t.l.pal_sz.0[by4]
};
let mut n_cache = 0;
// don't reuse above palette outside SB64 boundaries
let mut a_cache = if by4 & 15 != 0 {
if pl {
t.pal_sz_uv[0][bx4]
} else {
(*t.a).pal_sz.0[bx4]
}
} else {
0
};
let [a, l] = &mut t.al_pal;
let mut l = &l[by4][pli][..];
let mut a = &a[bx4][pli][..];

// fill/sort cache
// TODO: This logic could be replaced with `itertools`' `.merge` and `.dedup`, which would elide bounds checks.
while l_cache != 0 && a_cache != 0 {
if l[0] < a[0] {
if n_cache == 0 || cache[n_cache - 1] != l[0] {
cache[n_cache] = l[0];
n_cache += 1;
}
l = &l[1..];
l_cache -= 1;
} else {
if a[0] == l[0] {
l = &l[1..];
l_cache -= 1;
}
if n_cache == 0 || cache[n_cache - 1] != a[0] {
cache[n_cache] = a[0];
n_cache += 1;
}
a = &a[1..];
a_cache -= 1;
}
}
if l_cache != 0 {
loop {
if n_cache == 0 || cache[n_cache - 1] != l[0] {
cache[n_cache] = l[0];
n_cache += 1;
}
l = &l[1..];
l_cache -= 1;
if !(l_cache > 0) {
break;
}
}
} else if a_cache != 0 {
loop {
if n_cache == 0 || cache[n_cache - 1] != a[0] {
cache[n_cache] = a[0];
n_cache += 1;
}
a = &a[1..];
a_cache -= 1;
if !(a_cache > 0) {
break;
}
}
}
let cache = &cache[..n_cache];

// find reused cache entries
// TODO: Bounds checks could be elided with more complex iterators.
let mut i = 0;
for cache in cache {
if !(i < pal_sz) {
break;
}
if rav1d_msac_decode_bool_equi(&mut ts.msac) {
used_cache[i] = *cache;
i += 1;
}
}
let used_cache = &used_cache[..i];

// parse new entries
let pal = if t.frame_thread.pass != 0 {
&mut f.frame_thread.pal[(((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize) as usize][pli]
} else {
&mut t.scratch.c2rust_unnamed_0.pal[pli]
};
let pal = &mut pal[..pal_sz];
if i < pal.len() {
let mut prev = rav1d_msac_decode_bools(&mut ts.msac, f.cur.p.bpc as u32) as u16;
pal[i] = prev;
i += 1;

if i < pal.len() {
let mut bits = f.cur.p.bpc as u32 + rav1d_msac_decode_bools(&mut ts.msac, 2) - 3;
let max = (1 << f.cur.p.bpc) - 1;

loop {
let delta = rav1d_msac_decode_bools(&mut ts.msac, bits) as u16;
prev = cmp::min(prev + delta + not_pl, max);
pal[i] = prev;
i += 1;
if prev + not_pl >= max {
pal[i..].fill(max);
break;
} else {
bits = cmp::min(bits, 1 + ulog2((max - prev - not_pl) as u32) as u32);
if !(i < pal.len()) {
break;
}
}
}
}

// merge cache+new entries
let mut n = 0;
let mut m = used_cache.len();
for i in 0..pal.len() {
if n < used_cache.len() && (m >= pal.len() || used_cache[n] <= pal[m]) {
pal[i] = used_cache[n];
n += 1;
} else {
pal[i] = pal[m];
m += 1;
}
}
} else {
pal[..used_cache.len()].copy_from_slice(&used_cache);
}

if debug_block_info!(f, t) {
print!(
"Post-pal[pl={},sz={},cache_size={},used_cache={}]: r={}, cache=",
pli,
pal_sz,
cache.len(),
used_cache.len(),
ts.msac.rng
);
for (n, cache) in cache.iter().enumerate() {
print!("{}{:02x}", if n != 0 { ' ' } else { '[' }, cache);
}
print!("{}, pal=", if cache.len() != 0 { "]" } else { "[]" });
for (n, pal) in pal.iter().enumerate() {
print!("{}{:02x}", if n != 0 { ' ' } else { '[' }, pal);
}
println!("]");
}
}

unsafe fn read_pal_uv(
t: &mut Rav1dTaskContext,
f: &mut Rav1dFrameData,
b: &mut Av1Block,
sz_ctx: u8,
bx4: usize,
by4: usize,
) {
read_pal_plane(t, f, b, true, sz_ctx, bx4, by4);

// V pal coding
let ts = &mut *t.ts;

let pal = if t.frame_thread.pass != 0 {
&mut f.frame_thread.pal[(((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize) as usize][2]
} else {
&mut t.scratch.c2rust_unnamed_0.pal[2]
};
let pal = &mut pal[..b.pal_sz()[1] as usize];
if rav1d_msac_decode_bool_equi(&mut ts.msac) {
let bits = f.cur.p.bpc as u32 + rav1d_msac_decode_bools(&mut ts.msac, 2) - 4;
let mut prev = rav1d_msac_decode_bools(&mut ts.msac, f.cur.p.bpc as c_uint) as u16;
pal[0] = prev;
let max = (1 << f.cur.p.bpc) - 1;
for pal in &mut pal[1..] {
let mut delta = rav1d_msac_decode_bools(&mut ts.msac, bits) as i16;
if delta != 0 && rav1d_msac_decode_bool_equi(&mut ts.msac) {
delta = -delta;
}
prev = ((prev as i16 + delta) as u16) & max;
*pal = prev;
}
} else {
pal.fill_with(|| rav1d_msac_decode_bools(&mut ts.msac, f.cur.p.bpc as c_uint) as u16);
}
if debug_block_info!(f, t) {
print!("Post-pal[pl=2]: r={} ", ts.msac.rng);
for (n, pal) in pal.iter().enumerate() {
print!("{}{:02x}", if n != 0 { ' ' } else { '[' }, pal);
}
println!("]");
}
}

fn order_palette(
pal_idx: &[u8],
stride: usize,
Expand Down Expand Up @@ -2022,7 +1799,7 @@ unsafe fn decode_b_inner(
println!("Post-y_pal[{}]: r={}", use_y_pal, ts.msac.rng);
}
if use_y_pal {
read_pal_plane(t, f, b, false, sz_ctx, bx4 as usize, by4 as usize);
(bd_fn.read_pal_plane)(t, f, b, false, sz_ctx, bx4 as usize, by4 as usize);
}
}

Expand All @@ -2037,7 +1814,7 @@ unsafe fn decode_b_inner(
}
if use_uv_pal {
// see aomedia bug 2183 for why we use luma coordinates
read_pal_uv(t, f, b, sz_ctx, bx4 as usize, by4 as usize);
(bd_fn.read_pal_uv)(t, f, b, sz_ctx, bx4 as usize, by4 as usize);
}
}
}
Expand Down Expand Up @@ -2220,19 +1997,7 @@ unsafe fn decode_b_inner(
},
);
if b.pal_sz()[0] != 0 {
let pal = if t.frame_thread.pass != 0 {
let index = ((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize;
&f.frame_thread.pal[index as usize][0]
} else {
&t.scratch.c2rust_unnamed_0.pal[0]
};
for al_pal in &mut t.al_pal[0][bx4 as usize..][..bw4 as usize] {
al_pal[0] = *pal;
}
for al_pal in &mut t.al_pal[1][by4 as usize..][..bh4 as usize] {
al_pal[0] = *pal;
}
(bd_fn.copy_pal_block_y)(t, f, bx4, by4, bw4, bh4);
}
if has_chroma {
CaseSet::<32, false>::many(
Expand All @@ -2244,22 +2009,7 @@ unsafe fn decode_b_inner(
},
);
if b.pal_sz()[1] != 0 {
let pal = if t.frame_thread.pass != 0 {
let index = ((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize;
&f.frame_thread.pal[index as usize]
} else {
&t.scratch.c2rust_unnamed_0.pal
};
// see aomedia bug 2183 for why we use luma coordinates here
for pl in 1..=2 {
for x in 0..bw4 {
t.al_pal[0][(bx4 + x) as usize][pl] = pal[pl];
}
for y in 0..bh4 {
t.al_pal[1][(by4 + y) as usize][pl] = pal[pl];
}
}
(bd_fn.copy_pal_block_uv)(t, f, bx4, by4, bw4, bh4);
}
}
if f.frame_hdr().frame_type.is_inter_or_switch() || f.frame_hdr().allow_intrabc {
Expand Down Expand Up @@ -4354,9 +4104,10 @@ pub(crate) unsafe fn rav1d_decode_frame_init(

if frame_hdr.allow_screen_content_tools {
// TODO: Fallible allocation
f.frame_thread
.pal
.resize(num_sb128 as usize * 16 * 16, Default::default());
f.frame_thread.pal.resize(
num_sb128 as usize * 16 * 16 * 8 * 3 << hbd,
Default::default(),
);

let pal_idx_sz = num_sb128 * size_mul[1] as c_int;
// TODO: Fallible allocation
Expand Down
Loading

0 comments on commit de63e3e

Please sign in to comment.