Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fn dav1d_load_tmvs_sse4: backport x86_64 asm function from dav1d 1.3.0 #821

Merged
merged 4 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/refmvs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ extern "C" {
col_start8: c_int,
row_start8: c_int,
);
fn dav1d_load_tmvs_sse4(
rf: *const refmvs_frame,
tile_row_idx: c_int,
col_start8: c_int,
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
);
}

#[cfg(all(feature = "asm", target_arch = "x86_64"))]
Expand Down Expand Up @@ -142,6 +150,9 @@ pub struct refmvs_block(pub refmvs_block_unaligned);

#[repr(C)]
pub(crate) struct refmvs_frame {
/// A pointer to a `struct refmvs_frame` may be passed to a function of type `load_tmvs_fn`.
/// However, the `frm_hdr` pointer is not accessed in such a function. Thus, it is
/// safe to have a pointer to `Rav1dFrameHeader` instead of `Dav1dFrameHeader` here
fbossen marked this conversation as resolved.
Show resolved Hide resolved
pub frm_hdr: *const Rav1dFrameHeader,
pub iw4: c_int,
pub ih4: c_int,
Expand Down Expand Up @@ -1645,8 +1656,14 @@ unsafe fn refmvs_dsp_init_x86(c: *mut Rav1dRefmvsDSPContext) {

(*c).save_tmvs = Some(dav1d_save_tmvs_ssse3);

if !flags.contains(CpuFlags::SSE41) {
return;
}

#[cfg(target_arch = "x86_64")]
{
(*c).load_tmvs = Some(dav1d_load_tmvs_sse4);

if !flags.contains(CpuFlags::AVX2) {
return;
}
Expand Down
224 changes: 224 additions & 0 deletions src/x86/refmvs.asm
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ SECTION_RODATA 64
%endmacro

%if ARCH_X86_64
mv_proj: dw 0, 16384, 8192, 5461, 4096, 3276, 2730, 2340
dw 2048, 1820, 1638, 1489, 1365, 1260, 1170, 1092
dw 1024, 963, 910, 862, 819, 780, 744, 712
dw 682, 655, 630, 606, 585, 564, 546, 528
splat_mv_shuf: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3
db 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7
db 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
Expand All @@ -61,6 +65,7 @@ cond_shuf512: db 3, 3, 3, 3, 7, 7, 7, 7, 7, 7, 7, 7, 3, 3, 3, 3
save_cond0: db 0x80, 0x81, 0x82, 0x83, 0x89, 0x84, 0x00, 0x00
save_cond1: db 0x84, 0x85, 0x86, 0x87, 0x88, 0x80, 0x00, 0x00
pb_128: times 16 db 128
pq_8192: dq 8192

save_tmvs_ssse3_table: SAVE_TMVS_TABLE 2, 16, ssse3
SAVE_TMVS_TABLE 4, 8, ssse3
Expand Down Expand Up @@ -329,6 +334,225 @@ cglobal splat_mv, 4, 5, 3, rr, a, bx4, bw4, bh4
RET

%if ARCH_X86_64
INIT_XMM sse4
; refmvs_frame *rf, int tile_row_idx,
; int col_start8, int col_end8, int row_start8, int row_end8
cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \
stride, rp_proj, roff, troff, \
xendi, xstarti, iw8, ih8, dst
xor r14d, r14d
cmp dword [rfq+212], 1 ; n_tile_threads
mov ih8d, [rfq+20] ; rf->ih8
mov iw8d, [rfq+16] ; rf->iw8
mov xstartd, xstartd
mov xendd, xendd
cmove tridxd, r14d
lea xstartid, [xstartq-8]
lea xendid, [xendq+8]
mov strideq, [rfq+184]
mov rp_projq, [rfq+176]
cmp ih8d, yendd
mov [rsp+0x30], strideq
cmovs yendd, ih8d
test xstartid, xstartid
cmovs xstartid, r14d
cmp iw8d, xendid
cmovs xendid, iw8d
mov troffq, strideq
shl troffq, 4
imul troffq, tridxq
mov dstd, ystartd
and dstd, 15
imul dstq, strideq
add dstq, troffq ; (16 * tridx + (ystart & 15)) * stride
lea dstq, [dstq*5]
add dstq, rp_projq
lea troffq, [troffq*5] ; 16 * tridx * stride * 5
lea r13d, [xendq*5]
lea r12, [strideq*5]
DEFINE_ARGS rf, w5, xstart, xend, ystart, yend, h, x5, \
_, troff, xendi, xstarti, stride5, _, dst
lea w5d, [xstartq*5]
add r7, troffq ; rp_proj + tile_row_offset
mov hd, yendd
mov [rsp+0x28], r7
add dstq, r13
sub w5q, r13
sub hd, ystartd
.init_xloop_start:
mov x5q, w5q
test w5b, 1
jz .init_2blk
mov dword [dstq+x5q], 0x80008000
add x5q, 5
jz .init_next_row
.init_2blk:
mov dword [dstq+x5q+0], 0x80008000
mov dword [dstq+x5q+5], 0x80008000
add x5q, 10
jl .init_2blk
.init_next_row:
add dstq, stride5q
dec hd
jg .init_xloop_start
DEFINE_ARGS rf, _, xstart, xend, ystart, yend, n7, stride, \
_, _, xendi, xstarti, stride5, _, n
mov r13d, [rfq+152] ; rf->n_mfmvs
test r13d, r13d
jz .ret
mov [rsp+0x0c], r13d
mov strideq, [rsp+0x30]
movddup m3, [pq_8192]
mov r9d, ystartd
mov [rsp+0x38], yendd
mov [rsp+0x20], xstartid
xor nd, nd
xor n7d, n7d
imul r9, strideq ; ystart * stride
mov [rsp+0x48], rfq
mov [rsp+0x18], stride5q
lea r7, [r9*5]
mov [rsp+0x24], ystartd
mov [rsp+0x00], r7
.nloop:
DEFINE_ARGS y, off, xstart, xend, ystart, rf, n7, refsign, \
ref, rp_ref, xendi, xstarti, _, _, n
mov rfq, [rsp+0x48]
mov refd, [rfq+56+nq*4] ; ref2cur
cmp refd, 0x80000000
je .next_n
mov [rsp+0x40], refd
mov offq, [rsp+0x00] ; ystart * stride * 5
movzx refd, byte [rfq+53+nq] ; rf->mfmv_ref[n]
lea refsignq, [refq-4]
mov rp_refq, [rfq+168]
movq m2, refsignq
add offq, [rp_refq+refq*8] ; r = rp_ref[ref] + row_offset
mov [rsp+0x14], nd
mov yd, ystartd
.yloop:
mov r11d, [rsp+0x24] ; ystart
mov r12d, [rsp+0x38] ; yend
mov r14d, yd
and r14d, ~7 ; y_sb_align
cmp r11d, r14d
cmovs r11d, r14d ; imax(y_sb_align, ystart)
mov [rsp+0x44], r11d ; y_proj_start
add r14d, 8
cmp r12d, r14d
cmovs r14d, r12d ; imin(y_sb_align + 8, yend)
mov [rsp+0x3c], r14d ; y_proj_end
DEFINE_ARGS y, src, xstart, xend, frac, rf, n7, mv, \
ref, x, xendi, mvx, mvy, rb, ref2ref
mov xd, [rsp+0x20] ; xstarti
.xloop:
lea rbd, [xq*5]
add rbq, srcq
movsx refd, byte [rbq+4]
test refd, refd
jz .next_x_bad_ref
mov rfq, [rsp+0x48]
lea r14d, [16+n7q+refq]
mov ref2refd, [rfq+r14*4] ; rf->mfmv_ref2ref[n][b_ref-1]
test ref2refd, ref2refd
jz .next_x_bad_ref
lea fracq, [mv_proj]
movzx fracd, word [fracq+ref2refq*2]
mov mvd, [rbq]
imul fracd, [rsp+0x40] ; ref2cur
pmovsxwq m0, [rbq]
movd m1, fracd
punpcklqdq m1, m1
pmuldq m0, m1 ; mv * frac
pshufd m1, m0, q3311
paddd m0, m3
paddd m0, m1
psrad m0, 14 ; offset = (xy + (xy >> 31) + 8192) >> 14
pabsd m1, m0
packssdw m0, m0
psrld m1, 6
packuswb m1, m1
pxor m0, m2 ; offset ^ ref_sign
psignd m1, m0 ; apply_sign(abs(offset) >> 6, offset ^ refsign)
movq mvxq, m1
lea mvyd, [mvxq+yq] ; ypos
sar mvxq, 32
DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, \
ref, x, xendi, mvx, ypos, rb, ref2ref
cmp yposd, [rsp+0x44] ; y_proj_start
jl .next_x_bad_pos_y
cmp yposd, [rsp+0x3c] ; y_proj_end
jge .next_x_bad_pos_y
and yposd, 15
add mvxq, xq ; xpos
imul yposq, [rsp+0x30] ; pos = (ypos & 15) * stride
DEFINE_ARGS y, src, xstart, xend, dst, _, n7, mv, \
ref, x, xendi, xpos, pos, rb, ref2ref
mov dstq, [rsp+0x28] ; dst = rp_proj + tile_row_offset
add posq, xposq ; pos += xpos
lea posq, [posq*5]
add dstq, posq ; dst += pos5
jmp .write_loop_entry
.write_loop:
add rbq, 5
cmp refb, byte [rbq+4]
jne .xloop
cmp mvd, [rbq]
jne .xloop
add dstq, 5
inc xposd
.write_loop_entry:
mov r12d, xd
and r12d, ~7
lea r5d, [r12-8]
cmp r5d, xstartd
cmovs r5d, xstartd ; x_proj_start
cmp xposd, r5d
jl .next_xpos
add r12d, 16
cmp xendd, r12d
cmovs r12d, xendd ; x_proj_end
cmp xposd, r12d
jge .next_xpos
mov [dstq+0], mvd
mov byte [dstq+4], ref2refb
.next_xpos:
inc xd
cmp xd, xendid
jl .write_loop
.next_y:
DEFINE_ARGS y, src, xstart, xend, ystart, _, n7, _, _, x, xendi, _, _, _, n
add srcq, [rsp+0x18] ; stride5
inc yd
cmp yd, [rsp+0x38] ; yend
jne .yloop
mov nd, [rsp+0x14]
mov ystartd, [rsp+0x24]
.next_n:
add n7d, 7
inc nd
cmp nd, [rsp+0x0c] ; n_mfmvs
jne .nloop
.ret:
RET
.next_x:
DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, ref, x, xendi, _, _, rb, _
add rbq, 5
cmp refb, byte [rbq+4]
jne .xloop
cmp mvd, [rbq]
jne .xloop
.next_x_bad_pos_y:
inc xd
cmp xd, xendid
jl .next_x
jmp .next_y
.next_x_bad_ref:
inc xd
cmp xd, xendid
jl .xloop
jmp .next_y

INIT_YMM avx2
; refmvs_temporal_block *rp, ptrdiff_t stride,
; refmvs_block **rr, uint8_t *ref_sign,
Expand Down
5 changes: 5 additions & 0 deletions src/x86/refmvs.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "src/cpu.h"
#include "src/refmvs.h"

decl_load_tmvs_fn(dav1d_load_tmvs_sse4);

decl_save_tmvs_fn(dav1d_save_tmvs_ssse3);
decl_save_tmvs_fn(dav1d_save_tmvs_avx2);
decl_save_tmvs_fn(dav1d_save_tmvs_avx512icl);
Expand All @@ -47,7 +49,10 @@ static ALWAYS_INLINE void refmvs_dsp_init_x86(Dav1dRefmvsDSPContext *const c) {

c->save_tmvs = dav1d_save_tmvs_ssse3;

if (!(flags & DAV1D_X86_CPU_FLAG_SSE41)) return;
#if ARCH_X86_64
c->load_tmvs = dav1d_load_tmvs_sse4;

if (!(flags & DAV1D_X86_CPU_FLAG_AVX2)) return;

c->save_tmvs = dav1d_save_tmvs_avx2;
Expand Down
Loading
Loading