diff --git a/src/refmvs.rs b/src/refmvs.rs index a4f054c26..7d2009774 100644 --- a/src/refmvs.rs +++ b/src/refmvs.rs @@ -43,6 +43,14 @@ extern "C" { col_start8: c_int, row_start8: c_int, ); + fn dav1d_load_tmvs_sse4( + rf: *const refmvs_frame, + tile_row_idx: c_int, + col_start8: c_int, + col_end8: c_int, + row_start8: c_int, + row_end8: c_int, + ); } #[cfg(all(feature = "asm", target_arch = "x86_64"))] @@ -142,6 +150,9 @@ pub struct refmvs_block(pub refmvs_block_unaligned); #[repr(C)] pub(crate) struct refmvs_frame { + /// A pointer to a [`refmvs_frame`] may be passed to a [`load_tmvs_fn`] function. + /// However, the [`Self::frm_hdr`] pointer is not accessed in such a function (see [`load_tmvs_c`]). + /// Thus, it is safe to have a pointer to [`Rav1dFrameHeader`] instead of [`Dav1dFrameHeader`] here. pub frm_hdr: *const Rav1dFrameHeader, pub iw4: c_int, pub ih4: c_int, @@ -1645,8 +1656,14 @@ unsafe fn refmvs_dsp_init_x86(c: *mut Rav1dRefmvsDSPContext) { (*c).save_tmvs = Some(dav1d_save_tmvs_ssse3); + if !flags.contains(CpuFlags::SSE41) { + return; + } + #[cfg(target_arch = "x86_64")] { + (*c).load_tmvs = Some(dav1d_load_tmvs_sse4); + if !flags.contains(CpuFlags::AVX2) { return; } diff --git a/src/x86/refmvs.asm b/src/x86/refmvs.asm index 06f555db1..d95861fa1 100644 --- a/src/x86/refmvs.asm +++ b/src/x86/refmvs.asm @@ -47,6 +47,10 @@ SECTION_RODATA 64 %endmacro %if ARCH_X86_64 +mv_proj: dw 0, 16384, 8192, 5461, 4096, 3276, 2730, 2340 + dw 2048, 1820, 1638, 1489, 1365, 1260, 1170, 1092 + dw 1024, 963, 910, 862, 819, 780, 744, 712 + dw 682, 655, 630, 606, 585, 564, 546, 528 splat_mv_shuf: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 db 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7 db 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 @@ -61,6 +65,7 @@ cond_shuf512: db 3, 3, 3, 3, 7, 7, 7, 7, 7, 7, 7, 7, 3, 3, 3, 3 save_cond0: db 0x80, 0x81, 0x82, 0x83, 0x89, 0x84, 0x00, 0x00 save_cond1: db 0x84, 0x85, 0x86, 0x87, 0x88, 0x80, 0x00, 0x00 pb_128: times 16 db 128 +pq_8192: dq 8192 save_tmvs_ssse3_table: SAVE_TMVS_TABLE 2, 16, ssse3 SAVE_TMVS_TABLE 4, 8, ssse3 @@ -329,6 +334,225 @@ cglobal splat_mv, 4, 5, 3, rr, a, bx4, bw4, bh4 RET %if ARCH_X86_64 +INIT_XMM sse4 +; refmvs_frame *rf, int tile_row_idx, +; int col_start8, int col_end8, int row_start8, int row_end8 +cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \ + stride, rp_proj, roff, troff, \ + xendi, xstarti, iw8, ih8, dst + xor r14d, r14d + cmp dword [rfq+212], 1 ; n_tile_threads + mov ih8d, [rfq+20] ; rf->ih8 + mov iw8d, [rfq+16] ; rf->iw8 + mov xstartd, xstartd + mov xendd, xendd + cmove tridxd, r14d + lea xstartid, [xstartq-8] + lea xendid, [xendq+8] + mov strideq, [rfq+184] + mov rp_projq, [rfq+176] + cmp ih8d, yendd + mov [rsp+0x30], strideq + cmovs yendd, ih8d + test xstartid, xstartid + cmovs xstartid, r14d + cmp iw8d, xendid + cmovs xendid, iw8d + mov troffq, strideq + shl troffq, 4 + imul troffq, tridxq + mov dstd, ystartd + and dstd, 15 + imul dstq, strideq + add dstq, troffq ; (16 * tridx + (ystart & 15)) * stride + lea dstq, [dstq*5] + add dstq, rp_projq + lea troffq, [troffq*5] ; 16 * tridx * stride * 5 + lea r13d, [xendq*5] + lea r12, [strideq*5] + DEFINE_ARGS rf, w5, xstart, xend, ystart, yend, h, x5, \ + _, troff, xendi, xstarti, stride5, _, dst + lea w5d, [xstartq*5] + add r7, troffq ; rp_proj + tile_row_offset + mov hd, yendd + mov [rsp+0x28], r7 + add dstq, r13 + sub w5q, r13 + sub hd, ystartd +.init_xloop_start: + mov x5q, w5q + test w5b, 1 + jz .init_2blk + mov dword [dstq+x5q], 0x80008000 + add x5q, 5 + jz .init_next_row +.init_2blk: + mov dword [dstq+x5q+0], 0x80008000 + mov dword [dstq+x5q+5], 0x80008000 + add x5q, 10 + jl .init_2blk +.init_next_row: + add dstq, stride5q + dec hd + jg .init_xloop_start + DEFINE_ARGS rf, _, xstart, xend, ystart, yend, n7, stride, \ + _, _, xendi, xstarti, stride5, _, n + mov r13d, [rfq+152] ; rf->n_mfmvs + test r13d, r13d + jz .ret + mov [rsp+0x0c], r13d + mov strideq, [rsp+0x30] + movddup m3, [pq_8192] + mov r9d, ystartd + mov [rsp+0x38], yendd + mov [rsp+0x20], xstartid + xor nd, nd + xor n7d, n7d + imul r9, strideq ; ystart * stride + mov [rsp+0x48], rfq + mov [rsp+0x18], stride5q + lea r7, [r9*5] + mov [rsp+0x24], ystartd + mov [rsp+0x00], r7 +.nloop: + DEFINE_ARGS y, off, xstart, xend, ystart, rf, n7, refsign, \ + ref, rp_ref, xendi, xstarti, _, _, n + mov rfq, [rsp+0x48] + mov refd, [rfq+56+nq*4] ; ref2cur + cmp refd, 0x80000000 + je .next_n + mov [rsp+0x40], refd + mov offq, [rsp+0x00] ; ystart * stride * 5 + movzx refd, byte [rfq+53+nq] ; rf->mfmv_ref[n] + lea refsignq, [refq-4] + mov rp_refq, [rfq+168] + movq m2, refsignq + add offq, [rp_refq+refq*8] ; r = rp_ref[ref] + row_offset + mov [rsp+0x14], nd + mov yd, ystartd +.yloop: + mov r11d, [rsp+0x24] ; ystart + mov r12d, [rsp+0x38] ; yend + mov r14d, yd + and r14d, ~7 ; y_sb_align + cmp r11d, r14d + cmovs r11d, r14d ; imax(y_sb_align, ystart) + mov [rsp+0x44], r11d ; y_proj_start + add r14d, 8 + cmp r12d, r14d + cmovs r14d, r12d ; imin(y_sb_align + 8, yend) + mov [rsp+0x3c], r14d ; y_proj_end + DEFINE_ARGS y, src, xstart, xend, frac, rf, n7, mv, \ + ref, x, xendi, mvx, mvy, rb, ref2ref + mov xd, [rsp+0x20] ; xstarti +.xloop: + lea rbd, [xq*5] + add rbq, srcq + movsx refd, byte [rbq+4] + test refd, refd + jz .next_x_bad_ref + mov rfq, [rsp+0x48] + lea r14d, [16+n7q+refq] + mov ref2refd, [rfq+r14*4] ; rf->mfmv_ref2ref[n][b_ref-1] + test ref2refd, ref2refd + jz .next_x_bad_ref + lea fracq, [mv_proj] + movzx fracd, word [fracq+ref2refq*2] + mov mvd, [rbq] + imul fracd, [rsp+0x40] ; ref2cur + pmovsxwq m0, [rbq] + movd m1, fracd + punpcklqdq m1, m1 + pmuldq m0, m1 ; mv * frac + pshufd m1, m0, q3311 + paddd m0, m3 + paddd m0, m1 + psrad m0, 14 ; offset = (xy + (xy >> 31) + 8192) >> 14 + pabsd m1, m0 + packssdw m0, m0 + psrld m1, 6 + packuswb m1, m1 + pxor m0, m2 ; offset ^ ref_sign + psignd m1, m0 ; apply_sign(abs(offset) >> 6, offset ^ refsign) + movq mvxq, m1 + lea mvyd, [mvxq+yq] ; ypos + sar mvxq, 32 + DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, \ + ref, x, xendi, mvx, ypos, rb, ref2ref + cmp yposd, [rsp+0x44] ; y_proj_start + jl .next_x_bad_pos_y + cmp yposd, [rsp+0x3c] ; y_proj_end + jge .next_x_bad_pos_y + and yposd, 15 + add mvxq, xq ; xpos + imul yposq, [rsp+0x30] ; pos = (ypos & 15) * stride + DEFINE_ARGS y, src, xstart, xend, dst, _, n7, mv, \ + ref, x, xendi, xpos, pos, rb, ref2ref + mov dstq, [rsp+0x28] ; dst = rp_proj + tile_row_offset + add posq, xposq ; pos += xpos + lea posq, [posq*5] + add dstq, posq ; dst += pos5 + jmp .write_loop_entry +.write_loop: + add rbq, 5 + cmp refb, byte [rbq+4] + jne .xloop + cmp mvd, [rbq] + jne .xloop + add dstq, 5 + inc xposd +.write_loop_entry: + mov r12d, xd + and r12d, ~7 + lea r5d, [r12-8] + cmp r5d, xstartd + cmovs r5d, xstartd ; x_proj_start + cmp xposd, r5d + jl .next_xpos + add r12d, 16 + cmp xendd, r12d + cmovs r12d, xendd ; x_proj_end + cmp xposd, r12d + jge .next_xpos + mov [dstq+0], mvd + mov byte [dstq+4], ref2refb +.next_xpos: + inc xd + cmp xd, xendid + jl .write_loop +.next_y: + DEFINE_ARGS y, src, xstart, xend, ystart, _, n7, _, _, x, xendi, _, _, _, n + add srcq, [rsp+0x18] ; stride5 + inc yd + cmp yd, [rsp+0x38] ; yend + jne .yloop + mov nd, [rsp+0x14] + mov ystartd, [rsp+0x24] +.next_n: + add n7d, 7 + inc nd + cmp nd, [rsp+0x0c] ; n_mfmvs + jne .nloop +.ret: + RET +.next_x: + DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, ref, x, xendi, _, _, rb, _ + add rbq, 5 + cmp refb, byte [rbq+4] + jne .xloop + cmp mvd, [rbq] + jne .xloop +.next_x_bad_pos_y: + inc xd + cmp xd, xendid + jl .next_x + jmp .next_y +.next_x_bad_ref: + inc xd + cmp xd, xendid + jl .xloop + jmp .next_y + INIT_YMM avx2 ; refmvs_temporal_block *rp, ptrdiff_t stride, ; refmvs_block **rr, uint8_t *ref_sign, diff --git a/src/x86/refmvs.h b/src/x86/refmvs.h index 9dafa78b1..c9978561e 100644 --- a/src/x86/refmvs.h +++ b/src/x86/refmvs.h @@ -28,6 +28,8 @@ #include "src/cpu.h" #include "src/refmvs.h" +decl_load_tmvs_fn(dav1d_load_tmvs_sse4); + decl_save_tmvs_fn(dav1d_save_tmvs_ssse3); decl_save_tmvs_fn(dav1d_save_tmvs_avx2); decl_save_tmvs_fn(dav1d_save_tmvs_avx512icl); @@ -47,7 +49,10 @@ static ALWAYS_INLINE void refmvs_dsp_init_x86(Dav1dRefmvsDSPContext *const c) { c->save_tmvs = dav1d_save_tmvs_ssse3; + if (!(flags & DAV1D_X86_CPU_FLAG_SSE41)) return; #if ARCH_X86_64 + c->load_tmvs = dav1d_load_tmvs_sse4; + if (!(flags & DAV1D_X86_CPU_FLAG_AVX2)) return; c->save_tmvs = dav1d_save_tmvs_avx2; diff --git a/tests/checkasm/refmvs.c b/tests/checkasm/refmvs.c index 7118c4fd1..4d082cb3f 100644 --- a/tests/checkasm/refmvs.c +++ b/tests/checkasm/refmvs.c @@ -29,6 +29,200 @@ #include +static inline int gen_mv(const int total_bits, int spel_bits) { + int bits = rnd() & ((1 << spel_bits) - 1); + do { + bits |= (rnd() & 1) << spel_bits; + } while (rnd() & 1 && ++spel_bits < total_bits); + // the do/while makes it relatively more likely to be close to zero (fpel) + // than far away + return rnd() & 1 ? -bits : bits; +} + +#define ARRAY_SIZE(n) (sizeof(n)/sizeof(*(n))) + +static inline int get_min_mv_val(const int idx) { + if (idx <= 9) return idx; + else if (idx <= 18) return (idx - 9) * 10; + else if (idx <= 27) return (idx - 18) * 100; + else if (idx <= 36) return (idx - 27) * 1000; + else return (idx - 36) * 10000; +} + +static inline void gen_tmv(refmvs_temporal_block *const rb, const int *ref2ref) { + rb->ref = rnd() % 7; + if (!rb->ref) return; + static const int x_prob[] = { + 26447556, 6800591, 3708783, 2198592, 1635940, 1145901, 1052602, 1261759, + 1099739, 755108, 6075404, 4355916, 3254908, 2897157, 2273676, 2154432, + 1937436, 1694818, 1466863, 10203087, 5241546, 3328819, 2187483, 1458997, + 1030842, 806863, 587219, 525024, 1858953, 422368, 114626, 16992 + }; + static const int y_prob[] = { + 33845001, 7591218, 6425971, 4115838, 4032161, 2515962, 2614601, 2343656, + 2898897, 1397254, 10125350, 5124449, 3232914, 2185499, 1608775, 1342585, + 980208, 795714, 649665, 3369250, 1298716, 486002, 279588, 235990, + 110318, 89372, 66895, 46980, 153322, 32960, 4500, 389 + }; + const int prob = rnd() % 100000000; + int acc = 0; + for (unsigned i = 0; i < ARRAY_SIZE(x_prob); i++) { + acc += x_prob[i]; + if (prob < acc) { + const int min = get_min_mv_val(i); + const int max = get_min_mv_val(i + 1); + const int val = min + rnd() % (max - min); + rb->mv.x = iclip(val * ref2ref[rb->ref], -(1 << 15), (1 << 15) - 1); + break; + } + } + acc = 0; + for (unsigned i = 0; i < ARRAY_SIZE(y_prob); i++) { + acc += y_prob[i]; + if (prob < acc) { + const int min = get_min_mv_val(i); + const int max = get_min_mv_val(i + 1); + const int val = min + rnd() % (max - min); + rb->mv.y = iclip(val * ref2ref[rb->ref], -(1 << 15), (1 << 15) - 1); + break; + } + } +} + +static inline int get_ref2cur(void) { + const int prob = rnd() % 100; + static const uint8_t ref2cur[11] = { 35, 55, 67, 73, 78, 83, 84, 87, 90, 93, 100 }; + for (int i = 0; i < 11; i++) + if (prob < ref2cur[i]) + return rnd() & 1 ? -(i + 1) : i + 1; + return 0; +} + +static inline int get_seqlen(void) { + int len = 0, max_len; + const int prob = rnd() % 100000; + // =1 =2 =3 =4 <8 =8 <16 =16 <32 =32 <48 =48 <64 =64 >64 eq240 + // 5 17 1.5 16 5 10 5 7 4 3 1.5 2 1 2 20 15 chimera blocks + // 25 38 2.5 19 3.5 5.5 2 1.87 .86 .4 .18 .2 .067 .165 .478 .28 chimera sequences + + if (prob < 25000) len = 1; // =1 5% + else if (prob < 63000) len = 2; // =2 17% + else if (prob < 65500) len = 3; // =3 1.5% + else if (prob < 84500) len = 4; // =4 16% + else if (prob < 88000) max_len = 7; // <8 5% (43.5% tot <8) + else if (prob < 93500) len = 8; // =8 10% + else if (prob < 95500) max_len = 15; // <16 5% + else if (prob < 97370) len = 16; // =16 7% + else if (prob < 98230) max_len = 31; // <32 4% + else if (prob < 98630) len = 32; // =32 3% + else if (prob < 98810) max_len = 47; // <48 1.5% + else if (prob < 99010) len = 48; // =48 2% + else if (prob < 99077) max_len = 63; // <64 1% + else if (prob < 99242) len = 64; // =64 2% + else if (prob < 99720) max_len = 239; // <240 5% + else len = 240; // =240 15% + + if (!len) len = 1 + rnd() % max_len; + return len; +} + +static inline void init_rp_ref(refmvs_frame const *const rf, + const int col_start8, const int col_end8, + const int row_start8, const int row_end8) +{ + const int col_start8i = imax(col_start8 - 8, 0); + const int col_end8i = imin(col_end8 + 8, rf->iw8); + for (int n = 0; n < rf->n_mfmvs; n++) { + refmvs_temporal_block *rp_ref = rf->rp_ref[rf->mfmv_ref[n]]; + for (int i = row_start8; i < imin(row_end8, rf->ih8); i++) { + for (int j = col_start8i; j < col_end8i;) { + refmvs_temporal_block rb; + gen_tmv(&rb, rf->mfmv_ref2ref[n]); + for (int k = get_seqlen(); k && j < col_end8i; k--, j++) + rp_ref[i * rf->iw8 + j] = rb; + } + } + } +} + +static void check_load_tmvs(const Dav1dRefmvsDSPContext *const c) { + refmvs_temporal_block *rp_ref[7] = {0}; + refmvs_temporal_block c_rp_proj[240 * 63]; + refmvs_temporal_block a_rp_proj[240 * 63]; + refmvs_frame rf = { + .rp_ref = rp_ref, + .rp_stride = 240, .iw8 = 240, .ih8 = 63, + .n_mfmvs = 3 + }; + const size_t rp_ref_sz = rf.ih8 * rf.rp_stride * sizeof(refmvs_temporal_block); + + declare_func(void, const refmvs_frame *rf, int tile_row_idx, + int col_start8, int col_end8, int row_start8, int row_end8); + + if (check_func(c->load_tmvs, "load_tmvs")) { + const int row_start8 = (rnd() & 3) << 4; + const int row_end8 = row_start8 + 16; + const int col_start8 = rnd() & 31; + const int col_end8 = rf.iw8 - (rnd() & 31); + + for (int n = 0; n < rf.n_mfmvs; n++) { + rf.mfmv_ref[n] = rnd() % 7; + rf.mfmv_ref2cur[n] = get_ref2cur(); + for (int r = 0; r < 7; r++) + rf.mfmv_ref2ref[n][r] = rnd() & 31; + } + for (int n = 0; n < rf.n_mfmvs; n++) { + refmvs_temporal_block **p_rp_ref = &rp_ref[rf.mfmv_ref[n]]; + if (!*p_rp_ref) + *p_rp_ref = malloc(rp_ref_sz); + } + init_rp_ref(&rf, 0, rf.iw8, row_start8, row_end8); + for (int i = 0; i < rf.iw8 * rf.ih8; i++) { + c_rp_proj[i].mv.n = a_rp_proj[i].mv.n = 0xdeadbeef; + c_rp_proj[i].ref = a_rp_proj[i].ref = 0xdd; + } + + rf.n_tile_threads = 1; + + rf.rp_proj = c_rp_proj; + call_ref(&rf, 0, col_start8, col_end8, row_start8, row_end8); + rf.rp_proj = a_rp_proj; + call_new(&rf, 0, col_start8, col_end8, row_start8, row_end8); + + for (int i = 0; i < rf.ih8; i++) + for (int j = 0; j < rf.iw8; j++) + if (c_rp_proj[i * rf.iw8 + j].mv.n != a_rp_proj[i * rf.iw8 + j].mv.n || + (c_rp_proj[i * rf.iw8 + j].ref != a_rp_proj[i * rf.iw8 + j].ref && + c_rp_proj[i * rf.iw8 + j].mv.n != INVALID_MV)) + { + if (fail()) { + fprintf(stderr, "[%d][%d] c_rp.mv.x = 0x%x a_rp.mv.x = 0x%x\n", + i, j, c_rp_proj[i * rf.iw8 + j].mv.x, a_rp_proj[i * rf.iw8 + j].mv.x); + fprintf(stderr, "[%d][%d] c_rp.mv.y = 0x%x a_rp.mv.y = 0x%x\n", + i, j, c_rp_proj[i * rf.iw8 + j].mv.y, a_rp_proj[i * rf.iw8 + j].mv.y); + fprintf(stderr, "[%d][%d] c_rp.ref = %u a_rp.ref = %u\n", + i, j, c_rp_proj[i * rf.iw8 + j].ref, a_rp_proj[i * rf.iw8 + j].ref); + } + } + + if (checkasm_bench_func()) { + for (int n = 0; n < rf.n_mfmvs; n++) { + rf.mfmv_ref2cur[n] = 1; + for (int r = 0; r < 7; r++) + rf.mfmv_ref2ref[n][r] = 1; + } + bench_new(&rf, 0, 0, rf.iw8, row_start8, row_end8); + } + + for (int n = 0; n < rf.n_mfmvs; n++) { + free(rp_ref[rf.mfmv_ref[n]]); + rp_ref[rf.mfmv_ref[n]] = NULL; + } + } + + report("load_tmvs"); +} + static void check_save_tmvs(const Dav1dRefmvsDSPContext *const c) { refmvs_block *rr[31]; refmvs_block r[31 * 256]; @@ -58,10 +252,10 @@ static void check_save_tmvs(const Dav1dRefmvsDSPContext *const c) { while (j + ((dav1d_block_dimensions[bs][0] + 1) >> 1) > col_end8) bs++; rr[i * 2][j * 2 + 1] = (refmvs_block) { - .mv.mv[0].x = -(rnd() & 1) * (rnd() & 8191), - .mv.mv[0].y = -(rnd() & 1) * (rnd() & 8191), - .mv.mv[1].x = -(rnd() & 1) * (rnd() & 8191), - .mv.mv[1].y = -(rnd() & 1) * (rnd() & 8191), + .mv.mv[0].x = gen_mv(14, 10), + .mv.mv[0].y = gen_mv(14, 10), + .mv.mv[1].x = gen_mv(14, 10), + .mv.mv[1].y = gen_mv(14, 10), .ref.ref = { (rnd() % 9) - 1, (rnd() % 9) - 1 }, .bs = bs }; @@ -152,6 +346,7 @@ void checkasm_check_refmvs(void) { Dav1dRefmvsDSPContext c; dav1d_refmvs_dsp_init(&c); + check_load_tmvs(&c); check_save_tmvs(&c); check_splat_mv(&c); }