Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/videolan/dav1d.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenrik Gramner <gramner@twoorioles.com>2019-04-12 00:20:18 +0300
committerHenrik Gramner <henrik@gramner.com>2019-04-15 22:09:27 +0300
commitfa1b265142e1409a986f01bd7abe115b308c1028 (patch)
tree81befc8a99d235c98f9dce2a63cee0dee676591f
parent44d0de41d478b6b41a1ebbf1de012caa8d75cca0 (diff)
x86-64: Add msac_decode_symbol_adapt SSE2 asm
Also make various minor optimizations/style fixes to the MSAC C functions.
-rw-r--r--src/cdf.c2
-rw-r--r--src/cdf.h20
-rw-r--r--src/decode.c68
-rw-r--r--src/meson.build1
-rw-r--r--src/msac.c90
-rw-r--r--src/msac.h33
-rw-r--r--src/recon_tmpl.c28
-rw-r--r--src/x86/msac.asm287
-rw-r--r--tests/checkasm/checkasm.c1
-rw-r--r--tests/checkasm/checkasm.h1
-rw-r--r--tests/checkasm/msac.c115
-rw-r--r--tests/meson.build5
12 files changed, 536 insertions, 115 deletions
diff --git a/src/cdf.c b/src/cdf.c
index c608619..9aff5d7 100644
--- a/src/cdf.c
+++ b/src/cdf.c
@@ -813,7 +813,7 @@ static const uint16_t default_mv_joint_cdf[N_MV_JOINTS + 1] = {
AOM_CDF4(4096, 11264, 19328)
};
-static const uint16_t default_kf_y_mode_cdf[5][5][N_INTRA_PRED_MODES + 1] = {
+static const uint16_t default_kf_y_mode_cdf[5][5][N_INTRA_PRED_MODES + 1 + 2] = {
{
{ AOM_CDF13(15588, 17027, 19338, 20218, 20682, 21110, 21825, 23244,
24189, 28165, 29093, 30466) },
diff --git a/src/cdf.h b/src/cdf.h
index 6d95771..7512408 100644
--- a/src/cdf.h
+++ b/src/cdf.h
@@ -34,11 +34,13 @@
#include "src/ref.h"
#include "src/thread_data.h"
+/* Buffers padded to [8] or [16] for SIMD where needed. */
+
typedef struct CdfModeContext {
- uint16_t y_mode[4][N_INTRA_PRED_MODES + 1];
+ uint16_t y_mode[4][N_INTRA_PRED_MODES + 1 + 2];
uint16_t use_filter_intra[N_BS_SIZES][2];
uint16_t filter_intra[5 + 1];
- uint16_t uv_mode[2][N_INTRA_PRED_MODES][N_UV_INTRA_PRED_MODES + 1];
+ uint16_t uv_mode[2][N_INTRA_PRED_MODES][N_UV_INTRA_PRED_MODES + 1 + 1];
uint16_t angle_delta[8][8];
uint16_t filter[2][8][DAV1D_N_SWITCHABLE_FILTERS + 1];
uint16_t newmv_mode[6][2];
@@ -66,7 +68,7 @@ typedef struct CdfModeContext {
uint16_t txtp_intra[3][N_TX_SIZES][N_INTRA_PRED_MODES][N_TX_TYPES + 1];
uint16_t skip[3][2];
uint16_t skip_mode[3][2];
- uint16_t partition[N_BL_LEVELS][4][N_PARTITIONS + 1];
+ uint16_t partition[N_BL_LEVELS][4][N_PARTITIONS + 1 + 5];
uint16_t seg_pred[3][2];
uint16_t seg_id[3][DAV1D_MAX_SEGMENTS + 1];
uint16_t cfl_sign[8 + 1];
@@ -88,12 +90,12 @@ typedef struct CdfModeContext {
typedef struct CdfCoefContext {
uint16_t skip[N_TX_SIZES][13][2];
uint16_t eob_bin_16[2][2][6];
- uint16_t eob_bin_32[2][2][7];
+ uint16_t eob_bin_32[2][2][7 + 1];
uint16_t eob_bin_64[2][2][8];
uint16_t eob_bin_128[2][2][9];
- uint16_t eob_bin_256[2][2][10];
- uint16_t eob_bin_512[2][2][11];
- uint16_t eob_bin_1024[2][2][12];
+ uint16_t eob_bin_256[2][2][10 + 6];
+ uint16_t eob_bin_512[2][2][11 + 5];
+ uint16_t eob_bin_1024[2][2][12 + 4];
uint16_t eob_hi_bit[N_TX_SIZES][2][11 /*22*/][2];
uint16_t eob_base_tok[N_TX_SIZES][2][4][4];
uint16_t base_tok[N_TX_SIZES][2][41][5];
@@ -102,7 +104,7 @@ typedef struct CdfCoefContext {
} CdfCoefContext;
typedef struct CdfMvComponent {
- uint16_t classes[11 + 1];
+ uint16_t classes[11 + 1 + 4];
uint16_t class0[2];
uint16_t classN[10][2];
uint16_t class0_fp[2][4 + 1];
@@ -119,7 +121,7 @@ typedef struct CdfMvContext {
typedef struct CdfContext {
CdfModeContext m;
- uint16_t kfym[5][5][N_INTRA_PRED_MODES + 1];
+ uint16_t kfym[5][5][N_INTRA_PRED_MODES + 1 + 2];
CdfCoefContext coef;
CdfMvContext mv, dmv;
} CdfContext;
diff --git a/src/decode.c b/src/decode.c
index 25deffc..bad9b83 100644
--- a/src/decode.c
+++ b/src/decode.c
@@ -80,15 +80,15 @@ static int read_mv_component_diff(Dav1dTileContext *const t,
const Dav1dFrameContext *const f = t->f;
const int have_hp = f->frame_hdr->hp;
const int sign = dav1d_msac_decode_bool_adapt(&ts->msac, mv_comp->sign);
- const int cl = dav1d_msac_decode_symbol_adapt(&ts->msac,
- mv_comp->classes, 11);
+ const int cl = dav1d_msac_decode_symbol_adapt16(&ts->msac,
+ mv_comp->classes, 11);
int up, fp, hp;
if (!cl) {
up = dav1d_msac_decode_bool_adapt(&ts->msac, mv_comp->class0);
if (have_fp) {
- fp = dav1d_msac_decode_symbol_adapt(&ts->msac,
- mv_comp->class0_fp[up], 4);
+ fp = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+ mv_comp->class0_fp[up], 4);
hp = have_hp ? dav1d_msac_decode_bool_adapt(&ts->msac,
mv_comp->class0_hp) : 1;
} else {
@@ -101,8 +101,8 @@ static int read_mv_component_diff(Dav1dTileContext *const t,
up |= dav1d_msac_decode_bool_adapt(&ts->msac,
mv_comp->classN[n]) << n;
if (have_fp) {
- fp = dav1d_msac_decode_symbol_adapt(&ts->msac,
- mv_comp->classN_fp, 4);
+ fp = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+ mv_comp->classN_fp, 4);
hp = have_hp ? dav1d_msac_decode_bool_adapt(&ts->msac,
mv_comp->classN_hp) : 1;
} else {
@@ -119,8 +119,8 @@ static int read_mv_component_diff(Dav1dTileContext *const t,
static void read_mv_residual(Dav1dTileContext *const t, mv *const ref_mv,
CdfMvContext *const mv_cdf, const int have_fp)
{
- switch (dav1d_msac_decode_symbol_adapt(&t->ts->msac, t->ts->cdf.mv.joint,
- N_MV_JOINTS))
+ switch (dav1d_msac_decode_symbol_adapt4(&t->ts->msac, t->ts->cdf.mv.joint,
+ N_MV_JOINTS))
{
case MV_JOINT_HV:
ref_mv->y += read_mv_component_diff(t, &mv_cdf->comp[0], have_fp);
@@ -379,7 +379,7 @@ static void read_pal_plane(Dav1dTileContext *const t, Av1Block *const b,
{
Dav1dTileState *const ts = t->ts;
const Dav1dFrameContext *const f = t->f;
- const int pal_sz = b->pal_sz[pl] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const int pal_sz = b->pal_sz[pl] = dav1d_msac_decode_symbol_adapt8(&ts->msac,
ts->cdf.m.pal_sz[pl][sz_ctx], 7) + 2;
uint16_t cache[16], used_cache[8];
int l_cache = pl ? t->pal_sz_uv[1][by4] : t->l.pal_sz[by4];
@@ -595,7 +595,7 @@ static void read_pal_indices(Dav1dTileContext *const t,
const int last = imax(0, i - h4 * 4 + 1);
order_palette(pal_idx, stride, i, first, last, order, ctx);
for (int j = first, m = 0; j >= last; j--, m++) {
- const int color_idx = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const int color_idx = dav1d_msac_decode_symbol_adapt8(&ts->msac,
color_map_cdf[ctx[m]], b->pal_sz[pl]);
pal_idx[(i - j) * stride + j] = order[m][color_idx];
}
@@ -811,7 +811,7 @@ static int decode_b(Dav1dTileContext *const t,
const unsigned pred_seg_id =
get_cur_frame_segid(t->by, t->bx, have_top, have_left,
&seg_ctx, f->cur_segmap, f->b4_stride);
- const unsigned diff = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const unsigned diff = dav1d_msac_decode_symbol_adapt8(&ts->msac,
ts->cdf.m.seg_id[seg_ctx],
DAV1D_MAX_SEGMENTS);
const unsigned last_active_seg_id =
@@ -883,7 +883,7 @@ static int decode_b(Dav1dTileContext *const t,
if (b->skip) {
b->seg_id = pred_seg_id;
} else {
- const unsigned diff = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const unsigned diff = dav1d_msac_decode_symbol_adapt8(&ts->msac,
ts->cdf.m.seg_id[seg_ctx],
DAV1D_MAX_SEGMENTS);
const unsigned last_active_seg_id =
@@ -932,8 +932,8 @@ static int decode_b(Dav1dTileContext *const t,
memcpy(prev_delta_lf, ts->last_delta_lf, 4);
if (have_delta_q) {
- int delta_q = dav1d_msac_decode_symbol_adapt(&ts->msac,
- ts->cdf.m.delta_q, 4);
+ int delta_q = dav1d_msac_decode_symbol_adapt4(&ts->msac,
+ ts->cdf.m.delta_q, 4);
if (delta_q == 3) {
const int n_bits = 1 + dav1d_msac_decode_bools(&ts->msac, 3);
delta_q = dav1d_msac_decode_bools(&ts->msac, n_bits) +
@@ -953,7 +953,7 @@ static int decode_b(Dav1dTileContext *const t,
f->cur.p.layout != DAV1D_PIXEL_LAYOUT_I400 ? 4 : 2 : 1;
for (int i = 0; i < n_lfs; i++) {
- int delta_lf = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ int delta_lf = dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.delta_lf[i + f->frame_hdr->delta.lf.multi], 4);
if (delta_lf == 3) {
const int n_bits = 1 + dav1d_msac_decode_bools(&ts->msac, 3);
@@ -1018,8 +1018,8 @@ static int decode_b(Dav1dTileContext *const t,
ts->cdf.m.y_mode[dav1d_ymode_size_context[bs]] :
ts->cdf.kfym[dav1d_intra_mode_context[t->a->mode[bx4]]]
[dav1d_intra_mode_context[t->l.mode[by4]]];
- b->y_mode = dav1d_msac_decode_symbol_adapt(&ts->msac, ymode_cdf,
- N_INTRA_PRED_MODES);
+ b->y_mode = dav1d_msac_decode_symbol_adapt16(&ts->msac, ymode_cdf,
+ N_INTRA_PRED_MODES);
if (DEBUG_BLOCK_INFO)
printf("Post-ymode[%d]: r=%d\n", b->y_mode, ts->msac.rng);
@@ -1028,7 +1028,7 @@ static int decode_b(Dav1dTileContext *const t,
b->y_mode <= VERT_LEFT_PRED)
{
uint16_t *const acdf = ts->cdf.m.angle_delta[b->y_mode - VERT_PRED];
- const int angle = dav1d_msac_decode_symbol_adapt(&ts->msac, acdf, 7);
+ const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 7);
b->y_angle = angle - 3;
} else {
b->y_angle = 0;
@@ -1038,20 +1038,20 @@ static int decode_b(Dav1dTileContext *const t,
const int cfl_allowed = f->frame_hdr->segmentation.lossless[b->seg_id] ?
cbw4 == 1 && cbh4 == 1 : !!(cfl_allowed_mask & (1 << bs));
uint16_t *const uvmode_cdf = ts->cdf.m.uv_mode[cfl_allowed][b->y_mode];
- b->uv_mode = dav1d_msac_decode_symbol_adapt(&ts->msac, uvmode_cdf,
+ b->uv_mode = dav1d_msac_decode_symbol_adapt16(&ts->msac, uvmode_cdf,
N_UV_INTRA_PRED_MODES - !cfl_allowed);
if (DEBUG_BLOCK_INFO)
printf("Post-uvmode[%d]: r=%d\n", b->uv_mode, ts->msac.rng);
if (b->uv_mode == CFL_PRED) {
#define SIGN(a) (!!(a) + ((a) > 0))
- const int sign = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const int sign = dav1d_msac_decode_symbol_adapt8(&ts->msac,
ts->cdf.m.cfl_sign, 8) + 1;
const int sign_u = sign * 0x56 >> 8, sign_v = sign - sign_u * 3;
assert(sign_u == sign / 3);
if (sign_u) {
const int ctx = (sign_u == 2) * 3 + sign_v;
- b->cfl_alpha[0] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->cfl_alpha[0] = dav1d_msac_decode_symbol_adapt16(&ts->msac,
ts->cdf.m.cfl_alpha[ctx], 16) + 1;
if (sign_u == 1) b->cfl_alpha[0] = -b->cfl_alpha[0];
} else {
@@ -1059,7 +1059,7 @@ static int decode_b(Dav1dTileContext *const t,
}
if (sign_v) {
const int ctx = (sign_v == 2) * 3 + sign_u;
- b->cfl_alpha[1] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->cfl_alpha[1] = dav1d_msac_decode_symbol_adapt16(&ts->msac,
ts->cdf.m.cfl_alpha[ctx], 16) + 1;
if (sign_v == 1) b->cfl_alpha[1] = -b->cfl_alpha[1];
} else {
@@ -1073,7 +1073,7 @@ static int decode_b(Dav1dTileContext *const t,
b->uv_mode <= VERT_LEFT_PRED)
{
uint16_t *const acdf = ts->cdf.m.angle_delta[b->uv_mode - VERT_PRED];
- const int angle = dav1d_msac_decode_symbol_adapt(&ts->msac, acdf, 7);
+ const int angle = dav1d_msac_decode_symbol_adapt8(&ts->msac, acdf, 7);
b->uv_angle = angle - 3;
} else {
b->uv_angle = 0;
@@ -1113,7 +1113,7 @@ static int decode_b(Dav1dTileContext *const t,
ts->cdf.m.use_filter_intra[bs]);
if (is_filter) {
b->y_mode = FILTER_PRED;
- b->y_angle = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->y_angle = dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.filter_intra, 5);
}
if (DEBUG_BLOCK_INFO)
@@ -1156,7 +1156,7 @@ static int decode_b(Dav1dTileContext *const t,
if (f->frame_hdr->txfm_mode == DAV1D_TX_SWITCHABLE && t_dim->max > TX_4X4) {
const int tctx = get_tx_ctx(t->a, &t->l, t_dim, by4, bx4);
uint16_t *const tx_cdf = ts->cdf.m.txsz[t_dim->max - 1][tctx];
- int depth = dav1d_msac_decode_symbol_adapt(&ts->msac, tx_cdf,
+ int depth = dav1d_msac_decode_symbol_adapt4(&ts->msac, tx_cdf,
imin(t_dim->max + 1, 3));
while (depth--) {
@@ -1474,7 +1474,7 @@ static int decode_b(Dav1dTileContext *const t,
ts->tiling.col_end, ts->tiling.row_start,
ts->tiling.row_end, f->libaom_cm);
- b->inter_mode = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->inter_mode = dav1d_msac_decode_symbol_adapt8(&ts->msac,
ts->cdf.m.comp_inter_mode[ctx],
N_COMP_INTER_PRED_MODES);
if (DEBUG_BLOCK_INFO)
@@ -1583,7 +1583,7 @@ static int decode_b(Dav1dTileContext *const t,
dav1d_msac_decode_bool_adapt(&ts->msac,
ts->cdf.m.wedge_comp[ctx]);
if (b->comp_type == COMP_INTER_WEDGE)
- b->wedge_idx = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->wedge_idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
ts->cdf.m.wedge_idx[ctx], 16);
} else {
b->comp_type = COMP_INTER_SEG;
@@ -1737,7 +1737,7 @@ static int decode_b(Dav1dTileContext *const t,
dav1d_msac_decode_bool_adapt(&ts->msac,
ts->cdf.m.interintra[ii_sz_grp]))
{
- b->interintra_mode = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->interintra_mode = dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.interintra_mode[ii_sz_grp],
N_INTER_INTRA_PRED_MODES);
const int wedge_ctx = dav1d_wedge_ctx_lut[bs];
@@ -1745,7 +1745,7 @@ static int decode_b(Dav1dTileContext *const t,
dav1d_msac_decode_bool_adapt(&ts->msac,
ts->cdf.m.interintra_wedge[wedge_ctx]);
if (b->interintra_type == INTER_INTRA_WEDGE)
- b->wedge_idx = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ b->wedge_idx = dav1d_msac_decode_symbol_adapt16(&ts->msac,
ts->cdf.m.wedge_idx[wedge_ctx], 16);
} else {
b->interintra_type = INTER_INTRA_NONE;
@@ -1778,7 +1778,7 @@ static int decode_b(Dav1dTileContext *const t,
f->frame_hdr->warp_motion && (mask[0] | mask[1]);
b->motion_mode = allow_warp ?
- dav1d_msac_decode_symbol_adapt(&ts->msac,
+ dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.motion_mode[bs], 3) :
dav1d_msac_decode_bool_adapt(&ts->msac, ts->cdf.m.obmc[bs]);
if (b->motion_mode == MM_WARP) {
@@ -1817,7 +1817,7 @@ static int decode_b(Dav1dTileContext *const t,
const int comp = b->comp_type != COMP_INTER_NONE;
const int ctx1 = get_filter_ctx(t->a, &t->l, comp, 0, b->ref[0],
by4, bx4);
- filter[0] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ filter[0] = dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.filter[0][ctx1],
DAV1D_N_SWITCHABLE_FILTERS);
if (f->seq_hdr->dual_filter) {
@@ -1826,7 +1826,7 @@ static int decode_b(Dav1dTileContext *const t,
if (DEBUG_BLOCK_INFO)
printf("Post-subpel_filter1[%d,ctx=%d]: r=%d\n",
filter[0], ctx1, ts->msac.rng);
- filter[1] = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ filter[1] = dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.filter[1][ctx2],
DAV1D_N_SWITCHABLE_FILTERS);
if (DEBUG_BLOCK_INFO)
@@ -2021,7 +2021,7 @@ static int decode_sb(Dav1dTileContext *const t, const enum BlockLevel bl,
} else {
const unsigned n_part = bl == BL_8X8 ? N_SUB8X8_PARTITIONS :
bl == BL_128X128 ? N_PARTITIONS - 2 : N_PARTITIONS;
- bp = dav1d_msac_decode_symbol_adapt(&t->ts->msac, pc, n_part);
+ bp = dav1d_msac_decode_symbol_adapt16(&t->ts->msac, pc, n_part);
if (f->cur.p.layout == DAV1D_PIXEL_LAYOUT_I422 &&
(bp == PARTITION_V || bp == PARTITION_V4 ||
bp == PARTITION_T_LEFT_SPLIT || bp == PARTITION_T_RIGHT_SPLIT))
@@ -2365,7 +2365,7 @@ static void read_restoration_info(Dav1dTileContext *const t,
Dav1dTileState *const ts = t->ts;
if (frame_type == DAV1D_RESTORATION_SWITCHABLE) {
- const int filter = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const int filter = dav1d_msac_decode_symbol_adapt4(&ts->msac,
ts->cdf.m.restore_switchable, 3);
lr->type = filter ? filter == 2 ? DAV1D_RESTORATION_SGRPROJ :
DAV1D_RESTORATION_WIENER :
diff --git a/src/meson.build b/src/meson.build
index 38ab0f1..acd01ba 100644
--- a/src/meson.build
+++ b/src/meson.build
@@ -119,6 +119,7 @@ if is_asm_enabled
# NASM source files
libdav1d_sources_asm = files(
'x86/cpuid.asm',
+ 'x86/msac.asm',
)
if dav1d_bitdepths.contains('8')
diff --git a/src/msac.c b/src/msac.c
index 9e6d32b..104d2d6 100644
--- a/src/msac.c
+++ b/src/msac.c
@@ -58,8 +58,8 @@ static inline void ctx_refill(MsacContext *s) {
* necessary), and stores them back in the decoder context.
* dif: The new value of dif.
* rng: The new value of the range. */
-static inline void ctx_norm(MsacContext *s, ec_win dif, uint32_t rng) {
- const uint16_t d = 15 - (31 ^ clz(rng));
+static inline void ctx_norm(MsacContext *s, ec_win dif, unsigned rng) {
+ const int d = 15 ^ (31 ^ clz(rng));
assert(rng <= 65535U);
s->cnt -= d;
s->dif = ((dif + 1) << d) - 1; /* Shift in 1s in the LSBs */
@@ -69,18 +69,17 @@ static inline void ctx_norm(MsacContext *s, ec_win dif, uint32_t rng) {
}
unsigned dav1d_msac_decode_bool_equi(MsacContext *const s) {
- ec_win v, vw, dif = s->dif;
- uint16_t r = s->rng;
- unsigned ret;
+ ec_win vw, dif = s->dif;
+ unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r);
// When the probability is 1/2, f = 16384 >> EC_PROB_SHIFT = 256 and we can
// replace the multiply with a simple shift.
v = ((r >> 8) << 7) + EC_MIN_PROB;
- vw = v << (EC_WIN_SIZE - 16);
+ vw = (ec_win)v << (EC_WIN_SIZE - 16);
ret = dif >= vw;
dif -= ret*vw;
v += ret*(r - 2*v);
- ctx_norm(s, dif, (unsigned) v);
+ ctx_norm(s, dif, v);
return !ret;
}
@@ -88,59 +87,57 @@ unsigned dav1d_msac_decode_bool_equi(MsacContext *const s) {
* f: The probability that the bit is one
* Return: The value decoded (0 or 1). */
unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) {
- ec_win v, vw, dif = s->dif;
- uint16_t r = s->rng;
- unsigned ret;
+ ec_win vw, dif = s->dif;
+ unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r);
v = ((r >> 8) * (f >> EC_PROB_SHIFT) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB;
- vw = v << (EC_WIN_SIZE - 16);
+ vw = (ec_win)v << (EC_WIN_SIZE - 16);
ret = dif >= vw;
dif -= ret*vw;
v += ret*(r - 2*v);
- ctx_norm(s, dif, (unsigned) v);
+ ctx_norm(s, dif, v);
return !ret;
}
-unsigned dav1d_msac_decode_bools(MsacContext *const c, const unsigned l) {
- int v = 0;
- for (int n = (int) l - 1; n >= 0; n--)
- v = (v << 1) | dav1d_msac_decode_bool_equi(c);
+unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
+ unsigned v = 0;
+ while (n--)
+ v = (v << 1) | dav1d_msac_decode_bool_equi(s);
return v;
}
-int dav1d_msac_decode_subexp(MsacContext *const c, const int ref,
+int dav1d_msac_decode_subexp(MsacContext *const s, const int ref,
const int n, const unsigned k)
{
int i = 0;
int a = 0;
int b = k;
while ((2 << b) < n) {
- if (!dav1d_msac_decode_bool_equi(c)) break;
+ if (!dav1d_msac_decode_bool_equi(s)) break;
b = k + i++;
a = (1 << b);
}
- const unsigned v = dav1d_msac_decode_bools(c, b) + a;
+ const unsigned v = dav1d_msac_decode_bools(s, b) + a;
return ref * 2 <= n ? inv_recenter(ref, v) :
n - 1 - inv_recenter(n - 1 - ref, v);
}
-int dav1d_msac_decode_uniform(MsacContext *const c, const unsigned n) {
+int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
assert(n > 0);
const int l = ulog2(n) + 1;
assert(l > 1);
const unsigned m = (1 << l) - n;
- const unsigned v = dav1d_msac_decode_bools(c, l - 1);
- return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(c);
+ const unsigned v = dav1d_msac_decode_bools(s, l - 1);
+ return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
}
/* Decodes a symbol given an inverse cumulative distribution function (CDF)
* table in Q15. */
static unsigned decode_symbol(MsacContext *const s, const uint16_t *const cdf,
- const unsigned n_symbols)
+ const size_t n_symbols)
{
- ec_win u, v = s->rng, r = s->rng >> 8;
- const ec_win c = s->dif >> (EC_WIN_SIZE - 16);
- unsigned ret = 0;
+ const unsigned c = s->dif >> (EC_WIN_SIZE - 16);
+ unsigned u, v = s->rng, r = s->rng >> 8, ret = 0;
assert(!cdf[n_symbols - 1]);
@@ -153,39 +150,34 @@ static unsigned decode_symbol(MsacContext *const s, const uint16_t *const cdf,
assert(u <= s->rng);
- ctx_norm(s, s->dif - (v << (EC_WIN_SIZE - 16)), (unsigned) (u - v));
+ ctx_norm(s, s->dif - ((ec_win)v << (EC_WIN_SIZE - 16)), u - v);
return ret - 1;
}
-static void update_cdf(uint16_t *const cdf, const unsigned val,
- const unsigned n_symbols)
+unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s,
+ uint16_t *const cdf,
+ const size_t n_symbols)
{
- const unsigned count = cdf[n_symbols];
- const int rate = ((count >> 4) | 4) + (n_symbols > 3);
- unsigned i;
- for (i = 0; i < val; i++)
- cdf[i] += (32768 - cdf[i]) >> rate;
- for (; i < n_symbols - 1; i++)
- cdf[i] -= cdf[i] >> rate;
- cdf[n_symbols] = count + (count < 32);
-}
-
-unsigned dav1d_msac_decode_symbol_adapt(MsacContext *const c,
- uint16_t *const cdf,
- const unsigned n_symbols)
-{
- const unsigned val = decode_symbol(c, cdf, n_symbols);
- if(c->allow_update_cdf)
- update_cdf(cdf, val, n_symbols);
+ const unsigned val = decode_symbol(s, cdf, n_symbols);
+ if (s->allow_update_cdf) {
+ const unsigned count = cdf[n_symbols];
+ const int rate = ((count >> 4) | 4) + (n_symbols > 3);
+ unsigned i;
+ for (i = 0; i < val; i++)
+ cdf[i] += (32768 - cdf[i]) >> rate;
+ for (; i < n_symbols - 1; i++)
+ cdf[i] -= cdf[i] >> rate;
+ cdf[n_symbols] = count + (count < 32);
+ }
return val;
}
-unsigned dav1d_msac_decode_bool_adapt(MsacContext *const c,
+unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
uint16_t *const cdf)
{
- const unsigned bit = dav1d_msac_decode_bool(c, *cdf);
+ const unsigned bit = dav1d_msac_decode_bool(s, *cdf);
- if(c->allow_update_cdf){
+ if (s->allow_update_cdf) {
// update_cdf() specialized for boolean CDFs
const unsigned count = cdf[1];
const int rate = (count >> 4) | 4;
diff --git a/src/msac.h b/src/msac.h
index 91556fc..244d86f 100644
--- a/src/msac.h
+++ b/src/msac.h
@@ -38,20 +38,37 @@ typedef struct MsacContext {
const uint8_t *buf_pos;
const uint8_t *buf_end;
ec_win dif;
- uint16_t rng;
+ unsigned rng;
int cnt;
int allow_update_cdf;
} MsacContext;
-void dav1d_msac_init(MsacContext *c, const uint8_t *data, size_t sz,
+void dav1d_msac_init(MsacContext *s, const uint8_t *data, size_t sz,
int disable_cdf_update_flag);
-unsigned dav1d_msac_decode_symbol_adapt(MsacContext *s, uint16_t *cdf,
- const unsigned n_symbols);
-unsigned dav1d_msac_decode_bool_equi(MsacContext *const s);
+unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+unsigned dav1d_msac_decode_bool_equi(MsacContext *s);
unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
-unsigned dav1d_msac_decode_bools(MsacContext *c, unsigned l);
-int dav1d_msac_decode_subexp(MsacContext *c, int ref, int n, unsigned k);
-int dav1d_msac_decode_uniform(MsacContext *c, unsigned n);
+unsigned dav1d_msac_decode_bools(MsacContext *s, unsigned n);
+int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
+int dav1d_msac_decode_uniform(MsacContext *s, unsigned n);
+
+/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
+#if ARCH_X86_64 && HAVE_ASM
+unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
+#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
+#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#else
+#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
+#endif
#endif /* DAV1D_SRC_MSAC_H */
diff --git a/src/recon_tmpl.c b/src/recon_tmpl.c
index de2e0d3..0e78cca 100644
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -107,7 +107,9 @@ static int decode_coefs(Dav1dTileContext *const t,
uint16_t *const txtp_cdf = intra ?
ts->cdf.m.txtp_intra[set_idx][t_dim->min][y_mode_nofilt] :
ts->cdf.m.txtp_inter[set_idx][t_dim->min];
- idx = dav1d_msac_decode_symbol_adapt(&ts->msac, txtp_cdf, set_cnt);
+ idx = (set_cnt <= 8 ? dav1d_msac_decode_symbol_adapt8 :
+ dav1d_msac_decode_symbol_adapt16)(&ts->msac, txtp_cdf, set_cnt);
+
if (dbg)
printf("Post-txtp[%d->%d][%d->%d][%d][%d->%d]: r=%d\n",
set, set_idx, tx, t_dim->min, intra ? (int)y_mode_nofilt : -1,
@@ -122,19 +124,19 @@ static int decode_coefs(Dav1dTileContext *const t,
const enum TxClass tx_class = dav1d_tx_type_class[*txtp];
const int is_1d = tx_class != TX_CLASS_2D;
switch (tx2dszctx) {
-#define case_sz(sz, bin) \
+#define case_sz(sz, bin, ns) \
case sz: { \
uint16_t *const eob_bin_cdf = ts->cdf.coef.eob_bin_##bin[chroma][is_1d]; \
- eob_bin = dav1d_msac_decode_symbol_adapt(&ts->msac, eob_bin_cdf, 5 + sz); \
+ eob_bin = dav1d_msac_decode_symbol_adapt##ns(&ts->msac, eob_bin_cdf, 5 + sz); \
break; \
}
- case_sz(0, 16);
- case_sz(1, 32);
- case_sz(2, 64);
- case_sz(3, 128);
- case_sz(4, 256);
- case_sz(5, 512);
- case_sz(6, 1024);
+ case_sz(0, 16, 4);
+ case_sz(1, 32, 8);
+ case_sz(2, 64, 8);
+ case_sz(3, 128, 8);
+ case_sz(4, 256, 16);
+ case_sz(5, 512, 16);
+ case_sz(6, 1024, 16);
#undef case_sz
}
if (dbg)
@@ -179,8 +181,8 @@ static int decode_coefs(Dav1dTileContext *const t,
uint16_t *const lo_cdf = is_last ?
ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][ctx] :
ts->cdf.coef.base_tok[t_dim->ctx][chroma][ctx];
- int tok = dav1d_msac_decode_symbol_adapt(&ts->msac, lo_cdf,
- 4 - is_last) + is_last;
+ int tok = dav1d_msac_decode_symbol_adapt4(&ts->msac, lo_cdf,
+ 4 - is_last) + is_last;
if (dbg)
printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
@@ -190,7 +192,7 @@ static int decode_coefs(Dav1dTileContext *const t,
if (tok == 3) {
const int br_ctx = get_br_ctx(levels, rc, tx, tx_class);
do {
- const int tok_br = dav1d_msac_decode_symbol_adapt(&ts->msac,
+ const int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx], 4);
if (dbg)
printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",
diff --git a/src/x86/msac.asm b/src/x86/msac.asm
new file mode 100644
index 0000000..9f3a820
--- /dev/null
+++ b/src/x86/msac.asm
@@ -0,0 +1,287 @@
+; Copyright © 2019, VideoLAN and dav1d authors
+; Copyright © 2019, Two Orioles, LLC
+; All rights reserved.
+;
+; Redistribution and use in source and binary forms, with or without
+; modification, are permitted provided that the following conditions are met:
+;
+; 1. Redistributions of source code must retain the above copyright notice, this
+; list of conditions and the following disclaimer.
+;
+; 2. Redistributions in binary form must reproduce the above copyright notice,
+; this list of conditions and the following disclaimer in the documentation
+; and/or other materials provided with the distribution.
+;
+; THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+; ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+; WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+; DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+; ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+; (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+; ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+; (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+; SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+%include "config.asm"
+%include "ext/x86/x86inc.asm"
+
+%if ARCH_X86_64
+
+SECTION_RODATA 64 ; avoids cacheline splits
+
+dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
+pw_0xff00: times 8 dw 0xff00
+pw_32: times 8 dw 32
+
+struc msac
+ .buf: resq 1
+ .end: resq 1
+ .dif: resq 1
+ .rng: resd 1
+ .cnt: resd 1
+ .update_cdf: resd 1
+endstruc
+
+%define m(x) mangle(private_prefix %+ _ %+ x %+ SUFFIX)
+
+SECTION .text
+
+%if WIN64
+DECLARE_REG_TMP 3
+%define buf rsp+8 ; shadow space
+%else
+DECLARE_REG_TMP 0
+%define buf rsp-40 ; red zone
+%endif
+
+INIT_XMM sse2
+cglobal msac_decode_symbol_adapt4, 3, 7, 6, s, cdf, ns
+ movd m2, [sq+msac.rng]
+ movq m1, [cdfq]
+ lea rax, [pw_0xff00]
+ movq m3, [sq+msac.dif]
+ mov r3d, [sq+msac.update_cdf]
+ mov r4d, nsd
+ neg nsq
+ pshuflw m2, m2, q0000
+ movd [buf+12], m2
+ pand m2, [rax]
+ mova m0, m1
+ psrlw m1, 6
+ psllw m1, 7
+ pmulhuw m1, m2
+ movq m2, [rax+nsq*2]
+ pshuflw m3, m3, q3333
+ paddw m1, m2
+ mova [buf+16], m1
+ psubusw m1, m3
+ pxor m2, m2
+ pcmpeqw m1, m2 ; c >= v
+ pmovmskb eax, m1
+ test r3d, r3d
+ jz .renorm ; !allow_update_cdf
+
+; update_cdf:
+ movzx r3d, word [cdfq+r4*2] ; count
+ pcmpeqw m2, m2
+ mov r2d, r3d
+ shr r3d, 4
+ cmp r4d, 4
+ sbb r3d, -5 ; (count >> 4) + (n_symbols > 3) + 4
+ cmp r2d, 32
+ adc r2d, 0 ; count + (count < 32)
+ movd m3, r3d
+ pavgw m2, m1 ; i >= val ? -1 : 32768
+ psubw m2, m0 ; for (i = 0; i < val; i++)
+ psubw m0, m1 ; cdf[i] += (32768 - cdf[i]) >> rate;
+ psraw m2, m3 ; for (; i < n_symbols - 1; i++)
+ paddw m0, m2 ; cdf[i] += (( -1 - cdf[i]) >> rate) + 1;
+ movq [cdfq], m0
+ mov [cdfq+r4*2], r2w
+
+.renorm:
+ tzcnt eax, eax
+ mov r4, [sq+msac.dif]
+ movzx r1d, word [buf+rax+16] ; v
+ movzx r2d, word [buf+rax+14] ; u
+ shr eax, 1
+.renorm2:
+ not r4
+ sub r2d, r1d ; rng
+ shl r1, 48
+ add r4, r1 ; ~dif
+ mov r1d, [sq+msac.cnt]
+ movifnidn t0, sq
+ bsr ecx, r2d
+ xor ecx, 15 ; d
+ shl r2d, cl
+ shl r4, cl
+ mov [t0+msac.rng], r2d
+ not r4
+ sub r1d, ecx
+ jge .end ; no refill required
+
+; refill:
+ mov r2, [t0+msac.buf]
+ mov rcx, [t0+msac.end]
+ lea r5, [r2+8]
+ cmp r5, rcx
+ jg .refill_eob
+ mov r2, [r2]
+ lea ecx, [r1+23]
+ add r1d, 16
+ shr ecx, 3 ; shift_bytes
+ bswap r2
+ sub r5, rcx
+ shl ecx, 3 ; shift_bits
+ shr r2, cl
+ sub ecx, r1d ; shift_bits - 16 - cnt
+ mov r1d, 48
+ shl r2, cl
+ mov [t0+msac.buf], r5
+ sub r1d, ecx ; cnt + 64 - shift_bits
+ xor r4, r2
+.end:
+ mov [t0+msac.cnt], r1d
+ mov [t0+msac.dif], r4
+ RET
+.refill_eob: ; avoid overreading the input buffer
+ mov r5, rcx
+ mov ecx, 40
+ sub ecx, r1d ; c
+.refill_eob_loop:
+ cmp r2, r5
+ jge .refill_eob_end ; eob reached
+ movzx r1d, byte [r2]
+ inc r2
+ shl r1, cl
+ xor r4, r1
+ sub ecx, 8
+ jge .refill_eob_loop
+.refill_eob_end:
+ mov r1d, 40
+ sub r1d, ecx
+ mov [t0+msac.buf], r2
+ mov [t0+msac.dif], r4
+ mov [t0+msac.cnt], r1d
+ RET
+
+cglobal msac_decode_symbol_adapt8, 3, 7, 6, s, cdf, ns
+ movd m2, [sq+msac.rng]
+ movu m1, [cdfq]
+ lea rax, [pw_0xff00]
+ movq m3, [sq+msac.dif]
+ mov r3d, [sq+msac.update_cdf]
+ mov r4d, nsd
+ neg nsq
+ pshuflw m2, m2, q0000
+ movd [buf+12], m2
+ punpcklqdq m2, m2
+ mova m0, m1
+ psrlw m1, 6
+ pand m2, [rax]
+ psllw m1, 7
+ pmulhuw m1, m2
+ movu m2, [rax+nsq*2]
+ pshuflw m3, m3, q3333
+ paddw m1, m2
+ punpcklqdq m3, m3
+ mova [buf+16], m1
+ psubusw m1, m3
+ pxor m2, m2
+ pcmpeqw m1, m2
+ pmovmskb eax, m1
+ test r3d, r3d
+ jz m(msac_decode_symbol_adapt4).renorm
+ movzx r3d, word [cdfq+r4*2]
+ pcmpeqw m2, m2
+ mov r2d, r3d
+ shr r3d, 4
+ cmp r4d, 4 ; may be called with n_symbols < 4
+ sbb r3d, -5
+ cmp r2d, 32
+ adc r2d, 0
+ movd m3, r3d
+ pavgw m2, m1
+ psubw m2, m0
+ psubw m0, m1
+ psraw m2, m3
+ paddw m0, m2
+ movu [cdfq], m0
+ mov [cdfq+r4*2], r2w
+ jmp m(msac_decode_symbol_adapt4).renorm
+
+cglobal msac_decode_symbol_adapt16, 3, 7, 6, s, cdf, ns
+ movd m4, [sq+msac.rng]
+ movu m2, [cdfq]
+ lea rax, [pw_0xff00]
+ movu m3, [cdfq+16]
+ movq m5, [sq+msac.dif]
+ mov r3d, [sq+msac.update_cdf]
+ mov r4d, nsd
+ neg nsq
+%if WIN64
+ sub rsp, 48 ; need 36 bytes, shadow space is only 32
+%endif
+ pshuflw m4, m4, q0000
+ movd [buf-4], m4
+ punpcklqdq m4, m4
+ mova m0, m2
+ psrlw m2, 6
+ mova m1, m3
+ psrlw m3, 6
+ pand m4, [rax]
+ psllw m2, 7
+ psllw m3, 7
+ pmulhuw m2, m4
+ pmulhuw m3, m4
+ movu m4, [rax+nsq*2]
+ pshuflw m5, m5, q3333
+ paddw m2, m4
+ psubw m4, [rax-pw_0xff00+pw_32]
+ punpcklqdq m5, m5
+ paddw m3, m4
+ mova [buf], m2
+ mova [buf+16], m3
+ psubusw m2, m5
+ psubusw m3, m5
+ pxor m4, m4
+ pcmpeqw m2, m4
+ pcmpeqw m3, m4
+ packsswb m5, m2, m3
+ pmovmskb eax, m5
+ test r3d, r3d
+ jz .renorm
+ movzx r3d, word [cdfq+r4*2]
+ pcmpeqw m4, m4
+ mova m5, m4
+ lea r2d, [r3+80] ; only support n_symbols >= 4
+ shr r2d, 4
+ cmp r3d, 32
+ adc r3d, 0
+ pavgw m4, m2
+ pavgw m5, m3
+ psubw m4, m0
+ psubw m0, m2
+ movd m2, r2d
+ psubw m5, m1
+ psubw m1, m3
+ psraw m4, m2
+ psraw m5, m2
+ paddw m0, m4
+ paddw m1, m5
+ movu [cdfq], m0
+ movu [cdfq+16], m1
+ mov [cdfq+r4*2], r3w
+.renorm:
+ tzcnt eax, eax
+ mov r4, [sq+msac.dif]
+ movzx r1d, word [buf+rax*2]
+ movzx r2d, word [buf+rax*2-2]
+%if WIN64
+ add rsp, 48
+%endif
+ jmp m(msac_decode_symbol_adapt4).renorm2
+
+%endif
diff --git a/tests/checkasm/checkasm.c b/tests/checkasm/checkasm.c
index c9908b8..f852f4f 100644
--- a/tests/checkasm/checkasm.c
+++ b/tests/checkasm/checkasm.c
@@ -62,6 +62,7 @@ static const struct {
const char *name;
void (*func)(void);
} tests[] = {
+ { "msac", checkasm_check_msac },
#if CONFIG_8BPC
{ "cdef_8bpc", checkasm_check_cdef_8bpc },
{ "ipred_8bpc", checkasm_check_ipred_8bpc },
diff --git a/tests/checkasm/checkasm.h b/tests/checkasm/checkasm.h
index 7adc40c..a2e53ea 100644
--- a/tests/checkasm/checkasm.h
+++ b/tests/checkasm/checkasm.h
@@ -57,6 +57,7 @@ int xor128_rand(void);
name##_8bpc(void); \
name##_16bpc(void)
+void checkasm_check_msac(void);
decl_check_bitfns(void checkasm_check_cdef);
decl_check_bitfns(void checkasm_check_ipred);
decl_check_bitfns(void checkasm_check_itx);
diff --git a/tests/checkasm/msac.c b/tests/checkasm/msac.c
new file mode 100644
index 0000000..49551bb
--- /dev/null
+++ b/tests/checkasm/msac.c
@@ -0,0 +1,115 @@
+/*
+ * Copyright © 2019, VideoLAN and dav1d authors
+ * Copyright © 2019, Two Orioles, LLC
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "tests/checkasm/checkasm.h"
+
+#include "src/cpu.h"
+#include "src/msac.h"
+
+#include <string.h>
+
+/* The normal code doesn't use function pointers */
+typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+
+typedef struct {
+ decode_symbol_adapt_fn symbol_adapt4;
+ decode_symbol_adapt_fn symbol_adapt8;
+ decode_symbol_adapt_fn symbol_adapt16;
+} MsacDSPContext;
+
+static void randomize_cdf(uint16_t *const cdf, int n) {
+ for (int i = 16; i > n; i--)
+ cdf[i] = rnd(); /* randomize padding */
+ cdf[n] = cdf[n-1] = 0;
+ while (--n > 0)
+ cdf[n-1] = cdf[n] + rnd() % (32768 - cdf[n] - n) + 1;
+}
+
+/* memcmp() on structs can have weird behavior due to padding etc. */
+static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
+ return a->buf_pos != b->buf_pos || a->buf_end != b->buf_end ||
+ a->dif != b->dif || a->rng != b->rng || a->cnt != b->cnt ||
+ a->allow_update_cdf != b->allow_update_cdf;
+}
+
+#define CHECK_SYMBOL_ADAPT(n, n_min, n_max) do { \
+ if (check_func(c->symbol_adapt##n, "msac_decode_symbol_adapt%d", n)) { \
+ for (int cdf_update = 0; cdf_update <= 1; cdf_update++) { \
+ for (int ns = n_min; ns <= n_max; ns++) { \
+ dav1d_msac_init(&s_c, buf, sizeof(buf), !cdf_update); \
+ s_a = s_c; \
+ randomize_cdf(cdf[0], ns); \
+ memcpy(cdf[1], cdf[0], sizeof(*cdf)); \
+ for (int i = 0; i < 64; i++) { \
+ unsigned c_res = call_ref(&s_c, cdf[0], ns); \
+ unsigned a_res = call_new(&s_a, cdf[1], ns); \
+ if (c_res != a_res || msac_cmp(&s_c, &s_a) || \
+ memcmp(cdf[0], cdf[1], sizeof(**cdf) * (ns + 1))) \
+ { \
+ fail(); \
+ } \
+ } \
+ if (cdf_update && ns == n) \
+ bench_new(&s_a, cdf[0], n); \
+ } \
+ } \
+ } \
+} while (0)
+
+static void check_decode_symbol_adapt(MsacDSPContext *const c) {
+ /* Use an aligned CDF buffer for more consistent benchmark
+ * results, and a misaligned one for checking correctness. */
+ ALIGN_STK_16(uint16_t, cdf, 2, [17]);
+ MsacContext s_c, s_a;
+ uint8_t buf[1024];
+ for (int i = 0; i < 1024; i++)
+ buf[i] = rnd();
+
+ declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
+ CHECK_SYMBOL_ADAPT( 4, 1, 5);
+ CHECK_SYMBOL_ADAPT( 8, 1, 8);
+ CHECK_SYMBOL_ADAPT(16, 4, 16);
+ report("decode_symbol_adapt");
+}
+
+void checkasm_check_msac(void) {
+ MsacDSPContext c;
+ c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
+ c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
+ c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
+
+#if ARCH_X86_64 && HAVE_ASM
+ if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {
+ c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
+ c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
+ c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+ }
+#endif
+
+ check_decode_symbol_adapt(&c);
+}
diff --git a/tests/meson.build b/tests/meson.build
index dc0cc10..b1dce29 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -34,7 +34,10 @@ endif
libdav1d_nasm_objs_if_needed = []
if is_asm_enabled
- checkasm_sources = files('checkasm/checkasm.c')
+ checkasm_sources = files(
+ 'checkasm/checkasm.c',
+ 'checkasm/msac.c',
+ )
checkasm_tmpl_sources = files(
'checkasm/cdef.c',