diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-07 01:49:18 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-07 07:51:25 +0300 |
commit | 8d851c25135d777c344307e3b8ffd6a4551e2f0c (patch) | |
tree | ae82cfc9368bb6fd5f595b6fe6c532977a88da5d | |
parent | 2386a60ec644fadc437155cd6e5f6d4c561940d4 (diff) |
Split RDOVAE stats in two (latents + states)
-rw-r--r-- | dnn/dred_rdovae.c | 21 | ||||
-rw-r--r-- | dnn/torch/rdovae/export_rdovae_weights.py | 48 | ||||
-rw-r--r-- | silk/dred_decoder.c | 20 | ||||
-rw-r--r-- | silk/dred_encoder.c | 25 |
4 files changed, 48 insertions, 66 deletions
diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c index b4797b5e..748a463a 100644 --- a/dnn/dred_rdovae.c +++ b/dnn/dred_rdovae.c @@ -77,24 +77,3 @@ void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float { dred_rdovae_decode_qframe(h, model, qframe, z); } - - -const opus_uint8 * DRED_rdovae_get_p0_pointer(void) -{ - return &dred_p0_q8[0]; -} - -const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void) -{ - return &dred_dead_zone_q10[0]; -} - -const opus_uint8 * DRED_rdovae_get_r_pointer(void) -{ - return &dred_r_q8[0]; -} - -const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void) -{ - return &dred_quant_scales_q8[0]; -} diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index 3bcc7712..8571f8f9 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -49,16 +49,15 @@ from wexchange.torch import dump_torch_weights from wexchange.c_export import CWriter, print_vector -def dump_statistical_model(writer, qembedding): - w = qembedding.weight.detach() - levels, dim = w.shape - N = dim // 6 +def dump_statistical_model(writer, w, name): + levels = w.shape[0] + N = w.shape[-1] print("printing statistical model") - quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy() - dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy() - r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy() - p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy() + quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy() + dead_zone = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy() + r = torch.sigmoid(w[:, 5 , :]).numpy() + p0 = torch.sigmoid(w[:, 4 , :]).numpy() p0 = 1 - r ** (0.5 + 0.5 * p0) quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) @@ -66,17 +65,17 @@ def dump_statistical_model(writer, qembedding): r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8) p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16) - print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False) - print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False) - print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False) - print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False) + print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False) + print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False) + print_vector(writer.source, r_q15, f'dred_{name}_r_q8', dtype='opus_uint8', static=False) + print_vector(writer.source, p0_q15, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False) writer.header.write( f""" -extern const opus_uint16 dred_quant_scales_q8[{levels * N}]; -extern const opus_uint16 dred_dead_zone_q10[{levels * N}]; -extern const opus_uint8 dred_r_q8[{levels * N}]; -extern const opus_uint8 dred_p0_q8[{levels * N}]; +extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}]; +extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}]; +extern const opus_uint8 dred_{name}_r_q8[{levels * N}]; +extern const opus_uint8 dred_{name}_p0_q8[{levels * N}]; """ ) @@ -113,6 +112,19 @@ f""" """ ) + latent_out = model.get_submodule('core_encoder.module.z_dense').weight + states_out = model.get_submodule('core_encoder.module.state_dense_2').weight + nb_latents = latent_out.shape[0] + nb_states = states_out.shape[0] + # statistical model + qembedding = model.statistical_model.quant_embedding.weight.detach() + levels = qembedding.shape[0] + qembedding = torch.reshape(qembedding, (levels, 6, -1)) + + dump_statistical_model(stats_writer, qembedding[:, :, :nb_latents], 'latents') + dump_statistical_model(stats_writer, qembedding[:, :, nb_latents:], 'states') + + # encoder encoder_dense_layers = [ ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH', False,), @@ -187,10 +199,6 @@ f""" del dec_writer - # statistical model - qembedding = model.statistical_model.quant_embedding - dump_statistical_model(stats_writer, qembedding) - del stats_writer # constants diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c index 68ba8559..2bedde2f 100644 --- a/silk/dred_decoder.c +++ b/silk/dred_decoder.c @@ -36,6 +36,7 @@ #include "dred_coding.h" #include "celt/entdec.h" #include "celt/laplace.h" +#include "dred_rdovae_stats_data.h" /* From http://graphics.stanford.edu/~seander/bithacks.html#FixedSignExtend */ static int sign_extend(int x, int b) { @@ -55,9 +56,6 @@ static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames) { - const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer(); - const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer(); - const opus_uint8 *r = DRED_rdovae_get_r_pointer(); ec_dec ec; int q_level; int i; @@ -78,13 +76,13 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi /*printf("%d %d %d\n", dred_offset, q0, dQ);*/ //dred_decode_state(&ec, dec->state); - state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM; + state_qoffset = q0*DRED_STATE_DIM; dred_decode_latents( &ec, dec->state, - quant_scales + state_qoffset, - r + state_qoffset, - p0 + state_qoffset, + dred_states_quant_scales_q8 + state_qoffset, + dred_states_r_q8 + state_qoffset, + dred_states_p0_q8 + state_qoffset, DRED_STATE_DIM); /* decode newest to oldest and store oldest to newest */ @@ -94,13 +92,13 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi if (8*num_bytes - ec_tell(&ec) <= 7) break; q_level = compute_quantizer(q0, dQ, i/2); - offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM); + offset = q_level*DRED_LATENT_DIM; dred_decode_latents( &ec, &dec->latents[(i/2)*DRED_LATENT_DIM], - quant_scales + offset, - r + offset, - p0 + offset, + dred_latents_quant_scales_q8 + offset, + dred_latents_r_q8 + offset, + dred_latents_p0_q8 + offset, DRED_LATENT_DIM ); diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index 3f842af0..6cda6334 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -44,6 +44,7 @@ #include "float_cast.h" #include "os_support.h" #include "celt/laplace.h" +#include "dred_rdovae_stats_data.h" int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) @@ -244,10 +245,6 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 * } int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) { - const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer(); - const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer(); - const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer(); - const opus_uint8 *r = DRED_rdovae_get_r_pointer(); ec_enc ec_encoder; int q_level; @@ -265,14 +262,14 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk ec_enc_uint(&ec_encoder, enc->dred_offset, 32); ec_enc_uint(&ec_encoder, q0, 16); ec_enc_uint(&ec_encoder, dQ, 8); - state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM; + state_qoffset = q0*DRED_STATE_DIM; dred_encode_latents( &ec_encoder, enc->initial_state, - quant_scales + state_qoffset, - dead_zone + state_qoffset, - r + state_qoffset, - p0 + state_qoffset, + dred_states_quant_scales_q8 + state_qoffset, + dred_states_dead_zone_q10 + state_qoffset, + dred_states_r_q8 + state_qoffset, + dred_states_p0_q8 + state_qoffset, DRED_STATE_DIM); if (ec_tell(&ec_encoder) > 8*max_bytes) { return 0; @@ -283,15 +280,15 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk ec_bak = ec_encoder; q_level = compute_quantizer(q0, dQ, i/2); - offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM); + offset = q_level * DRED_LATENT_DIM; dred_encode_latents( &ec_encoder, enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM, - quant_scales + offset, - dead_zone + offset, - r + offset, - p0 + offset, + dred_latents_quant_scales_q8 + offset, + dred_latents_dead_zone_q10 + offset, + dred_latents_r_q8 + offset, + dred_latents_p0_q8 + offset, DRED_LATENT_DIM ); if (ec_tell(&ec_encoder) > 8*max_bytes) { |