Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-11-07 01:49:18 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-11-07 07:51:25 +0300
commit8d851c25135d777c344307e3b8ffd6a4551e2f0c (patch)
treeae82cfc9368bb6fd5f595b6fe6c532977a88da5d
parent2386a60ec644fadc437155cd6e5f6d4c561940d4 (diff)
Split RDOVAE stats in two (latents + states)
-rw-r--r--dnn/dred_rdovae.c21
-rw-r--r--dnn/torch/rdovae/export_rdovae_weights.py48
-rw-r--r--silk/dred_decoder.c20
-rw-r--r--silk/dred_encoder.c25
4 files changed, 48 insertions, 66 deletions
diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c
index b4797b5e..748a463a 100644
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -77,24 +77,3 @@ void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float
{
dred_rdovae_decode_qframe(h, model, qframe, z);
}
-
-
-const opus_uint8 * DRED_rdovae_get_p0_pointer(void)
-{
- return &dred_p0_q8[0];
-}
-
-const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
-{
- return &dred_dead_zone_q10[0];
-}
-
-const opus_uint8 * DRED_rdovae_get_r_pointer(void)
-{
- return &dred_r_q8[0];
-}
-
-const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
-{
- return &dred_quant_scales_q8[0];
-}
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index 3bcc7712..8571f8f9 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -49,16 +49,15 @@ from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
-def dump_statistical_model(writer, qembedding):
- w = qembedding.weight.detach()
- levels, dim = w.shape
- N = dim // 6
+def dump_statistical_model(writer, w, name):
+ levels = w.shape[0]
+ N = w.shape[-1]
print("printing statistical model")
- quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy()
- dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
- r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
- p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
+ quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy()
+ dead_zone = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy()
+ r = torch.sigmoid(w[:, 5 , :]).numpy()
+ p0 = torch.sigmoid(w[:, 4 , :]).numpy()
p0 = 1 - r ** (0.5 + 0.5 * p0)
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
@@ -66,17 +65,17 @@ def dump_statistical_model(writer, qembedding):
r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
- print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
- print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
- print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
- print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)
+ print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
+ print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
+ print_vector(writer.source, r_q15, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
+ print_vector(writer.source, p0_q15, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
writer.header.write(
f"""
-extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
-extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
-extern const opus_uint8 dred_r_q8[{levels * N}];
-extern const opus_uint8 dred_p0_q8[{levels * N}];
+extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}];
+extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}];
+extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
+extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
"""
)
@@ -113,6 +112,19 @@ f"""
"""
)
+ latent_out = model.get_submodule('core_encoder.module.z_dense').weight
+ states_out = model.get_submodule('core_encoder.module.state_dense_2').weight
+ nb_latents = latent_out.shape[0]
+ nb_states = states_out.shape[0]
+ # statistical model
+ qembedding = model.statistical_model.quant_embedding.weight.detach()
+ levels = qembedding.shape[0]
+ qembedding = torch.reshape(qembedding, (levels, 6, -1))
+
+ dump_statistical_model(stats_writer, qembedding[:, :, :nb_latents], 'latents')
+ dump_statistical_model(stats_writer, qembedding[:, :, nb_latents:], 'states')
+
+
# encoder
encoder_dense_layers = [
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH', False,),
@@ -187,10 +199,6 @@ f"""
del dec_writer
- # statistical model
- qembedding = model.statistical_model.quant_embedding
- dump_statistical_model(stats_writer, qembedding)
-
del stats_writer
# constants
diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c
index 68ba8559..2bedde2f 100644
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -36,6 +36,7 @@
#include "dred_coding.h"
#include "celt/entdec.h"
#include "celt/laplace.h"
+#include "dred_rdovae_stats_data.h"
/* From http://graphics.stanford.edu/~seander/bithacks.html#FixedSignExtend */
static int sign_extend(int x, int b) {
@@ -55,9 +56,6 @@ static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale,
int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
{
- const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
- const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
- const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_dec ec;
int q_level;
int i;
@@ -78,13 +76,13 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
/*printf("%d %d %d\n", dred_offset, q0, dQ);*/
//dred_decode_state(&ec, dec->state);
- state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM;
+ state_qoffset = q0*DRED_STATE_DIM;
dred_decode_latents(
&ec,
dec->state,
- quant_scales + state_qoffset,
- r + state_qoffset,
- p0 + state_qoffset,
+ dred_states_quant_scales_q8 + state_qoffset,
+ dred_states_r_q8 + state_qoffset,
+ dred_states_p0_q8 + state_qoffset,
DRED_STATE_DIM);
/* decode newest to oldest and store oldest to newest */
@@ -94,13 +92,13 @@ int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int mi
if (8*num_bytes - ec_tell(&ec) <= 7)
break;
q_level = compute_quantizer(q0, dQ, i/2);
- offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
+ offset = q_level*DRED_LATENT_DIM;
dred_decode_latents(
&ec,
&dec->latents[(i/2)*DRED_LATENT_DIM],
- quant_scales + offset,
- r + offset,
- p0 + offset,
+ dred_latents_quant_scales_q8 + offset,
+ dred_latents_r_q8 + offset,
+ dred_latents_p0_q8 + offset,
DRED_LATENT_DIM
);
diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index 3f842af0..6cda6334 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -44,6 +44,7 @@
#include "float_cast.h"
#include "os_support.h"
#include "celt/laplace.h"
+#include "dred_rdovae_stats_data.h"
int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
@@ -244,10 +245,6 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *
}
int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) {
- const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer();
- const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
- const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
- const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_enc ec_encoder;
int q_level;
@@ -265,14 +262,14 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
ec_enc_uint(&ec_encoder, enc->dred_offset, 32);
ec_enc_uint(&ec_encoder, q0, 16);
ec_enc_uint(&ec_encoder, dQ, 8);
- state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM;
+ state_qoffset = q0*DRED_STATE_DIM;
dred_encode_latents(
&ec_encoder,
enc->initial_state,
- quant_scales + state_qoffset,
- dead_zone + state_qoffset,
- r + state_qoffset,
- p0 + state_qoffset,
+ dred_states_quant_scales_q8 + state_qoffset,
+ dred_states_dead_zone_q10 + state_qoffset,
+ dred_states_r_q8 + state_qoffset,
+ dred_states_p0_q8 + state_qoffset,
DRED_STATE_DIM);
if (ec_tell(&ec_encoder) > 8*max_bytes) {
return 0;
@@ -283,15 +280,15 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
ec_bak = ec_encoder;
q_level = compute_quantizer(q0, dQ, i/2);
- offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
+ offset = q_level * DRED_LATENT_DIM;
dred_encode_latents(
&ec_encoder,
enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM,
- quant_scales + offset,
- dead_zone + offset,
- r + offset,
- p0 + offset,
+ dred_latents_quant_scales_q8 + offset,
+ dred_latents_dead_zone_q10 + offset,
+ dred_latents_r_q8 + offset,
+ dred_latents_p0_q8 + offset,
DRED_LATENT_DIM
);
if (ec_tell(&ec_encoder) > 8*max_bytes) {