diff options
author | Rupert Swarbrick <rupert.swarbrick@argondesign.com> | 2018-10-17 19:49:35 +0300 |
---|---|---|
committer | Ronald S. Bultje <rsbultje@gmail.com> | 2018-11-14 17:39:21 +0300 |
commit | c59f19405362091741f441ff1a98810955a56a3f (patch) | |
tree | 8245b36af185c15953c3bca3ecc7466858984aca | |
parent | 2532642bbbdfcc77140846e1403a6b393eaba974 (diff) |
Correctly flush at the end of OBUs
This fixes failures when an OBU has more than a byte's worth of
trailing zeros.
As part of this work, it also rejigs the dav1d_flush_get_bits function
slightly. This worked before, but it wasn't very obvious why (it
worked because bits_left was never more than 7). This patch renames it
to dav1d_bytealign_get_bits, which makes it clearer what it does and
adds a comment explaining why it works properly.
The new dav1d_bytealign_get_bits is also now void (rather than
returning the next byte to read). The patch defines
dav1d_get_bits_pos, which returns the current bit position. This feels
a little easier to reason about.
We also add a new check to make sure that we haven't fallen off the
end of the OBU. This can happen when a byte buffer contains more than
one OBU: the GetBits might not have got to EOF, but we might now be
half-way through the next OBU.
-rw-r--r-- | src/getbits.c | 12 | ||||
-rw-r--r-- | src/getbits.h | 9 | ||||
-rw-r--r-- | src/obu.c | 148 |
3 files changed, 120 insertions, 49 deletions
diff --git a/src/getbits.c b/src/getbits.c index fe7d4b5..0a34601 100644 --- a/src/getbits.c +++ b/src/getbits.c @@ -126,8 +126,16 @@ int dav1d_get_bits_subexp(GetBits *const c, const int ref, const unsigned n) { return (int) get_bits_subexp_u(c, ref + (1 << n), 2 << n) - (1 << n); } -const uint8_t *dav1d_flush_get_bits(GetBits *c) { +void dav1d_bytealign_get_bits(GetBits *c) { + // bits_left is never more than 7, because it is only incremented + // by refill(), called by dav1d_get_bits and that never reads more + // than 7 bits more than it needs. + // + // If this wasn't true, we would need to work out how many bits to + // discard (bits_left % 8), subtract that from bits_left and then + // shift state right by that amount. + assert(c->bits_left <= 7); + c->bits_left = 0; c->state = 0; - return c->ptr; } diff --git a/src/getbits.h b/src/getbits.h index 6a59db2..d96810a 100644 --- a/src/getbits.h +++ b/src/getbits.h @@ -46,6 +46,13 @@ int dav1d_get_sbits(GetBits *c, unsigned n); unsigned dav1d_get_uniform(GetBits *c, unsigned max); unsigned dav1d_get_vlc(GetBits *c); int dav1d_get_bits_subexp(GetBits *c, int ref, unsigned n); -const uint8_t *dav1d_flush_get_bits(GetBits *c); + +// Discard bits from the buffer until we're next byte-aligned. +void dav1d_bytealign_get_bits(GetBits *c); + +// Return the current bit position relative to the start of the buffer. +static inline unsigned dav1d_get_bits_pos(const GetBits *c) { + return (c->ptr - c->ptr_start) * 8 - c->bits_left; +} #endif /* __DAV1D_SRC_GETBITS_H__ */ @@ -46,15 +46,17 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, Av1SequenceHeader *const hdr) { - const uint8_t *const init_ptr = gb->ptr; - #define DEBUG_SEQ_HDR 0 +#if DEBUG_SEQ_HDR + const unsigned init_bit_pos = dav1d_get_bits_pos(gb); +#endif + hdr->profile = dav1d_get_bits(gb, 3); if (hdr->profile > 2) goto error; #if DEBUG_SEQ_HDR printf("SEQHDR: post-profile: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif hdr->still_picture = dav1d_get_bits(gb, 1); @@ -62,7 +64,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, if (hdr->reduced_still_picture_header && !hdr->still_picture) goto error; #if DEBUG_SEQ_HDR printf("SEQHDR: post-stillpicture_flags: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif if (hdr->reduced_still_picture_header) { @@ -97,7 +99,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, } #if DEBUG_SEQ_HDR printf("SEQHDR: post-timinginfo: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif hdr->display_model_info_present = dav1d_get_bits(gb, 1); @@ -126,7 +128,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, } #if DEBUG_SEQ_HDR printf("SEQHDR: post-operating-points: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif } @@ -136,7 +138,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, hdr->max_height = dav1d_get_bits(gb, hdr->height_n_bits) + 1; #if DEBUG_SEQ_HDR printf("SEQHDR: post-size: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif hdr->frame_id_numbers_present = hdr->reduced_still_picture_header ? 0 : dav1d_get_bits(gb, 1); @@ -146,7 +148,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, } #if DEBUG_SEQ_HDR printf("SEQHDR: post-frame-id-numbers-present: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif hdr->sb128 = dav1d_get_bits(gb, 1); @@ -180,7 +182,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, hdr->screen_content_tools = dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1); #if DEBUG_SEQ_HDR printf("SEQHDR: post-screentools: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif hdr->force_integer_mv = hdr->screen_content_tools ? dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1) : 2; @@ -192,7 +194,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, hdr->restoration = dav1d_get_bits(gb, 1); #if DEBUG_SEQ_HDR printf("SEQHDR: post-featurebits: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif const int hbd = dav1d_get_bits(gb, 1); @@ -243,18 +245,22 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, } #if DEBUG_SEQ_HDR printf("SEQHDR: post-colorinfo: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif hdr->film_grain_present = dav1d_get_bits(gb, 1); #if DEBUG_SEQ_HDR printf("SEQHDR: post-filmgrain: off=%ld\n", - (gb->ptr - init_ptr) * 8 - gb->bits_left); + dav1d_get_bits_pos(gb) - init_bit_pos); #endif dav1d_get_bits(gb, 1); // dummy bit - return dav1d_flush_get_bits(gb) - init_ptr; + // We needn't bother flushing the OBU here: we'll check we didn't + // overrun in the caller and will then discard gb, so there's no + // point in setting its position properly. + + return 0; error: fprintf(stderr, "Error parsing sequence header\n"); @@ -313,16 +319,16 @@ static const Av1LoopfilterModeRefDeltas default_mode_ref_deltas = { .ref_delta = { 1, 0, 0, 0, -1, 0, -1, -1 }, }; -static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb, - const int have_trailing_bit) -{ +static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb) { +#define DEBUG_FRAME_HDR 0 + +#if DEBUG_FRAME_HDR const uint8_t *const init_ptr = gb->ptr; +#endif const Av1SequenceHeader *const seqhdr = &c->seq_hdr; Av1FrameHeader *const hdr = &c->frame_hdr; int res; -#define DEBUG_FRAME_HDR 0 - hdr->show_existing_frame = !seqhdr->reduced_still_picture_header && dav1d_get_bits(gb, 1); #if DEBUG_FRAME_HDR @@ -335,7 +341,7 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb, hdr->frame_presentation_delay = dav1d_get_bits(gb, seqhdr->frame_presentation_delay_length); if (seqhdr->frame_id_numbers_present) hdr->frame_id = dav1d_get_bits(gb, seqhdr->frame_id_n_bits); - goto end; + return 0; } hdr->frame_type = seqhdr->reduced_still_picture_header ? DAV1D_FRAME_TYPE_KEY : dav1d_get_bits(gb, 2); @@ -976,21 +982,14 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb, (gb->ptr - init_ptr) * 8 - gb->bits_left); #endif -end: - - if (have_trailing_bit) - dav1d_get_bits(gb, 1); // dummy bit - - return dav1d_flush_get_bits(gb) - init_ptr; + return 0; error: fprintf(stderr, "Error parsing frame header\n"); return -EINVAL; } -static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) { - const uint8_t *const init_ptr = gb->ptr; - +static void parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) { int have_tile_pos = 0; const int n_tiles = c->frame_hdr.tiling.cols * c->frame_hdr.tiling.rows; if (n_tiles > 1) @@ -1005,8 +1004,31 @@ static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) { c->tile[c->n_tile_data].start = 0; c->tile[c->n_tile_data].end = n_tiles - 1; } +} + +// Check that we haven't read more than obu_len bytes from the buffer +// since init_bit_pos. +static int +check_for_overrun(GetBits *const gb, unsigned init_bit_pos, unsigned obu_len) +{ + // Make sure we haven't actually read past the end of the gb buffer + if (gb->error) { + fprintf(stderr, "Overrun in OBU bit buffer\n"); + return 1; + } - return dav1d_flush_get_bits(gb) - init_ptr; + unsigned pos = dav1d_get_bits_pos(gb); + + // We assume that init_bit_pos was the bit position of the buffer + // at some point in the past, so cannot be smaller than pos. + assert (init_bit_pos <= pos); + + if (pos - init_bit_pos > 8 * obu_len) { + fprintf(stderr, "Overrun in OBU bit buffer into next OBU\n"); + return 1; + } + + return 0; } int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { @@ -1041,9 +1063,23 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { } while (more); if (gb.error) goto error; - unsigned off = dav1d_flush_get_bits(&gb) - in->data; - const unsigned init_off = off; - if (len > in->sz - off) goto error; + const unsigned init_bit_pos = dav1d_get_bits_pos(&gb); + const unsigned init_byte_pos = init_bit_pos >> 3; + const unsigned pkt_bytelen = init_byte_pos + len; + + // We must have read a whole number of bytes at this point (1 byte + // for the header and whole bytes at a time when reading the + // leb128 length field). + assert(init_bit_pos & 7 == 0); + + // We also know that we haven't tried to read more than in->sz + // bytes yet (otherwise the error flag would have been set by the + // code in getbits.c) + assert(in->sz >= init_byte_pos); + + // Make sure that there are enough bits left in the buffer for the + // rest of the OBU. + if (len > in->sz - init_byte_pos) goto error; switch (type) { case OBU_SEQ_HDR: { @@ -1052,8 +1088,8 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { c->have_frame_hdr = 0; if ((res = parse_seq_hdr(c, &gb, hdr_ptr)) < 0) return res; - if ((unsigned)res != len) - goto error; + if (check_for_overrun(&gb, init_bit_pos, len)) + return -EINVAL; if (!c->have_frame_hdr || memcmp(&hdr, &c->seq_hdr, sizeof(hdr))) { for (int i = 0; i < 8; i++) { if (c->refs[i].p.p.data[0]) @@ -1076,29 +1112,48 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { case OBU_FRAME_HDR: c->have_frame_hdr = 0; if (!c->have_seq_hdr) goto error; - if ((res = parse_frame_hdr(c, &gb, type != OBU_FRAME)) < 0) + if ((res = parse_frame_hdr(c, &gb)) < 0) return res; c->have_frame_hdr = 1; for (int n = 0; n < c->n_tile_data; n++) dav1d_data_unref(&c->tile[n].data); c->n_tile_data = 0; c->n_tiles = 0; - if (type != OBU_FRAME) break; + if (type != OBU_FRAME) { + // This is actually a frame header OBU so read the + // trailing bit and check for overrun. + dav1d_get_bits(&gb, 1); + if (check_for_overrun(&gb, init_bit_pos, len)) + return -EINVAL; + + break; + } + // OBU_FRAMEs shouldn't be signalled with show_existing_frame if (c->frame_hdr.show_existing_frame) goto error; - off += res; + + // This is the frame header at the start of a frame OBU. + // There's no trailing bit at the end to skip, but we do need + // to align to the next byte. + dav1d_bytealign_get_bits(&gb); // fall-through - case OBU_TILE_GRP: + case OBU_TILE_GRP: { if (!c->have_frame_hdr) goto error; if (c->n_tile_data >= 256) goto error; - if ((res = parse_tile_hdr(c, &gb)) < 0) - return res; - off += res; - if (off > len + init_off) - goto error; + parse_tile_hdr(c, &gb); + // Align to the next byte boundary and check for overrun. + dav1d_bytealign_get_bits(&gb); + if (check_for_overrun(&gb, init_bit_pos, len)) + return -EINVAL; + // The current bit position is a multiple of 8 (because we + // just aligned it) and less than 8*pkt_bytelen because + // otherwise the overrun check would have fired. + const unsigned bit_pos = dav1d_get_bits_pos(&gb); + assert(bit_pos & 7 == 0); + assert(pkt_bytelen > (bit_pos >> 3)); dav1d_ref_inc(in->ref); c->tile[c->n_tile_data].data.ref = in->ref; - c->tile[c->n_tile_data].data.data = in->data + off; - c->tile[c->n_tile_data].data.sz = len + init_off - off; + c->tile[c->n_tile_data].data.data = in->data + (bit_pos >> 3); + c->tile[c->n_tile_data].data.sz = pkt_bytelen - (bit_pos >> 3); // ensure tile groups are in order and sane, see 6.10.1 if (c->tile[c->n_tile_data].start > c->tile[c->n_tile_data].end || c->tile[c->n_tile_data].start != c->n_tiles) @@ -1113,6 +1168,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { c->tile[c->n_tile_data].start; c->n_tile_data++; break; + } case OBU_PADDING: case OBU_TD: case OBU_METADATA: @@ -1192,7 +1248,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { } } - return len + init_off; + return len + init_byte_pos; error: fprintf(stderr, "Error parsing OBU data\n"); |