Skip to content

Commit

Permalink
struct Rav1dState: Split out the fields of Rav1dContext that are …
Browse files Browse the repository at this point in the history
…mutated on the main thread and put them in a `Mutex`.

Our previous code where we passed a `&Rav1dContext` to the worker threads
and mutated the `&mut Rav1dContext` through `DAV1D_API`s like `fn rav1d_flush` was unsound,
as we were mutating `Rav1dContext` while a `&Rav1dContext` exists.
This moves all of those fields mutated on the main thread from the `DAV1D_API`s into `Rav1dState`,
and then puts `Rav1dState` into a `Mutex` (`.try_lock()`ed) into `Rav1dContext`.
We only need to `.try_lock()` it a few places in the `DAV1D_API`s;
otherwise, we can just pass `c: &Rav1dContext, state: &mut Rav1dState` args to most `fn`s.
Now `Rav1dContext` is always accessed through a `&` except during its construction in `fn rav1d_open`.

This also lets us remove the `Mutex` around `cached_error_props` and the `Atomic` around `frame_flags`,
since they are now accessed through a `&mut Rav1dState`.
  • Loading branch information
kkysen committed Jun 12, 2024
1 parent 7c9de93 commit 72521fa
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 321 deletions.
167 changes: 106 additions & 61 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ use crate::src::internal::Rav1dFrameContext;
use crate::src::internal::Rav1dFrameContext_frame_thread;
use crate::src::internal::Rav1dFrameContext_lf;
use crate::src::internal::Rav1dFrameData;
use crate::src::internal::Rav1dState;
use crate::src::internal::Rav1dTaskContext;
use crate::src::internal::Rav1dTileState;
use crate::src::internal::Rav1dTileStateContext;
Expand Down Expand Up @@ -170,7 +171,6 @@ use crate::src::warpmv::rav1d_find_affine_int;
use crate::src::warpmv::rav1d_get_shear_params;
use crate::src::warpmv::rav1d_set_affine_mv2d;
use libc::ptrdiff_t;
use parking_lot::Mutex;
use std::array;
use std::cmp;
use std::ffi::c_int;
Expand Down Expand Up @@ -4992,18 +4992,18 @@ fn get_upscale_x0(in_w: c_int, out_w: c_int, step: c_int) -> c_int {
x0 & 0x3fff
}

pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
pub fn rav1d_submit_frame(c: &Rav1dContext, state: &mut Rav1dState) -> Rav1dResult {
// wait for c->out_delayed[next] and move into c->out if visible
let (fc, out, _task_thread_lock) = if c.fc.len() > 1 {
let mut task_thread_lock = c.task_thread.lock.lock();
let next = c.frame_thread.next;
c.frame_thread.next = (c.frame_thread.next + 1) % c.fc.len() as u32;
let next = state.frame_thread.next;
state.frame_thread.next = (state.frame_thread.next + 1) % c.fc.len() as u32;

let fc = &c.fc[next as usize];
while !fc.task_thread.finished.load(Ordering::SeqCst) {
fc.task_thread.cond.wait(&mut task_thread_lock);
}
let out_delayed = &mut c.frame_thread.out_delayed[next as usize];
let out_delayed = &mut state.frame_thread.out_delayed[next as usize];
if out_delayed.p.data.is_some() || fc.task_thread.error.load(Ordering::SeqCst) != 0 {
let first = c.task_thread.first.load(Ordering::SeqCst);
if first as usize + 1 < c.fc.len() {
Expand All @@ -5025,32 +5025,32 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
}
let error = &mut *fc.task_thread.retval.try_lock().unwrap();
if error.is_some() {
c.cached_error = mem::take(&mut *error);
*c.cached_error_props.get_mut() = out_delayed.p.m.clone();
state.cached_error = mem::take(&mut *error);
state.cached_error_props = out_delayed.p.m.clone();
let _ = mem::take(out_delayed);
} else if out_delayed.p.data.is_some() {
let progress = out_delayed.progress.as_ref().unwrap()[1].load(Ordering::Relaxed);
if (out_delayed.visible || c.output_invisible_frames) && progress != FRAME_ERROR {
c.out = out_delayed.clone();
c.event_flags |= out_delayed.flags.into();
state.out = out_delayed.clone();
state.event_flags |= out_delayed.flags.into();
}
let _ = mem::take(out_delayed);
}
(fc, out_delayed, Some(task_thread_lock))
} else {
(&c.fc[0], &mut c.out, None)
(&c.fc[0], &mut state.out, None)
};

let mut f = fc.data.try_write().unwrap();
f.seq_hdr = c.seq_hdr.clone();
f.frame_hdr = mem::take(&mut c.frame_hdr);
f.seq_hdr = state.seq_hdr.clone();
f.frame_hdr = mem::take(&mut state.frame_hdr);
let seq_hdr = f.seq_hdr.clone().unwrap();

fn on_error(
fc: &Rav1dFrameContext,
f: &mut Rav1dFrameData,
out: &mut Rav1dThreadPicture,
cached_error_props: &Mutex<Rav1dDataProps>,
cached_error_props: &mut Rav1dDataProps,
m: &Rav1dDataProps,
) {
fc.task_thread.error.store(1, Ordering::Relaxed);
Expand All @@ -5070,7 +5070,7 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
let _ = mem::take(&mut f.mvs);
let _ = mem::take(&mut f.seq_hdr);
let _ = mem::take(&mut f.frame_hdr);
*cached_error_props.lock() = m.clone();
*cached_error_props = m.clone();

f.tiles.clear();
fc.task_thread.finished.store(true, Ordering::SeqCst);
Expand All @@ -5081,7 +5081,13 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
Some(dsp) => f.dsp = dsp,
None => {
writeln!(c.logger, "Compiled without support for {bpc}-bit decoding",);
on_error(fc, &mut f, out, &c.cached_error_props, &c.in_0.m);
on_error(
fc,
&mut f,
out,
&mut state.cached_error_props,
&state.in_0.m,
);
return Err(ENOPROTOOPT);
}
};
Expand All @@ -5095,34 +5101,53 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
if frame_hdr.frame_type.is_inter_or_switch() {
if frame_hdr.primary_ref_frame != RAV1D_PRIMARY_REF_NONE {
let pri_ref = frame_hdr.refidx[frame_hdr.primary_ref_frame as usize] as usize;
if c.refs[pri_ref].p.p.data.is_none() {
on_error(fc, &mut f, out, &c.cached_error_props, &c.in_0.m);
if state.refs[pri_ref].p.p.data.is_none() {
on_error(
fc,
&mut f,
out,
&mut state.cached_error_props,
&state.in_0.m,
);
return Err(EINVAL);
}
}
for i in 0..7 {
let refidx = frame_hdr.refidx[i] as usize;
if c.refs[refidx].p.p.data.is_none()
|| (frame_hdr.size.width[0] * 2) < c.refs[refidx].p.p.p.w
|| (frame_hdr.size.height * 2) < c.refs[refidx].p.p.p.h
|| frame_hdr.size.width[0] > c.refs[refidx].p.p.p.w * 16
|| frame_hdr.size.height > c.refs[refidx].p.p.p.h * 16
|| seq_hdr.layout != c.refs[refidx].p.p.p.layout
|| bpc != c.refs[refidx].p.p.p.bpc
if state.refs[refidx].p.p.data.is_none()
|| (frame_hdr.size.width[0] * 2) < state.refs[refidx].p.p.p.w
|| (frame_hdr.size.height * 2) < state.refs[refidx].p.p.p.h
|| frame_hdr.size.width[0] > state.refs[refidx].p.p.p.w * 16
|| frame_hdr.size.height > state.refs[refidx].p.p.p.h * 16
|| seq_hdr.layout != state.refs[refidx].p.p.p.layout
|| bpc != state.refs[refidx].p.p.p.bpc
{
for j in 0..i {
let _ = mem::take(&mut f.refp[j]);
}
on_error(fc, &mut f, out, &c.cached_error_props, &c.in_0.m);
on_error(
fc,
&mut f,
out,
&mut state.cached_error_props,
&state.in_0.m,
);
return Err(EINVAL);
}
f.refp[i] = c.refs[refidx].p.clone();
ref_coded_width[i] = c.refs[refidx].p.p.frame_hdr.as_ref().unwrap().size.width[0];
if frame_hdr.size.width[0] != c.refs[refidx].p.p.p.w
|| frame_hdr.size.height != c.refs[refidx].p.p.p.h
f.refp[i] = state.refs[refidx].p.clone();
ref_coded_width[i] = state.refs[refidx]
.p
.p
.frame_hdr
.as_ref()
.unwrap()
.size
.width[0];
if frame_hdr.size.width[0] != state.refs[refidx].p.p.p.w
|| frame_hdr.size.height != state.refs[refidx].p.p.p.h
{
f.svc[i][0].scale = scale_fac(c.refs[refidx].p.p.p.w, frame_hdr.size.width[0]);
f.svc[i][1].scale = scale_fac(c.refs[refidx].p.p.p.h, frame_hdr.size.height);
f.svc[i][0].scale = scale_fac(state.refs[refidx].p.p.p.w, frame_hdr.size.width[0]);
f.svc[i][1].scale = scale_fac(state.refs[refidx].p.p.p.h, frame_hdr.size.height);
f.svc[i][0].step = f.svc[i][0].scale + 8 >> 4;
f.svc[i][1].step = f.svc[i][1].scale + 8 >> 4;
} else {
Expand All @@ -5141,13 +5166,19 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
*fc.in_cdf.try_write().unwrap() = rav1d_cdf_thread_init_static(frame_hdr.quant.yac);
} else {
let pri_ref = frame_hdr.refidx[frame_hdr.primary_ref_frame as usize] as usize;
*fc.in_cdf.try_write().unwrap() = c.cdf[pri_ref].clone();
*fc.in_cdf.try_write().unwrap() = state.cdf[pri_ref].clone();
}
if frame_hdr.refresh_context != 0 {
let res = rav1d_cdf_thread_alloc(c.fc.len() > 1);
match res {
Err(e) => {
on_error(fc, &mut f, out, &c.cached_error_props, &c.in_0.m);
on_error(
fc,
&mut f,
out,
&mut state.cached_error_props,
&state.in_0.m,
);
return Err(e);
}
Ok(res) => {
Expand All @@ -5158,7 +5189,7 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {

// FIXME qsort so tiles are in order (for frame threading)
f.tiles.clear();
mem::swap(&mut f.tiles, &mut c.tiles);
mem::swap(&mut f.tiles, &mut state.tiles);
fc.task_thread
.finished
.store(f.tiles.is_empty(), Ordering::SeqCst);
Expand All @@ -5167,22 +5198,28 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {

// We must take itut_t35 out of the context before the call so borrowck can
// see we mutably borrow `c.itut_t35` disjointly from the task thread lock.
let itut_t35 = mem::take(&mut c.itut_t35);
let itut_t35 = mem::take(&mut state.itut_t35);
let res = rav1d_thread_picture_alloc(
&c.fc,
&c.logger,
&c.allocator,
c.content_light.clone(),
c.mastering_display.clone(),
state.content_light.clone(),
state.mastering_display.clone(),
c.output_invisible_frames,
c.max_spatial_id,
&c.frame_flags,
state.max_spatial_id,
&mut state.frame_flags,
&mut f,
bpc,
itut_t35,
);
if res.is_err() {
on_error(fc, &mut f, out, &c.cached_error_props, &c.in_0.m);
on_error(
fc,
&mut f,
out,
&mut state.cached_error_props,
&state.in_0.m,
);
return res;
}

Expand All @@ -5195,7 +5232,7 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
let res =
rav1d_picture_alloc_copy(&c.logger, &mut f.cur, frame_hdr.size.width[0], &f.sr_cur.p);
if res.is_err() {
on_error(fc, f, out, &c.cached_error_props, &c.in_0.m);
on_error(fc, f, out, &mut state.cached_error_props, &state.in_0.m);
return res;
}
} else {
Expand All @@ -5215,7 +5252,7 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
if c.fc.len() == 1 {
if frame_hdr.show_frame != 0 || c.output_invisible_frames {
*out = f.sr_cur.clone();
c.event_flags |= f.sr_cur.flags.into();
state.event_flags |= f.sr_cur.flags.into();
}
} else {
*out = f.sr_cur.clone();
Expand Down Expand Up @@ -5262,11 +5299,11 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
let ref_w = (ref_coded_width[i] + 7 >> 3) << 1;
let ref_h = (f.refp[i].p.p.h + 7 >> 3) << 1;
if ref_w == f.bw && ref_h == f.bh {
f.ref_mvs[i] = c.refs[refidx].refmvs.clone();
f.ref_mvs[i] = state.refs[refidx].refmvs.clone();
} else {
f.ref_mvs[i] = None;
}
f.refrefpoc[i] = c.refs[refidx].refpoc;
f.refrefpoc[i] = state.refs[refidx].refpoc;
}
} else {
f.ref_mvs.fill_with(Default::default);
Expand All @@ -5289,7 +5326,9 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
let ref_w = (ref_coded_width[pri_ref] + 7 >> 3) << 1;
let ref_h = (f.refp[pri_ref].p.p.h + 7 >> 3) << 1;
if ref_w == f.bw && ref_h == f.bh {
f.prev_segmap = c.refs[frame_hdr.refidx[pri_ref] as usize].segmap.clone();
f.prev_segmap = state.refs[frame_hdr.refidx[pri_ref] as usize]
.segmap
.clone();
}
}

Expand Down Expand Up @@ -5326,43 +5365,49 @@ pub fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
let refresh_frame_flags = frame_hdr.refresh_frame_flags as c_uint;
for i in 0..8 {
if refresh_frame_flags & (1 << i) != 0 {
if c.refs[i].p.p.frame_hdr.is_some() {
let _ = mem::take(&mut c.refs[i].p);
if state.refs[i].p.p.frame_hdr.is_some() {
let _ = mem::take(&mut state.refs[i].p);
}
c.refs[i].p = f.sr_cur.clone();
state.refs[i].p = f.sr_cur.clone();

if frame_hdr.refresh_context != 0 {
c.cdf[i] = f.out_cdf.clone();
state.cdf[i] = f.out_cdf.clone();
} else {
c.cdf[i] = fc.in_cdf.try_read().unwrap().clone();
state.cdf[i] = fc.in_cdf.try_read().unwrap().clone();
}

c.refs[i].segmap = f.cur_segmap.clone();
let _ = mem::take(&mut c.refs[i].refmvs);
state.refs[i].segmap = f.cur_segmap.clone();
let _ = mem::take(&mut state.refs[i].refmvs);
if !frame_hdr.allow_intrabc {
c.refs[i].refmvs = f.mvs.clone();
state.refs[i].refmvs = f.mvs.clone();
}
c.refs[i].refpoc = f.refpoc;
state.refs[i].refpoc = f.refpoc;
}
}
drop(f);

if c.fc.len() == 1 {
let res = rav1d_decode_frame(c, &fc);
if res.is_err() {
let _ = mem::take(&mut c.out);
let _ = mem::take(&mut state.out);
for i in 0..8 {
if refresh_frame_flags & (1 << i) != 0 {
if c.refs[i].p.p.frame_hdr.is_some() {
let _ = mem::take(&mut c.refs[i].p);
if state.refs[i].p.p.frame_hdr.is_some() {
let _ = mem::take(&mut state.refs[i].p);
}
let _ = mem::take(&mut c.cdf[i]);
let _ = mem::take(&mut c.refs[i].segmap);
let _ = mem::take(&mut c.refs[i].refmvs);
let _ = mem::take(&mut state.cdf[i]);
let _ = mem::take(&mut state.refs[i].segmap);
let _ = mem::take(&mut state.refs[i].refmvs);
}
}
let mut f = fc.data.try_write().unwrap();
on_error(fc, &mut f, &mut c.out, &c.cached_error_props, &c.in_0.m);
on_error(
fc,
&mut f,
&mut state.out,
&mut state.cached_error_props,
&state.in_0.m,
);
return res;
}
} else {
Expand Down
Loading

0 comments on commit 72521fa

Please sign in to comment.