Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-11-06 11:10:59 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-11-06 11:16:43 +0300
commit544b3e576c8edd1785914c988882b62d60652f26 (patch)
treed1dd9e1a90ce85f3fbb462228c7adea29fe10d70
parent98b8be09d56d03d220fff3536842c0703bae865c (diff)
DRED: quantize r and p0 parameters with 8 bits
Only code non-degenerate symbols, which makes the encoder faster
-rwxr-xr-xautogen.sh2
-rw-r--r--dnn/dred_rdovae.c8
-rw-r--r--dnn/dred_rdovae.h4
-rw-r--r--dnn/torch/rdovae/export_rdovae_weights.py12
-rw-r--r--dnn/torch/rdovae/rdovae/rdovae.py53
-rw-r--r--silk/dred_config.h2
-rw-r--r--silk/dred_decoder.c9
-rw-r--r--silk/dred_encoder.c10
8 files changed, 52 insertions, 48 deletions
diff --git a/autogen.sh b/autogen.sh
index cc70a11a..efd3ef3d 100755
--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
-dnn/download_model.sh c99054d
+dnn/download_model.sh 98b8be0
echo "Updating build configuration files, please wait...."
diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c
index a17b957b..b4797b5e 100644
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -79,9 +79,9 @@ void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float
}
-const opus_uint16 * DRED_rdovae_get_p0_pointer(void)
+const opus_uint8 * DRED_rdovae_get_p0_pointer(void)
{
- return &dred_p0_q15[0];
+ return &dred_p0_q8[0];
}
const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
@@ -89,9 +89,9 @@ const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
return &dred_dead_zone_q10[0];
}
-const opus_uint16 * DRED_rdovae_get_r_pointer(void)
+const opus_uint8 * DRED_rdovae_get_r_pointer(void)
{
- return &dred_r_q15[0];
+ return &dred_r_q8[0];
}
const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
diff --git a/dnn/dred_rdovae.h b/dnn/dred_rdovae.h
index f2c3235e..05da69ba 100644
--- a/dnn/dred_rdovae.h
+++ b/dnn/dred_rdovae.h
@@ -58,9 +58,9 @@ void DRED_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, cons
void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z);
-const opus_uint16 * DRED_rdovae_get_p0_pointer(void);
+const opus_uint8 * DRED_rdovae_get_p0_pointer(void);
const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void);
-const opus_uint16 * DRED_rdovae_get_r_pointer(void);
+const opus_uint8 * DRED_rdovae_get_r_pointer(void);
const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void);
#endif
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index 9a35c17a..3bcc7712 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -63,20 +63,20 @@ def dump_statistical_model(writer, qembedding):
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16)
- r_q15 = np.round(r * 2**15).astype(np.uint16)
- p0_q15 = np.round(p0 * 2**15).astype(np.uint16)
+ r_q15 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
+ p0_q15 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
- print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
- print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
+ print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
+ print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)
writer.header.write(
f"""
extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
-extern const opus_uint16 dred_r_q15[{levels * N}];
-extern const opus_uint16 dred_p0_q15[{levels * N}];
+extern const opus_uint8 dred_r_q8[{levels * N}];
+extern const opus_uint8 dred_p0_q8[{levels * N}];
"""
)
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 3552cf90..09b6801a 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -222,6 +222,9 @@ def weight_clip_factory(max_value):
return clip_weights
+def n(x):
+ return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
+
# RDOVAE module and submodules
class MyConv(nn.Module):
@@ -295,17 +298,17 @@ class CoreEncoder(nn.Module):
device = x.device
# run encoding layer stack
- x = torch.tanh(self.dense_1(x))
- x = torch.cat([x, self.gru1(x)[0]], -1)
- x = torch.cat([x, self.conv1(x)], -1)
- x = torch.cat([x, self.gru2(x)[0]], -1)
- x = torch.cat([x, self.conv2(x)], -1)
- x = torch.cat([x, self.gru3(x)[0]], -1)
- x = torch.cat([x, self.conv3(x)], -1)
- x = torch.cat([x, self.gru4(x)[0]], -1)
- x = torch.cat([x, self.conv4(x)], -1)
- x = torch.cat([x, self.gru5(x)[0]], -1)
- x = torch.cat([x, self.conv5(x)], -1)
+ x = n(torch.tanh(self.dense_1(x)))
+ x = torch.cat([x, n(self.gru1(x)[0])], -1)
+ x = torch.cat([x, n(self.conv1(x))], -1)
+ x = torch.cat([x, n(self.gru2(x)[0])], -1)
+ x = torch.cat([x, n(self.conv2(x))], -1)
+ x = torch.cat([x, n(self.gru3(x)[0])], -1)
+ x = torch.cat([x, n(self.conv3(x))], -1)
+ x = torch.cat([x, n(self.gru4(x)[0])], -1)
+ x = torch.cat([x, n(self.conv4(x))], -1)
+ x = torch.cat([x, n(self.gru5(x)[0])], -1)
+ x = torch.cat([x, n(self.conv5(x))], -1)
z = self.z_dense(x)
# init state for decoder
@@ -372,18 +375,18 @@ class CoreDecoder(nn.Module):
h5_state = gru_state[:,:,384:].contiguous()
# run decoding layer stack
- x = torch.tanh(self.dense_1(z))
-
- x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
- x = torch.cat([x, self.conv1(x)], -1)
- x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
- x = torch.cat([x, self.conv2(x)], -1)
- x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
- x = torch.cat([x, self.conv3(x)], -1)
- x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
- x = torch.cat([x, self.conv4(x)], -1)
- x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
- x = torch.cat([x, self.conv5(x)], -1)
+ x = n(torch.tanh(self.dense_1(z)))
+
+ x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
+ x = torch.cat([x, n(self.conv1(x))], -1)
+ x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
+ x = torch.cat([x, n(self.conv2(x))], -1)
+ x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
+ x = torch.cat([x, n(self.conv3(x))], -1)
+ x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
+ x = torch.cat([x, n(self.conv4(x))], -1)
+ x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
+ x = torch.cat([x, n(self.conv5(x))], -1)
# output layer and reshaping
x10 = self.output(x)
@@ -451,7 +454,7 @@ class RDOVAE(nn.Module):
cond_size2,
state_dim=24,
split_mode='split',
- clip_weights=True,
+ clip_weights=False,
pvq_num_pulses=82,
state_dropout_rate=0):
@@ -487,7 +490,7 @@ class RDOVAE(nn.Module):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
- def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24):
+ def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
enc_stride = self.enc_stride
dec_stride = self.dec_stride
diff --git a/silk/dred_config.h b/silk/dred_config.h
index 06ff397b..5e3e74a3 100644
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126
/* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 6
+#define DRED_EXPERIMENTAL_VERSION 7
#define DRED_EXPERIMENTAL_BYTES 2
diff --git a/silk/dred_decoder.c b/silk/dred_decoder.c
index 658d340f..68ba8559 100644
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -43,20 +43,21 @@ static int sign_extend(int x, int b) {
return (x ^ m) - m;
}
-static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
+static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
int i;
for (i=0;i<dim;i++) {
int q;
- q = ec_laplace_decode_p0(dec, p0[i], r[i]);
+ if (r[i] == 0 || p0[i] == 255) q = 0;
+ else q = ec_laplace_decode_p0(dec, p0[i]<<7, r[i]<<7);
x[i] = q*256.f/(scale[i] == 0 ? 1 : scale[i]);
}
}
int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
{
- const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
+ const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
- const opus_uint16 *r = DRED_rdovae_get_r_pointer();
+ const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_dec ec;
int q_level;
int i;
diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index 6b585416..3f842af0 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -217,7 +217,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
}
}
-static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint16 *r, const opus_uint16 *p0, int dim) {
+static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
int i;
int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
@@ -238,16 +238,16 @@ static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *
}
for (i=0;i<dim;i++) {
/* Make the impossible actually impossible. */
- if (r[i] == 0 || p0[i] >= 32767) q[i] = 0;
- ec_laplace_encode_p0(enc, q[i], p0[i], r[i]);
+ if (r[i] == 0 || p0[i] == 255) q[i] = 0;
+ else ec_laplace_encode_p0(enc, q[i], p0[i]<<7, r[i]<<7);
}
}
int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) {
const opus_uint16 *dead_zone = DRED_rdovae_get_dead_zone_pointer();
- const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
+ const opus_uint8 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
- const opus_uint16 *r = DRED_rdovae_get_r_pointer();
+ const opus_uint8 *r = DRED_rdovae_get_r_pointer();
ec_enc ec_encoder;
int q_level;