diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-16 00:27:44 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-21 01:04:08 +0300 |
commit | b88644b9c7547f09c7e313c2ae5ad3085fba2d4e (patch) | |
tree | e4e47c2509b941daff61acc59415e19831a4996c | |
parent | 2ec31cc5ccb92e1b472cf3da1d26490412805a79 (diff) |
Quantizing initial state with rdovae too
More efficient than PVQ
-rw-r--r-- | dnn/torch/rdovae/rdovae/rdovae.py | 45 | ||||
-rw-r--r-- | dnn/torch/rdovae/train_rdovae.py | 15 | ||||
-rw-r--r-- | silk/dred_coding.c | 77 | ||||
-rw-r--r-- | silk/dred_config.h | 4 | ||||
-rw-r--r-- | silk/dred_decoder.c | 12 | ||||
-rw-r--r-- | silk/dred_encoder.c | 15 | ||||
-rw-r--r-- | silk/dred_encoder.h | 4 |
7 files changed, 64 insertions, 108 deletions
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 1eec42c1..0dc943ec 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -372,7 +372,7 @@ class CoreDecoder(nn.Module): class StatisticalModel(nn.Module): - def __init__(self, quant_levels, latent_dim): + def __init__(self, quant_levels, latent_dim, state_dim): """ Statistical model for latent space Computes scaling, deadzone, r, and theta @@ -383,8 +383,10 @@ class StatisticalModel(nn.Module): # copy parameters self.latent_dim = latent_dim + self.state_dim = state_dim + self.total_dim = latent_dim + state_dim self.quant_levels = quant_levels - self.embedding_dim = 6 * latent_dim + self.embedding_dim = 6 * self.total_dim # quantization embedding self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim) @@ -400,12 +402,12 @@ class StatisticalModel(nn.Module): x = self.quant_embedding(quant_ids) # CAVE: theta_soft is not used anymore. Kick it out? - quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim]) - dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim]) - theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim]) - r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim]) - theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim]) - r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim]) + quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim]) + dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim]) + theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim]) + r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim]) + theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim]) + r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim]) return { @@ -445,7 +447,7 @@ class RDOVAE(nn.Module): self.state_dropout_rate = state_dropout_rate # submodules encoder and decoder share the statistical model - self.statistical_model = StatisticalModel(quant_levels, latent_dim) + self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim) self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) @@ -522,13 +524,18 @@ class RDOVAE(nn.Module): z, states = self.core_encoder(features) # scaling, dead-zone and quantization - z = z * statistical_model['quant_scale'] - z = soft_dead_zone(z, statistical_model['dead_zone']) + z = z * statistical_model['quant_scale'][:,:,:self.latent_dim] + z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim]) # quantization - z_q = hard_quantize(z) / statistical_model['quant_scale'] - z_n = noise_quantize(z) / statistical_model['quant_scale'] - states_q = soft_pvq(states, self.pvq_num_pulses) + z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] + z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] + #states_q = soft_pvq(states, self.pvq_num_pulses) + states = states * statistical_model['quant_scale'][:,:,self.latent_dim:] + states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:]) + + states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] + states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] if self.state_dropout_rate > 0: drop = torch.rand(states_q.size(0)) < self.state_dropout_rate @@ -551,6 +558,7 @@ class RDOVAE(nn.Module): # decoder with soft quantized input z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) + dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) @@ -558,6 +566,7 @@ class RDOVAE(nn.Module): 'outputs_hard_quant' : outputs_hq, 'outputs_soft_quant' : outputs_sq, 'z' : z, + 'states' : states, 'statistical_model' : statistical_model } @@ -586,11 +595,11 @@ class RDOVAE(nn.Module): stats = self.statistical_model(q_ids) - zq = z * stats['quant_scale'] - zq = soft_dead_zone(zq, stats['dead_zone']) + zq = z * stats['quant_scale'][:self.latent_dim] + zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim]) zq = torch.round(zq) - sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False) + sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False) return zq, sizes @@ -599,7 +608,7 @@ class RDOVAE(nn.Module): stats = self.statistical_model(q_ids) - z = zq / stats['quant_scale'] + z = zq / stats['quant_scale'][:,:,:self.latent_dim] return z diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index f29ed98f..3f8484e1 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -172,6 +172,7 @@ if __name__ == '__main__': running_soft_rate_loss = 0 running_total_loss = 0 running_rate_metric = 0 + running_states_rate_metric = 0 previous_total_loss = 0 running_first_frame_loss = 0 @@ -194,17 +195,21 @@ if __name__ == '__main__': # collect outputs z = model_output['z'] + states = model_output['states'] outputs_hard_quant = model_output['outputs_hard_quant'] outputs_soft_quant = model_output['outputs_soft_quant'] statistical_model = model_output['statistical_model'] # rate loss - hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False) - soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False) - soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate) - hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate) + hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False) + soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False) + states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False) + states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False) + soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate)) + hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate)) rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss) hard_rate_metric = torch.mean(hard_rate) + states_rate_metric = torch.mean(states_hard_rate) ## distortion losses @@ -242,6 +247,7 @@ if __name__ == '__main__': running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu()) running_rate_loss += float(rate_loss.detach().cpu()) running_rate_metric += float(hard_rate_metric.detach().cpu()) + running_states_rate_metric += float(states_rate_metric.detach().cpu()) running_total_loss += float(total_loss.detach().cpu()) running_first_frame_loss += float(first_frame_loss.detach().cpu()) running_soft_rate_loss += float(soft_rate_loss.detach().cpu()) @@ -256,6 +262,7 @@ if __name__ == '__main__': dist_sq=running_soft_dist_loss / (i + 1), rate_loss=running_rate_loss / (i + 1), rate=running_rate_metric / (i + 1), + states_rate=running_states_rate_metric / (i + 1), ffloss=running_first_frame_loss / (i + 1), rateloss_hard=running_hard_rate_loss / (i + 1), rateloss_soft=running_soft_rate_loss / (i + 1) diff --git a/silk/dred_coding.c b/silk/dred_coding.c index 24e109ef..f8d2f070 100644 --- a/silk/dred_coding.c +++ b/silk/dred_coding.c @@ -33,16 +33,13 @@ #include <stdio.h> #include "celt/entenc.h" -#include "celt/vq.h" -#include "celt/cwrs.h" #include "celt/laplace.h" #include "os_support.h" #include "dred_config.h" #include "dred_coding.h" #define LATENT_DIM 80 -#define PVQ_DIM 24 -#define PVQ_K 82 +#define STATE_DIM 80 int compute_quantizer(int q0, int dQ, int i) { int quant; @@ -53,37 +50,6 @@ int compute_quantizer(int q0, int dQ, int i) { return (int) floor(0.5f + DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2)); } -static void encode_pvq(const int *iy, int N, int K, ec_enc *enc) { - int fits; - celt_assert(N==24 || N==12 || N==6); - fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6); - /*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/ - if (fits) { - if (K > 0) - encode_pulses(iy, N, K, enc); - } - else { - int N2 = N/2; - int K0=0; - int i; - for (i=0;i<N2;i++) K0 += abs(iy[i]); - /* FIXME: Don't use uniform probability for K0. */ - ec_enc_uint(enc, K0, K+1); - /*printf("K0 = %d\n", K0);*/ - encode_pvq(iy, N2, K0, enc); - encode_pvq(&iy[N2], N2, K-K0, enc); - } -} - -void dred_encode_state(ec_enc *enc, const float *x) { - int iy[PVQ_DIM]; - float x0[PVQ_DIM]; - /* Copy state because the PVQ search will trash it. */ - OPUS_COPY(x0, x, PVQ_DIM); - op_pvq_search_c(x0, iy, PVQ_K, PVQ_DIM, 0); - encode_pvq(iy, PVQ_DIM, PVQ_K, enc); -} - void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0) { int i; float eps = .1f; @@ -101,47 +67,6 @@ void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, } } - - -static void decode_pvq(int *iy, int N, int K, ec_dec *dec) { - int fits; - celt_assert(N==24 || N==12 || N==6); - fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6); - /*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/ - if (fits) { - if (K > 0) - decode_pulses(iy, N, K, dec); - else - OPUS_CLEAR(iy, N); - } - else { - int N2 = N/2; - int K0; - /* FIXME: Don't use uniform probability for K0. */ - K0 = ec_dec_uint(dec, K+1); - /*printf("K0 = %d\n", K0);*/ - decode_pvq(iy, N2, K0, dec); - decode_pvq(&iy[N2], N2, K-K0, dec); - } -} - -void dred_decode_state(ec_enc *dec, float *x) { - int k; - int iy[PVQ_DIM]; - float norm = 0; - decode_pvq(iy, PVQ_DIM, PVQ_K, dec); - /*printf("tell: %d\n", ec_tell(dec)-tell1);*/ - for (k = 0; k < PVQ_DIM; k++) - { - norm += (float) iy[k] * iy[k]; - } - norm = 1.f / sqrt(norm); - for (k = 0; k < PVQ_DIM; k++) - { - x[k] = iy[k] * norm; - } -} - void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0) { int i; for (i=0;i<LATENT_DIM;i++) { diff --git a/silk/dred_config.h b/silk/dred_config.h index d8342f8e..891eb1cf 100644 --- a/silk/dred_config.h +++ b/silk/dred_config.h @@ -32,7 +32,7 @@ #define DRED_EXTENSION_ID 126 /* Remove these two completely once DRED gets an extension number assigned. */ -#define DRED_EXPERIMENTAL_VERSION 1 +#define DRED_EXPERIMENTAL_VERSION 2 #define DRED_EXPERIMENTAL_BYTES 2 @@ -41,7 +41,7 @@ /* these are inpart duplicates to the values defined in dred_rdovae_constants.h */ #define DRED_NUM_FEATURES 20 #define DRED_LATENT_DIM 80 -#define DRED_STATE_DIM 24 +#define DRED_STATE_DIM 80 #define DRED_SILK_ENCODER_DELAY (79+12-80) #define DRED_FRAME_SIZE 160 #define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE)) diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c index 04ba1ef3..500f33b8 100644 --- a/silk/dred_decoder.c +++ b/silk/dred_decoder.c @@ -54,6 +54,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi int offset; int q0; int dQ; + int state_qoffset; /* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */ @@ -66,7 +67,14 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi dQ = ec_dec_uint(&ec, 8); /*printf("%d %d %d\n", dred_offset, q0, dQ);*/ - dred_decode_state(&ec, dec->state); + //dred_decode_state(&ec, dec->state); + state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM; + dred_decode_latents( + &ec, + dec->state, + quant_scales + state_qoffset, + r + state_qoffset, + p0 + state_qoffset); /* decode newest to oldest and store oldest to newest */ for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2) @@ -75,7 +83,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi if (8*num_bytes - ec_tell(&ec) <= 7) break; q_level = compute_quantizer(q0, dQ, i/2); - offset = q_level * DRED_LATENT_DIM; + offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM); dred_decode_latents( &ec, &dec->latents[(i/2)*DRED_LATENT_DIM], diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index 5bae39e9..7b34cefe 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -197,7 +197,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex /* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */ if (enc->dred_offset < 6) { enc->dred_offset += 8; - OPUS_COPY(enc->initial_state, enc->state_buffer, 24); + OPUS_COPY(enc->initial_state, enc->state_buffer, DRED_STATE_DIM); } else { enc->latent_offset++; } @@ -221,6 +221,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk int ec_buffer_fill; int q0; int dQ; + int state_qoffset; /* entropy coding of state and latents */ ec_enc_init(&ec_encoder, buf, max_bytes); @@ -229,15 +230,21 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk ec_enc_uint(&ec_encoder, enc->dred_offset, 32); ec_enc_uint(&ec_encoder, q0, 16); ec_enc_uint(&ec_encoder, dQ, 8); - dred_encode_state(&ec_encoder, enc->initial_state); - + state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM; + dred_encode_latents( + &ec_encoder, + enc->initial_state, + quant_scales + state_qoffset, + dead_zone + state_qoffset, + r + state_qoffset, + p0 + state_qoffset); for (i = 0; i < IMIN(2*max_chunks, enc->latents_buffer_fill-enc->latent_offset-1); i += 2) { ec_enc ec_bak; ec_bak = ec_encoder; q_level = compute_quantizer(q0, dQ, i/2); - offset = q_level * DRED_LATENT_DIM; + offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM); dred_encode_latents( &ec_encoder, diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h index 8ed323d3..2b77d581 100644 --- a/silk/dred_encoder.h +++ b/silk/dred_encoder.h @@ -50,8 +50,8 @@ typedef struct { int latent_offset; float latents_buffer[DRED_MAX_FRAMES * DRED_LATENT_DIM]; int latents_buffer_fill; - float state_buffer[24]; - float initial_state[24]; + float state_buffer[DRED_STATE_DIM]; + float initial_state[DRED_STATE_DIM]; float resample_mem[RESAMPLING_ORDER + 1]; LPCNetEncState lpcnet_enc_state; RDOVAEEncState rdovae_enc; |