diff options
author | Henrik Gramner <gramner@twoorioles.com> | 2019-08-10 15:34:59 +0300 |
---|---|---|
committer | Henrik Gramner <henrik@gramner.com> | 2019-08-13 19:51:49 +0300 |
commit | 61dcd11ba7364a0c6a9cef49e7b99722feee133e (patch) | |
tree | 247147d86efbdd95a7541c9835cc7aa4561cb12f | |
parent | e29fd5c0016fec27c88a36ac6f6eaaf416d91330 (diff) |
x86: Add an msac function for coefficient hi_tok decoding
This particular sequence is executed often enough to justify having
a separate slightly more optimized code path instead of just chaining
multiple generic symbol decoding function calls together.
-rw-r--r-- | src/msac.c | 16 | ||||
-rw-r--r-- | src/msac.h | 4 | ||||
-rw-r--r-- | src/recon_tmpl.c | 141 | ||||
-rw-r--r-- | src/x86/msac.asm | 176 | ||||
-rw-r--r-- | src/x86/msac.h | 2 | ||||
-rw-r--r-- | tests/checkasm/msac.c | 37 |
6 files changed, 242 insertions, 134 deletions
@@ -171,6 +171,22 @@ unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *const s, return bit; } +unsigned dav1d_msac_decode_hi_tok_c(MsacContext *const s, uint16_t *const cdf) { + unsigned tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3); + unsigned tok = 3 + tok_br; + if (tok_br == 3) { + tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3); + tok = 6 + tok_br; + if (tok_br == 3) { + tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3); + tok = 9 + tok_br; + if (tok_br == 3) + tok = 12 + dav1d_msac_decode_symbol_adapt4(s, cdf, 3); + } + } + return tok; +} + void dav1d_msac_init(MsacContext *const s, const uint8_t *const data, const size_t sz, const int disable_cdf_update_flag) { @@ -58,6 +58,7 @@ unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf, unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf); unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s); unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f); +unsigned dav1d_msac_decode_hi_tok_c(MsacContext *s, uint16_t *cdf); int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k); /* Supported n_symbols ranges: adapt4: 1-4, adapt8: 1-7, adapt16: 3-15 */ @@ -79,6 +80,9 @@ int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k); #ifndef dav1d_msac_decode_bool #define dav1d_msac_decode_bool dav1d_msac_decode_bool_c #endif +#ifndef dav1d_msac_decode_hi_tok +#define dav1d_msac_decode_hi_tok dav1d_msac_decode_hi_tok_c +#endif static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) { unsigned v = 0; diff --git a/src/recon_tmpl.c b/src/recon_tmpl.c index 71e7e80..22caa53 100644 --- a/src/recon_tmpl.c +++ b/src/recon_tmpl.c @@ -199,40 +199,13 @@ static int decode_coefs(Dav1dTileContext *const t, printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng); - // hi tok if (tok_br == 2) { -#define dbg_print_hi_tok(i, tok, tok_br) \ - if (dbg)\ - printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",\ - imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok, tok_br,\ - ts->msac.rng) const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride); - - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], 3); - tok = 3 + tok_br; - dbg_print_hi_tok(eob, tok + tok_br, tok_br); - - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], 3); - tok = 6 + tok_br; - dbg_print_hi_tok(eob, tok + tok_br, tok_br); - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], - 3); - tok = 9 + tok_br; - dbg_print_hi_tok(eob, tok + tok_br, tok_br); - if (tok_br == 3) { - tok = 12 + - dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], - 3); - dbg_print_hi_tok(eob, tok + tok_br, tok_br); - } - } - } + tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]); + if (dbg) + printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", + imin(t_dim->ctx, 3), chroma, br_ctx, eob, rc, tok, + ts->msac.rng); } cf[rc] = tok; @@ -249,37 +222,14 @@ static int decode_coefs(Dav1dTileContext *const t, printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng); - // hi tok if (tok == 3) { const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride); - - int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], 3); - tok = 3 + tok_br; - dbg_print_hi_tok(i, tok + tok_br, tok_br); - - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], 3); - - tok = 6 + tok_br; - dbg_print_hi_tok(i, tok + tok_br, tok_br); - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], - 3); - tok = 9 + tok_br; - dbg_print_hi_tok(i, tok + tok_br, tok_br); - if (tok_br == 3) { - tok = 12 + dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], - 3); - dbg_print_hi_tok(i, tok + tok_br, tok_br); - } - } - } + tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]); + if (dbg) + printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n", + imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok, + ts->msac.rng); } -#undef dbg_print_hi_tok cf[rc] = tok; levels[x * stride + y] = (uint8_t) tok; } @@ -292,43 +242,13 @@ static int decode_coefs(Dav1dTileContext *const t, printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n", t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng); - // hi tok if (dc_tok == 3) { -#define dbg_print_hi_tok(dc_tok, tok_br) \ - if (dbg) \ - printf("Post-dc_hi_tok[%d][%d][%d][%d->%d]: r=%d\n", \ - imin(t_dim->ctx, 3), chroma, br_ctx, tok_br, dc_tok, ts->msac.rng); - const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride); - - int tok_br = - dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[br_ctx], 3); - dc_tok = 3 + tok_br; - - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], 3); - dc_tok = 6 + tok_br; - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], - 3); - dc_tok = 9 + tok_br; - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - if (tok_br == 3) { - dc_tok = 12 + - dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[br_ctx], - 3); - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - } - } - } + dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]); + if (dbg) + printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n", + imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng); } -#undef dbg_print_hi_tok } } else { // dc-only uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0]; @@ -338,38 +258,13 @@ static int decode_coefs(Dav1dTileContext *const t, printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n", t_dim->ctx, chroma, 0, dc_tok, ts->msac.rng); - // hi tok if (tok_br == 2) { -#define dbg_print_hi_tok(dc_tok, tok_br) \ - if (dbg) \ - printf("Post-dc_hi_tok[%d][%d][0][%d->%d]: r=%d\n", \ - imin(t_dim->ctx, 3), chroma, tok_br, dc_tok, ts->msac.rng); - - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3); - dc_tok = 3 + tok_br; - - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3); - dc_tok = 6 + tok_br; - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - if (tok_br == 3) { - tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[0], 3); - dc_tok = 9 + tok_br; - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - if (tok_br == 3) { - dc_tok = 12 + - dav1d_msac_decode_symbol_adapt4(&ts->msac, - br_cdf[0], 3); - dbg_print_hi_tok(dc_tok + tok_br, tok_br); - } - } - } + dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[0]); + if (dbg) + printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n", + imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng); } } -#undef dbg_print_hi_tok // residual and sign int dc_sign = 1 << 6; diff --git a/src/x86/msac.asm b/src/x86/msac.asm index b896f74..e974ba0 100644 --- a/src/x86/msac.asm +++ b/src/x86/msac.asm @@ -27,7 +27,7 @@ SECTION_RODATA 64 ; avoids cacheline splits -dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0 +min_prob: dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0 pw_0xff00: times 8 dw 0xff00 pw_32: times 8 dw 32 @@ -35,21 +35,24 @@ pw_32: times 8 dw 32 %define resp resq %define movp movq %define c_shuf q3333 -%define DECODE_SYMBOL_ADAPT_INIT +%macro DECODE_SYMBOL_ADAPT_INIT 0-1 +%endmacro %else %define resp resd %define movp movd %define c_shuf q1111 -%macro DECODE_SYMBOL_ADAPT_INIT 0 +%macro DECODE_SYMBOL_ADAPT_INIT 0-1 0 ; hi_tok mov t0, r0m mov t1, r1m +%if %1 == 0 mov t2, r2m +%endif %if STACK_ALIGNMENT >= 16 - sub esp, 40 + sub esp, 40-%1*4 %else mov eax, esp and esp, ~15 - sub esp, 40 + sub esp, 40-%1*4 mov [esp], eax %endif %endmacro @@ -69,13 +72,13 @@ endstruc SECTION .text %if WIN64 -DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3 -%define buf rsp+8 ; shadow space +DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3, 8 +%define buf rsp+stack_offset+8 ; shadow space %elif UNIX64 -DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0 +DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0, 8 %define buf rsp-40 ; red zone %else -DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2 +DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2, 3 %define buf esp+8 %endif @@ -440,3 +443,158 @@ cglobal msac_decode_bool, 0, 6, 0 movzx eax, al %endif jmp m(msac_decode_symbol_adapt4).renorm3 + +%macro HI_TOK 1 ; update_cdf +%if ARCH_X86_64 == 0 + mov eax, -24 +%endif +%%loop: +%if %1 + movzx t2d, word [t1+3*2] +%endif + mova m1, m0 + pshuflw m2, m2, q0000 + psrlw m1, 6 + movd [buf+12], m2 + pand m2, m4 + psllw m1, 7 + pmulhuw m1, m2 +%if ARCH_X86_64 == 0 + add eax, 5 + mov [buf+8], eax +%endif + pshuflw m3, m3, c_shuf + paddw m1, m5 + movq [buf+16], m1 + psubusw m1, m3 + pxor m2, m2 + pcmpeqw m1, m2 + pmovmskb eax, m1 +%if %1 + lea ecx, [t2+80] + pcmpeqw m2, m2 + shr ecx, 4 + cmp t2d, 32 + adc t2d, 0 + movd m3, ecx + pavgw m2, m1 + psubw m2, m0 + psubw m0, m1 + psraw m2, m3 + paddw m0, m2 + movq [t1], m0 + mov [t1+3*2], t2w +%endif + tzcnt eax, eax + movzx ecx, word [buf+rax+16] + movzx t2d, word [buf+rax+14] + not t4 +%if ARCH_X86_64 + add t6d, 5 +%endif + sub eax, 5 ; setup for merging the tok_br and tok branches + sub t2d, ecx + shl rcx, gprsize*8-16 + add t4, rcx + bsr ecx, t2d + xor ecx, 15 + shl t2d, cl + shl t4, cl + movd m2, t2d + mov [t7+msac.rng], t2d + not t4 + sub t5d, ecx + jge %%end + mov t2, [t7+msac.buf] + mov rcx, [t7+msac.end] +%if UNIX64 == 0 + push t8 +%endif + lea t8, [t2+gprsize] + cmp t8, rcx + ja %%refill_eob + mov t2, [t2] + lea ecx, [t5+23] + add t5d, 16 + shr ecx, 3 + bswap t2 + sub t8, rcx + shl ecx, 3 + shr t2, cl + sub ecx, t5d + mov t5d, gprsize*8-16 + shl t2, cl + mov [t7+msac.buf], t8 +%if UNIX64 == 0 + pop t8 +%endif + sub t5d, ecx + xor t4, t2 +%%end: + movp m3, t4 +%if ARCH_X86_64 + add t6d, eax ; CF = tok_br < 3 || tok == 15 + jnc %%loop + lea eax, [t6+30] +%else + add eax, [buf+8] + jnc %%loop + add eax, 30 +%if STACK_ALIGNMENT >= 16 + add esp, 36 +%else + mov esp, [esp] +%endif +%endif + mov [t7+msac.dif], t4 + shr eax, 1 + mov [t7+msac.cnt], t5d + RET +%%refill_eob: + mov t8, rcx + mov ecx, gprsize*8-24 + sub ecx, t5d +%%refill_eob_loop: + cmp t2, t8 + jae %%refill_eob_end + movzx t5d, byte [t2] + inc t2 + shl t5, cl + xor t4, t5 + sub ecx, 8 + jge %%refill_eob_loop +%%refill_eob_end: +%if UNIX64 == 0 + pop t8 +%endif + mov t5d, gprsize*8-24 + mov [t7+msac.buf], t2 + sub t5d, ecx + jmp %%end +%endmacro + +cglobal msac_decode_hi_tok, 0, 7 + ARCH_X86_64, 6 + DECODE_SYMBOL_ADAPT_INIT 1 +%if ARCH_X86_64 == 0 && PIC + LEA t2, min_prob+12*2 + %define base t2-(min_prob+12*2) +%else + %define base 0 +%endif + movq m0, [t1] + movd m2, [t0+msac.rng] + mov eax, [t0+msac.update_cdf] + movq m4, [base+pw_0xff00] + movp m3, [t0+msac.dif] + movq m5, [base+min_prob+12*2] + mov t4, [t0+msac.dif] + mov t5d, [t0+msac.cnt] +%if ARCH_X86_64 + mov t6d, -24 +%endif + movifnidn t7, t0 + test eax, eax + jz .no_update_cdf + HI_TOK 1 +.no_update_cdf: + HI_TOK 0 diff --git a/src/x86/msac.h b/src/x86/msac.h index 3d8b76f..48954c1 100644 --- a/src/x86/msac.h +++ b/src/x86/msac.h @@ -37,11 +37,13 @@ unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf, unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf); unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s); unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f); +unsigned dav1d_msac_decode_hi_tok_sse2(MsacContext *s, uint16_t *cdf); #if ARCH_X86_64 || defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2) #define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2 #define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2 +#define dav1d_msac_decode_hi_tok dav1d_msac_decode_hi_tok_sse2 #endif #define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_sse2 diff --git a/tests/checkasm/msac.c b/tests/checkasm/msac.c index 3b75055..c1681d7 100644 --- a/tests/checkasm/msac.c +++ b/tests/checkasm/msac.c @@ -38,7 +38,7 @@ /* The normal code doesn't use function pointers */ typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf, size_t n_symbols); -typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf); +typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf); typedef unsigned (*decode_bool_equi_fn)(MsacContext *s); typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f); @@ -46,9 +46,10 @@ typedef struct { decode_symbol_adapt_fn symbol_adapt4; decode_symbol_adapt_fn symbol_adapt8; decode_symbol_adapt_fn symbol_adapt16; - decode_bool_adapt_fn bool_adapt; + decode_adapt_fn bool_adapt; decode_bool_equi_fn bool_equi; decode_bool_fn bool; + decode_adapt_fn hi_tok; } MsacDSPContext; static void randomize_cdf(uint16_t *const cdf, const int n) { @@ -199,6 +200,35 @@ static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) { report("decode_bool"); } +static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) { + ALIGN_STK_16(uint16_t, cdf, 2, [16]); + MsacContext s_c, s_a; + + if (check_func(c->hi_tok, "msac_decode_hi_tok")) { + declare_func(unsigned, MsacContext *s, uint16_t *cdf); + for (int cdf_update = 0; cdf_update <= 1; cdf_update++) { + dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update); + s_a = s_c; + randomize_cdf(cdf[0], 3); + memcpy(cdf[1], cdf[0], sizeof(*cdf)); + for (int i = 0; i < 64; i++) { + unsigned c_res = call_ref(&s_c, cdf[0]); + unsigned a_res = call_new(&s_a, cdf[1]); + if (c_res != a_res || msac_cmp(&s_c, &s_a) || + memcmp(cdf[0], cdf[1], sizeof(*cdf))) + { + if (fail()) + msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3); + break; + } + } + if (cdf_update) + bench_new(&s_a, cdf[1]); + } + } + report("decode_hi_tok"); +} + void checkasm_check_msac(void) { MsacDSPContext c; c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c; @@ -207,6 +237,7 @@ void checkasm_check_msac(void) { c.bool_adapt = dav1d_msac_decode_bool_adapt_c; c.bool_equi = dav1d_msac_decode_bool_equi_c; c.bool = dav1d_msac_decode_bool_c; + c.hi_tok = dav1d_msac_decode_hi_tok_c; #if ARCH_AARCH64 && HAVE_ASM if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) { @@ -225,6 +256,7 @@ void checkasm_check_msac(void) { c.bool_adapt = dav1d_msac_decode_bool_adapt_sse2; c.bool_equi = dav1d_msac_decode_bool_equi_sse2; c.bool = dav1d_msac_decode_bool_sse2; + c.hi_tok = dav1d_msac_decode_hi_tok_sse2; } #endif @@ -234,4 +266,5 @@ void checkasm_check_msac(void) { check_decode_symbol(&c, buf); check_decode_bool(&c, buf); + check_decode_hi_tok(&c, buf); } |