diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-06 11:10:59 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-06 11:16:43 +0300 |
commit | 544b3e576c8edd1785914c988882b62d60652f26 (patch) | |
tree | d1dd9e1a90ce85f3fbb462228c7adea29fe10d70 | |
parent | 98b8be09d56d03d220fff3536842c0703bae865c (diff) |
DRED: quantize r and p0 parameters with 8 bits
Only code non-degenerate symbols, which makes the encoder faster
-rwxr-xr-x | autogen.sh | 2 | ||||
-rw-r--r-- | dnn/dred_rdovae.c | 8 | ||||
-rw-r--r-- | dnn/dred_rdovae.h | 4 | ||||
-rw-r--r-- | dnn/torch/rdovae/export_rdovae_weights.py | 12 | ||||
-rw-r--r-- | dnn/torch/rdovae/rdovae/rdovae.py | 53 | ||||
-rw-r--r-- | silk/dred_config.h | 2 | ||||
-rw-r--r-- | silk/dred_decoder.c | 9 | ||||
-rw-r--r-- | silk/dred_encoder.c | 10 |
8 files changed, 52 insertions, 48 deletions
@@ -9,7 +9,7 @@ set -e srcdir=`dirname $0` test -n "$srcdir" && cd "$srcdir" -dnn/download_model.sh c99054d +dnn/download_model.sh 98b8be0 echo "Updating build configuration files, please wait...." diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c index a17b957b..b4797b5e 100644 --- a/dnn/dred_rdovae.c +++ b/dnn/dred_rdovae.c @@ -79,9 +79,9 @@ void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float } -const opus_uint16 * DRED_rdovae_get_p0_pointer(void) +const opus_uint8 * DRED_rdovae_get_p0_pointer(void) { - return &dred_p0_q15[0]; + return &dred_p0_q8[0]; } const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void) @@ -89,9 +89,9 @@ const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void) return &dred_dead_zone_q10[0]; } -const opus_uint16 * DRED_rdovae_get_r_pointer(void) +const opus_uint8 * DRED_rdovae_get_r_pointer(void) { - return &dred_r_q15[0]; + return &dred_r_q8[0]; } const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void) diff --git a/dnn/dred_rdovae.h b/dnn/dred_rdovae.h index f2c3235e..05da69ba 100644 --- a/dnn/dred_rdovae.h +++ b/dnn/dred_rdovae.h @@ -58,9 +58,9 @@ void DRED_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, cons void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z); -const opus_uint16 * DRED_rdovae_get_p0_pointer(void); +const opus_uint8 * DRED_rdovae_get_p0_pointer(void); const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void); -const opus_uint16 * DRED_rdovae_get_r_pointer(void); +const opus_uint8 * DRED_rdovae_get_r_pointer(void); const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void); #endif diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index 9a35c17a..3bcc7712 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -63,20 +63,20 @@ def dump_statistical_model(writer, qembedding): quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16) - r_q15 = np.round(r * 2**15).astype(np.uint16) - p0_q15 = np.round(p0 * 2**15).astype(np.uint16) + r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8) + p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16) print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False) print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False) - print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False) - print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False) + print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False) + print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False) writer.header.write( f""" extern const opus_uint16 dred_quant_scales_q8[{levels * N}]; extern const opus_uint16 dred_dead_zone_q10[{levels * N}]; -extern const opus_uint16 dred_r_q15[{levels * N}]; -extern const opus_uint16 dred_p0_q15[{levels * N}]; +extern const opus_uint8 dred_r_q8[{levels * N}]; +extern const opus_uint8 dred_p0_q8[{levels * N}]; """ ) diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 3552cf90..09b6801a 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -222,6 +222,9 @@ def weight_clip_factory(max_value): return clip_weights +def n(x): + return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) + # RDOVAE module and submodules class MyConv(nn.Module): @@ -295,17 +298,17 @@ class CoreEncoder(nn.Module): device = x.device # run encoding layer stack - x = torch.tanh(self.dense_1(x)) - x = torch.cat([x, self.gru1(x)[0]], -1) - x = torch.cat([x, self.conv1(x)], -1) - x = torch.cat([x, self.gru2(x)[0]], -1) - x = torch.cat([x, self.conv2(x)], -1) - x = torch.cat([x, self.gru3(x)[0]], -1) - x = torch.cat([x, self.conv3(x)], -1) - x = torch.cat([x, self.gru4(x)[0]], -1) - x = torch.cat([x, self.conv4(x)], -1) - x = torch.cat([x, self.gru5(x)[0]], -1) - x = torch.cat([x, self.conv5(x)], -1) + x = n(torch.tanh(self.dense_1(x))) + x = torch.cat([x, n(self.gru1(x)[0])], -1) + x = torch.cat([x, n(self.conv1(x))], -1) + x = torch.cat([x, n(self.gru2(x)[0])], -1) + x = torch.cat([x, n(self.conv2(x))], -1) + x = torch.cat([x, n(self.gru3(x)[0])], -1) + x = torch.cat([x, n(self.conv3(x))], -1) + x = torch.cat([x, n(self.gru4(x)[0])], -1) + x = torch.cat([x, n(self.conv4(x))], -1) + x = torch.cat([x, n(self.gru5(x)[0])], -1) + x = torch.cat([x, n(self.conv5(x))], -1) z = self.z_dense(x) # init state for decoder @@ -372,18 +375,18 @@ class CoreDecoder(nn.Module): h5_state = gru_state[:,:,384:].contiguous() # run decoding layer stack - x = torch.tanh(self.dense_1(z)) - - x = torch.cat([x, self.gru1(x, h1_state)[0]], -1) - x = torch.cat([x, self.conv1(x)], -1) - x = torch.cat([x, self.gru2(x, h2_state)[0]], -1) - x = torch.cat([x, self.conv2(x)], -1) - x = torch.cat([x, self.gru3(x, h3_state)[0]], -1) - x = torch.cat([x, self.conv3(x)], -1) - x = torch.cat([x, self.gru4(x, h4_state)[0]], -1) - x = torch.cat([x, self.conv4(x)], -1) - x = torch.cat([x, self.gru5(x, h5_state)[0]], -1) - x = torch.cat([x, self.conv5(x)], -1) + x = n(torch.tanh(self.dense_1(z))) + + x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1) + x = torch.cat([x, n(self.conv1(x))], -1) + x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1) + x = torch.cat([x, n(self.conv2(x))], -1) + x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1) + x = torch.cat([x, n(self.conv3(x))], -1) + x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1) + x = torch.cat([x, n(self.conv4(x))], -1) + x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1) + x = torch.cat([x, n(self.conv5(x))], -1) # output layer and reshaping x10 = self.output(x) @@ -451,7 +454,7 @@ class RDOVAE(nn.Module): cond_size2, state_dim=24, split_mode='split', - clip_weights=True, + clip_weights=False, pvq_num_pulses=82, state_dropout_rate=0): @@ -487,7 +490,7 @@ class RDOVAE(nn.Module): if not type(self.weight_clip_fn) == type(None): self.apply(self.weight_clip_fn) - def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24): + def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): enc_stride = self.enc_stride dec_stride = self.dec_stride diff --git a/silk/dred_config.h b/silk/dred_config.h index 06ff397b..5e3e74a3 100644 --- a/silk/dred_config.h +++ b/silk/dred_config.h @@ -32,7 +32,7 @@ #define DRED_EXTENSION_ID 126 /* Remove these two completely once DRED gets an extension number assigned. */ -#define DRED_EXPERIMENTAL_VERSION 6 +#define DRED_EXPERIMENTAL_VERSION 7 #define DRED_EXPERIMENTAL_BYTES 2 diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c index 658d340f..68ba8559 100644 --- a/silk/dred_decoder.c +++ b/silk/dred_decoder.c @@ -43,20 +43,21 @@ static int sign_extend(int x, int b) { return (x ^ m) - m; } -static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0, int dim) { +static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) { int i; for (i=0;i<dim;i++) { int q; - q = ec_laplace_decode_p0(dec, p0[i], r[i]); + if (r[i] == 0 || p0[i] == 255) q = 0; + else q = ec_laplace_decode_p0(dec, p0[i]<<7, r[i]<<7); x[i] = q*256.f/(scale[i] == 0 ? 1 : scale[i]); } } int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames) { - const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer(); + const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer(); const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer(); - const opus_uint16 *r = DRED_rdovae_get_r_pointer(); + const opus_uint8 *r = DRED_rdovae_get_r_pointer(); ec_dec ec; int q_level; int i; diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index 6b585416..3f842af0 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -217,7 +217,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex } } -static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0, int dim) { +static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) { int i; int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)]; float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)]; @@ -238,16 +238,16 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 * } for (i=0;i<dim;i++) { /* Make the impossible actually impossible. */ - if (r[i] == 0 || p0[i] >= 32767) q[i] = 0; - ec_laplace_encode_p0(enc, q[i], p0[i], r[i]); + if (r[i] == 0 || p0[i] == 255) q[i] = 0; + else ec_laplace_encode_p0(enc, q[i], p0[i]<<7, r[i]<<7); } } int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) { const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer(); - const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer(); + const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer(); const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer(); - const opus_uint16 *r = DRED_rdovae_get_r_pointer(); + const opus_uint8 *r = DRED_rdovae_get_r_pointer(); ec_enc ec_encoder; int q_level; |