Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-16 00:27:44 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-21 01:04:08 +0300
commitb88644b9c7547f09c7e313c2ae5ad3085fba2d4e (patch)
treee4e47c2509b941daff61acc59415e19831a4996c
parent2ec31cc5ccb92e1b472cf3da1d26490412805a79 (diff)
Quantizing initial state with rdovae too
More efficient than PVQ
-rw-r--r--dnn/torch/rdovae/rdovae/rdovae.py45
-rw-r--r--dnn/torch/rdovae/train_rdovae.py15
-rw-r--r--silk/dred_coding.c77
-rw-r--r--silk/dred_config.h4
-rw-r--r--silk/dred_decoder.c12
-rw-r--r--silk/dred_encoder.c15
-rw-r--r--silk/dred_encoder.h4
7 files changed, 64 insertions, 108 deletions
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 1eec42c1..0dc943ec 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -372,7 +372,7 @@ class CoreDecoder(nn.Module):
class StatisticalModel(nn.Module):
- def __init__(self, quant_levels, latent_dim):
+ def __init__(self, quant_levels, latent_dim, state_dim):
""" Statistical model for latent space
Computes scaling, deadzone, r, and theta
@@ -383,8 +383,10 @@ class StatisticalModel(nn.Module):
# copy parameters
self.latent_dim = latent_dim
+ self.state_dim = state_dim
+ self.total_dim = latent_dim + state_dim
self.quant_levels = quant_levels
- self.embedding_dim = 6 * latent_dim
+ self.embedding_dim = 6 * self.total_dim
# quantization embedding
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
@@ -400,12 +402,12 @@ class StatisticalModel(nn.Module):
x = self.quant_embedding(quant_ids)
# CAVE: theta_soft is not used anymore. Kick it out?
- quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
- dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
- theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
- r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
- theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
- r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
+ quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
+ dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
+ theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
+ r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
+ theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
+ r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
return {
@@ -445,7 +447,7 @@ class RDOVAE(nn.Module):
self.state_dropout_rate = state_dropout_rate
# submodules encoder and decoder share the statistical model
- self.statistical_model = StatisticalModel(quant_levels, latent_dim)
+ self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
@@ -522,13 +524,18 @@ class RDOVAE(nn.Module):
z, states = self.core_encoder(features)
# scaling, dead-zone and quantization
- z = z * statistical_model['quant_scale']
- z = soft_dead_zone(z, statistical_model['dead_zone'])
+ z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
+ z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
# quantization
- z_q = hard_quantize(z) / statistical_model['quant_scale']
- z_n = noise_quantize(z) / statistical_model['quant_scale']
- states_q = soft_pvq(states, self.pvq_num_pulses)
+ z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
+ z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
+ #states_q = soft_pvq(states, self.pvq_num_pulses)
+ states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
+ states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
+
+ states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
+ states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
if self.state_dropout_rate > 0:
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
@@ -551,6 +558,7 @@ class RDOVAE(nn.Module):
# decoder with soft quantized input
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
+ dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
@@ -558,6 +566,7 @@ class RDOVAE(nn.Module):
'outputs_hard_quant' : outputs_hq,
'outputs_soft_quant' : outputs_sq,
'z' : z,
+ 'states' : states,
'statistical_model' : statistical_model
}
@@ -586,11 +595,11 @@ class RDOVAE(nn.Module):
stats = self.statistical_model(q_ids)
- zq = z * stats['quant_scale']
- zq = soft_dead_zone(zq, stats['dead_zone'])
+ zq = z * stats['quant_scale'][:self.latent_dim]
+ zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
zq = torch.round(zq)
- sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
+ sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
return zq, sizes
@@ -599,7 +608,7 @@ class RDOVAE(nn.Module):
stats = self.statistical_model(q_ids)
- z = zq / stats['quant_scale']
+ z = zq / stats['quant_scale'][:,:,:self.latent_dim]
return z
diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py
index f29ed98f..3f8484e1 100644
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -172,6 +172,7 @@ if __name__ == '__main__':
running_soft_rate_loss = 0
running_total_loss = 0
running_rate_metric = 0
+ running_states_rate_metric = 0
previous_total_loss = 0
running_first_frame_loss = 0
@@ -194,17 +195,21 @@ if __name__ == '__main__':
# collect outputs
z = model_output['z']
+ states = model_output['states']
outputs_hard_quant = model_output['outputs_hard_quant']
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
# rate loss
- hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
- soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
- soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
- hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
+ hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
+ soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
+ states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
+ states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
+ soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
+ hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
+ states_rate_metric = torch.mean(states_hard_rate)
## distortion losses
@@ -242,6 +247,7 @@ if __name__ == '__main__':
running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu())
running_rate_loss += float(rate_loss.detach().cpu())
running_rate_metric += float(hard_rate_metric.detach().cpu())
+ running_states_rate_metric += float(states_rate_metric.detach().cpu())
running_total_loss += float(total_loss.detach().cpu())
running_first_frame_loss += float(first_frame_loss.detach().cpu())
running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
@@ -256,6 +262,7 @@ if __name__ == '__main__':
dist_sq=running_soft_dist_loss / (i + 1),
rate_loss=running_rate_loss / (i + 1),
rate=running_rate_metric / (i + 1),
+ states_rate=running_states_rate_metric / (i + 1),
ffloss=running_first_frame_loss / (i + 1),
rateloss_hard=running_hard_rate_loss / (i + 1),
rateloss_soft=running_soft_rate_loss / (i + 1)
diff --git a/silk/dred_coding.c b/silk/dred_coding.c
index 24e109ef..f8d2f070 100644
--- a/silk/dred_coding.c
+++ b/silk/dred_coding.c
@@ -33,16 +33,13 @@
#include <stdio.h>
#include "celt/entenc.h"
-#include "celt/vq.h"
-#include "celt/cwrs.h"
#include "celt/laplace.h"
#include "os_support.h"
#include "dred_config.h"
#include "dred_coding.h"
#define LATENT_DIM 80
-#define PVQ_DIM 24
-#define PVQ_K 82
+#define STATE_DIM 80
int compute_quantizer(int q0, int dQ, int i) {
int quant;
@@ -53,37 +50,6 @@ int compute_quantizer(int q0, int dQ, int i) {
return (int) floor(0.5f + DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2));
}
-static void encode_pvq(const int *iy, int N, int K, ec_enc *enc) {
- int fits;
- celt_assert(N==24 || N==12 || N==6);
- fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
- /*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
- if (fits) {
- if (K > 0)
- encode_pulses(iy, N, K, enc);
- }
- else {
- int N2 = N/2;
- int K0=0;
- int i;
- for (i=0;i<N2;i++) K0 += abs(iy[i]);
- /* FIXME: Don't use uniform probability for K0. */
- ec_enc_uint(enc, K0, K+1);
- /*printf("K0 = %d\n", K0);*/
- encode_pvq(iy, N2, K0, enc);
- encode_pvq(&iy[N2], N2, K-K0, enc);
- }
-}
-
-void dred_encode_state(ec_enc *enc, const float *x) {
- int iy[PVQ_DIM];
- float x0[PVQ_DIM];
- /* Copy state because the PVQ search will trash it. */
- OPUS_COPY(x0, x, PVQ_DIM);
- op_pvq_search_c(x0, iy, PVQ_K, PVQ_DIM, 0);
- encode_pvq(iy, PVQ_DIM, PVQ_K, enc);
-}
-
void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0) {
int i;
float eps = .1f;
@@ -101,47 +67,6 @@ void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale,
}
}
-
-
-static void decode_pvq(int *iy, int N, int K, ec_dec *dec) {
- int fits;
- celt_assert(N==24 || N==12 || N==6);
- fits = (N==24 && K<=9) || (N==12 && K<=16) || (N==6);
- /*printf("encode(%d,%d), fits=%d\n", N, K, fits);*/
- if (fits) {
- if (K > 0)
- decode_pulses(iy, N, K, dec);
- else
- OPUS_CLEAR(iy, N);
- }
- else {
- int N2 = N/2;
- int K0;
- /* FIXME: Don't use uniform probability for K0. */
- K0 = ec_dec_uint(dec, K+1);
- /*printf("K0 = %d\n", K0);*/
- decode_pvq(iy, N2, K0, dec);
- decode_pvq(&iy[N2], N2, K-K0, dec);
- }
-}
-
-void dred_decode_state(ec_enc *dec, float *x) {
- int k;
- int iy[PVQ_DIM];
- float norm = 0;
- decode_pvq(iy, PVQ_DIM, PVQ_K, dec);
- /*printf("tell: %d\n", ec_tell(dec)-tell1);*/
- for (k = 0; k < PVQ_DIM; k++)
- {
- norm += (float) iy[k] * iy[k];
- }
- norm = 1.f / sqrt(norm);
- for (k = 0; k < PVQ_DIM; k++)
- {
- x[k] = iy[k] * norm;
- }
-}
-
void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0) {
int i;
for (i=0;i<LATENT_DIM;i++) {
diff --git a/silk/dred_config.h b/silk/dred_config.h
index d8342f8e..891eb1cf 100644
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126
/* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 1
+#define DRED_EXPERIMENTAL_VERSION 2
#define DRED_EXPERIMENTAL_BYTES 2
@@ -41,7 +41,7 @@
/* these are inpart duplicates to the values defined in dred_rdovae_constants.h */
#define DRED_NUM_FEATURES 20
#define DRED_LATENT_DIM 80
-#define DRED_STATE_DIM 24
+#define DRED_STATE_DIM 80
#define DRED_SILK_ENCODER_DELAY (79+12-80)
#define DRED_FRAME_SIZE 160
#define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE))
diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c
index 04ba1ef3..500f33b8 100644
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -54,6 +54,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
int offset;
int q0;
int dQ;
+ int state_qoffset;
/* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */
@@ -66,7 +67,14 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
dQ = ec_dec_uint(&ec, 8);
/*printf("%d %d %d\n", dred_offset, q0, dQ);*/
- dred_decode_state(&ec, dec->state);
+ //dred_decode_state(&ec, dec->state);
+ state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
+ dred_decode_latents(
+ &ec,
+ dec->state,
+ quant_scales + state_qoffset,
+ r + state_qoffset,
+ p0 + state_qoffset);
/* decode newest to oldest and store oldest to newest */
for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2)
@@ -75,7 +83,7 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
if (8*num_bytes - ec_tell(&ec) <= 7)
break;
q_level = compute_quantizer(q0, dQ, i/2);
- offset = q_level * DRED_LATENT_DIM;
+ offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
dred_decode_latents(
&ec,
&dec->latents[(i/2)*DRED_LATENT_DIM],
diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index 5bae39e9..7b34cefe 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -197,7 +197,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
/* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */
if (enc->dred_offset < 6) {
enc->dred_offset += 8;
- OPUS_COPY(enc->initial_state, enc->state_buffer, 24);
+ OPUS_COPY(enc->initial_state, enc->state_buffer, DRED_STATE_DIM);
} else {
enc->latent_offset++;
}
@@ -221,6 +221,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
int ec_buffer_fill;
int q0;
int dQ;
+ int state_qoffset;
/* entropy coding of state and latents */
ec_enc_init(&ec_encoder, buf, max_bytes);
@@ -229,15 +230,21 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
ec_enc_uint(&ec_encoder, enc->dred_offset, 32);
ec_enc_uint(&ec_encoder, q0, 16);
ec_enc_uint(&ec_encoder, dQ, 8);
- dred_encode_state(&ec_encoder, enc->initial_state);
-
+ state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_STATE_DIM;
+ dred_encode_latents(
+ &ec_encoder,
+ enc->initial_state,
+ quant_scales + state_qoffset,
+ dead_zone + state_qoffset,
+ r + state_qoffset,
+ p0 + state_qoffset);
for (i = 0; i < IMIN(2*max_chunks, enc->latents_buffer_fill-enc->latent_offset-1); i += 2)
{
ec_enc ec_bak;
ec_bak = ec_encoder;
q_level = compute_quantizer(q0, dQ, i/2);
- offset = q_level * DRED_LATENT_DIM;
+ offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
dred_encode_latents(
&ec_encoder,
diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h
index 8ed323d3..2b77d581 100644
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -50,8 +50,8 @@ typedef struct {
int latent_offset;
float latents_buffer[DRED_MAX_FRAMES * DRED_LATENT_DIM];
int latents_buffer_fill;
- float state_buffer[24];
- float initial_state[24];
+ float state_buffer[DRED_STATE_DIM];
+ float initial_state[DRED_STATE_DIM];
float resample_mem[RESAMPLING_ORDER + 1];
LPCNetEncState lpcnet_enc_state;
RDOVAEEncState rdovae_enc;