Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-21 19:20:11 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-30 10:57:37 +0300
commite6149735687558945c295728fc0fa3a44c108e5c (patch)
treea3043e8a23ab080fb3423b6cb1c408e7558f6f89
parentf3b86f941408b37b0c0236eb5b8b09605b8a713b (diff)
Using a DenseNet for DRED
-rw-r--r--dnn/README.md126
-rw-r--r--dnn/dred_rdovae_dec.c100
-rw-r--r--dnn/dred_rdovae_dec.h14
-rw-r--r--dnn/dred_rdovae_enc.c90
-rw-r--r--dnn/dred_rdovae_enc.h15
-rw-r--r--dnn/nnet.c19
-rw-r--r--dnn/nnet.h1
-rw-r--r--dnn/torch/rdovae/export_rdovae_weights.py54
-rw-r--r--dnn/torch/rdovae/rdovae/rdovae.py139
9 files changed, 266 insertions, 292 deletions
diff --git a/dnn/README.md b/dnn/README.md
deleted file mode 100644
index ad4a6724..00000000
--- a/dnn/README.md
+++ /dev/null
@@ -1,126 +0,0 @@
-# LPCNet
-
-Low complexity implementation of the WaveRNN-based LPCNet algorithm, as described in:
-
-- J.-M. Valin, J. Skoglund, [LPCNet: Improving Neural Speech Synthesis Through Linear Prediction](https://jmvalin.ca/papers/lpcnet_icassp2019.pdf), *Proc. International Conference on Acoustics, Speech and Signal Processing (ICASSP)*, arXiv:1810.11846, 2019.
-- J.-M. Valin, U. Isik, P. Smaragdis, A. Krishnaswamy, [Neural Speech Synthesis on a Shoestring: Improving the Efficiency of LPCNet](https://jmvalin.ca/papers/improved_lpcnet.pdf), *Proc. ICASSP*, arxiv:2106.04129, 2022.
-- K. Subramani, J.-M. Valin, U. Isik, P. Smaragdis, A. Krishnaswamy, [End-to-end LPCNet: A Neural Vocoder With Fully-Differentiable LPC Estimation](https://jmvalin.ca/papers/lpcnet_end2end.pdf), *Proc. INTERSPEECH*, arxiv:2106.04129, 2022.
-
-For coding/PLC applications of LPCNet, see:
-
-- J.-M. Valin, J. Skoglund, [A Real-Time Wideband Neural Vocoder at 1.6 kb/s Using LPCNet](https://jmvalin.ca/papers/lpcnet_codec.pdf), *Proc. INTERSPEECH*, arxiv:1903.12087, 2019.
-- J. Skoglund, J.-M. Valin, [Improving Opus Low Bit Rate Quality with Neural Speech Synthesis](https://jmvalin.ca/papers/opusnet.pdf), *Proc. INTERSPEECH*, arxiv:1905.04628, 2020.
-- J.-M. Valin, A. Mustafa, C. Montgomery, T.B. Terriberry, M. Klingbeil, P. Smaragdis, A. Krishnaswamy, [Real-Time Packet Loss Concealment With Mixed Generative and Predictive Model](https://jmvalin.ca/papers/lpcnet_plc.pdf), *Proc. INTERSPEECH*, arxiv:2205.05785, 2022.
-- J.-M. Valin, J. Büthe, A. Mustafa, [Low-Bitrate Redundancy Coding of Speech Using a Rate-Distortion-Optimized Variational Autoencoder](https://jmvalin.ca/papers/valin_dred.pdf), *Proc. ICASSP*, arXiv:2212.04453, 2023. ([blog post](https://www.amazon.science/blog/neural-encoding-enables-more-efficient-recovery-of-lost-audio-packets))
-
-# Introduction
-
-Work in progress software for researching low CPU complexity algorithms for speech synthesis and compression by applying Linear Prediction techniques to WaveRNN. High quality speech can be synthesised on regular CPUs (around 3 GFLOP) with SIMD support (SSE2, SSSE3, AVX, AVX2/FMA, NEON currently supported). The code also supports very low bitrate compression at 1.6 kb/s.
-
-The BSD licensed software is written in C and Python/Keras. For training, a GTX 1080 Ti or better is recommended.
-
-This software is an open source starting point for LPCNet/WaveRNN-based speech synthesis and coding.
-
-# Using the existing software
-
-You can build the code using:
-
-```
-./autogen.sh
-./configure
-make
-```
-Note that the autogen.sh script is used when building from Git and will automatically download the latest model
-(models are too large to put in Git). By default, LPCNet will attempt to use 8-bit dot product instructions on AVX\*/Neon to
-speed up inference. To disable that (e.g. to avoid quantization effects when retraining), add --disable-dot-product to the
-configure script. LPCNet does not yet have a complete implementation for some of the integer operations on the ARMv7
-architecture so for now you will also need --disable-dot-product to successfully compile on 32-bit ARM.
-
-It is highly recommended to set the CFLAGS environment variable to enable AVX or NEON *prior* to running configure, otherwise
-no vectorization will take place and the code will be very slow. On a recent x86 CPU, something like
-```
-export CFLAGS='-Ofast -g -march=native'
-```
-should work. On ARM, you can enable Neon with:
-```
-export CFLAGS='-Ofast -g -mfpu=neon'
-```
-While not strictly required, the -Ofast flag will help with auto-vectorization, especially for dot products that
-cannot be optimized without -ffast-math (which -Ofast enables). Additionally, -falign-loops=32 has been shown to
-help on x86.
-
-You can test the capabilities of LPCNet using the lpcnet\_demo application. To encode a file:
-```
-./lpcnet_demo -encode input.pcm compressed.bin
-```
-where input.pcm is a 16-bit (machine endian) PCM file sampled at 16 kHz. The raw compressed data (no header)
-is written to compressed.bin and consists of 8 bytes per 40-ms packet.
-
-To decode:
-```
-./lpcnet_demo -decode compressed.bin output.pcm
-```
-where output.pcm is also 16-bit, 16 kHz PCM.
-
-Alternatively, you can run the uncompressed analysis/synthesis using -features
-instead of -encode and -synthesis instead of -decode.
-The same functionality is available in the form of a library. See include/lpcnet.h for the API.
-
-To try packet loss concealment (PLC), you first need a PLC model, which you can get with:
-```
-./download_model.sh plc-3b1eab4
-```
-or (for the PLC challenge submission):
-```
-./download_model.sh plc_challenge
-```
-PLC can be tested with:
-```
-./lpcnet_demo -plc_file noncausal_dc error_pattern.txt input.pcm output.pcm
-```
-where error_pattern.txt is a text file with one entry per 20-ms packet, with 1 meaning "packet lost" and 0 meaning "packet not lost".
-noncausal_dc is the non-causal (5-ms look-ahead) with special handling for DC offsets. It's also possible to use "noncausal", "causal",
-or "causal_dc".
-
-# Training a new model
-
-This codebase is also meant for research and it is possible to train new models. These are the steps to do that:
-
-1. Set up a Keras system with GPU.
-
-1. Generate training data:
- ```
- ./dump_data -train input.s16 features.f32 data.s16
- ```
- where the first file contains 16 kHz 16-bit raw PCM audio (no header) and the other files are output files. This program makes several passes over the data with different filters to generate a large amount of training data.
-
-1. Now that you have your files, train with:
- ```
- python3 training_tf2/train_lpcnet.py features.f32 data.s16 model_name
- ```
- and it will generate an h5 file for each iteration, with model\_name as prefix. If it stops with a
- "Failed to allocate RNN reserve space" message try specifying a smaller --batch-size for train\_lpcnet.py.
-
-1. You can synthesise speech with Python and your GPU card (very slow):
- ```
- ./dump_data -test test_input.s16 test_features.f32
- ./training_tf2/test_lpcnet.py lpcnet_model_name.h5 test_features.f32 test.s16
- ```
-
-1. Or with C on a CPU (C inference is much faster):
- First extract the model files nnet\_data.h and nnet\_data.c
- ```
- ./training_tf2/dump_lpcnet.py lpcnet_model_name.h5
- ```
- and move the generated nnet\_data.\* files to the src/ directory.
- Then you just need to rebuild the software and use lpcnet\_demo as explained above.
-
-# Speech Material for Training
-
-Suitable training material can be obtained from [Open Speech and Language Resources](https://www.openslr.org/). See the datasets.txt file for details on suitable training data.
-
-# Reading Further
-
-1. [LPCNet: DSP-Boosted Neural Speech Synthesis](https://people.xiph.org/~jm/demo/lpcnet/)
-1. [A Real-Time Wideband Neural Vocoder at 1.6 kb/s Using LPCNet](https://people.xiph.org/~jm/demo/lpcnet_codec/)
-1. Sample model files (check compatibility): https://media.xiph.org/lpcnet/data/
diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c
index c79723b3..6cfd6577 100644
--- a/dnn/dred_rdovae_dec.c
+++ b/dnn/dred_rdovae_dec.c
@@ -33,16 +33,36 @@
#include "dred_rdovae_constants.h"
#include "os_support.h"
+static void conv1_cond_init(float *mem, const float *input, int len, int dilation, int *init)
+{
+ if (!*init) {
+ int i;
+ for (i=0;i<dilation;i++) OPUS_COPY(&mem[i*len], input, len);
+ }
+ *init = 1;
+}
+
void dred_rdovae_dec_init_states(
RDOVAEDecState *h, /* io: state buffer handle */
const RDOVAEDec *model,
const float *initial_state /* i: initial state */
)
{
- /* initialize GRU states from initial state */
- compute_generic_dense(&model->state1, h->dense2_state, initial_state, ACTIVATION_TANH);
- compute_generic_dense(&model->state2, h->dense4_state, initial_state, ACTIVATION_TANH);
- compute_generic_dense(&model->state3, h->dense6_state, initial_state, ACTIVATION_TANH);
+ float hidden[DEC_HIDDEN_INIT_OUT_SIZE];
+ float state_init[DEC_GRU1_STATE_SIZE+DEC_GRU2_STATE_SIZE+DEC_GRU3_STATE_SIZE+DEC_GRU4_STATE_SIZE+DEC_GRU5_STATE_SIZE];
+ int counter=0;
+ compute_generic_dense(&model->dec_hidden_init, hidden, initial_state, ACTIVATION_TANH);
+ compute_generic_dense(&model->dec_gru_init, state_init, hidden, ACTIVATION_TANH);
+ OPUS_COPY(h->gru1_state, state_init, DEC_GRU1_STATE_SIZE);
+ counter += DEC_GRU1_STATE_SIZE;
+ OPUS_COPY(h->gru2_state, &state_init[counter], DEC_GRU2_STATE_SIZE);
+ counter += DEC_GRU2_STATE_SIZE;
+ OPUS_COPY(h->gru3_state, &state_init[counter], DEC_GRU3_STATE_SIZE);
+ counter += DEC_GRU3_STATE_SIZE;
+ OPUS_COPY(h->gru4_state, &state_init[counter], DEC_GRU4_STATE_SIZE);
+ counter += DEC_GRU4_STATE_SIZE;
+ OPUS_COPY(h->gru5_state, &state_init[counter], DEC_GRU5_STATE_SIZE);
+ h->initialized = 0;
}
@@ -53,44 +73,48 @@ void dred_rdovae_decode_qframe(
const float *input /* i: latent vector */
)
{
- float buffer[DEC_DENSE1_OUT_SIZE + DEC_DENSE2_OUT_SIZE + DEC_DENSE3_OUT_SIZE + DEC_DENSE4_OUT_SIZE + DEC_DENSE5_OUT_SIZE + DEC_DENSE6_OUT_SIZE + DEC_DENSE7_OUT_SIZE + DEC_DENSE8_OUT_SIZE];
+ float buffer[DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
+ + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE];
int output_index = 0;
- int input_index = 0;
/* run encoder stack and concatenate output in buffer*/
compute_generic_dense(&model->dec_dense1, &buffer[output_index], input, ACTIVATION_TANH);
- input_index = output_index;
output_index += DEC_DENSE1_OUT_SIZE;
- compute_generic_gru(&model->dec_dense2_input, &model->dec_dense2_recurrent, dec_state->dense2_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], dec_state->dense2_state, DEC_DENSE2_OUT_SIZE);
- input_index = output_index;
- output_index += DEC_DENSE2_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += DEC_DENSE3_OUT_SIZE;
-
- compute_generic_gru(&model->dec_dense4_input, &model->dec_dense4_recurrent, dec_state->dense4_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], dec_state->dense4_state, DEC_DENSE4_OUT_SIZE);
- input_index = output_index;
- output_index += DEC_DENSE4_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += DEC_DENSE5_OUT_SIZE;
-
- compute_generic_gru(&model->dec_dense6_input, &model->dec_dense6_recurrent, dec_state->dense6_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], dec_state->dense6_state, DEC_DENSE6_OUT_SIZE);
- input_index = output_index;
- output_index += DEC_DENSE6_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += DEC_DENSE7_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- output_index += DEC_DENSE8_OUT_SIZE;
-
- compute_generic_dense(&model->dec_final, qframe, buffer, ACTIVATION_LINEAR);
+ compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
+ output_index += DEC_GRU1_OUT_SIZE;
+ conv1_cond_init(dec_state->conv1_state, buffer, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV1_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
+ output_index += DEC_GRU2_OUT_SIZE;
+ conv1_cond_init(dec_state->conv2_state, buffer, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV2_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
+ output_index += DEC_GRU3_OUT_SIZE;
+ conv1_cond_init(dec_state->conv3_state, buffer, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV3_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
+ output_index += DEC_GRU4_OUT_SIZE;
+ conv1_cond_init(dec_state->conv4_state, buffer, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV4_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
+ output_index += DEC_GRU5_OUT_SIZE;
+ conv1_cond_init(dec_state->conv5_state, buffer, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV5_OUT_SIZE;
+
+ compute_generic_dense(&model->dec_output, qframe, buffer, ACTIVATION_LINEAR);
}
diff --git a/dnn/dred_rdovae_dec.h b/dnn/dred_rdovae_dec.h
index 008551b5..4e039cf2 100644
--- a/dnn/dred_rdovae_dec.h
+++ b/dnn/dred_rdovae_dec.h
@@ -33,9 +33,17 @@
#include "dred_rdovae_stats_data.h"
struct RDOVAEDecStruct {
- float dense2_state[DEC_DENSE2_STATE_SIZE];
- float dense4_state[DEC_DENSE2_STATE_SIZE];
- float dense6_state[DEC_DENSE2_STATE_SIZE];
+ int initialized;
+ float gru1_state[DEC_GRU1_STATE_SIZE];
+ float gru2_state[DEC_GRU2_STATE_SIZE];
+ float gru3_state[DEC_GRU3_STATE_SIZE];
+ float gru4_state[DEC_GRU4_STATE_SIZE];
+ float gru5_state[DEC_GRU5_STATE_SIZE];
+ float conv1_state[DEC_CONV1_STATE_SIZE];
+ float conv2_state[DEC_CONV2_STATE_SIZE];
+ float conv3_state[DEC_CONV3_STATE_SIZE];
+ float conv4_state[DEC_CONV4_STATE_SIZE];
+ float conv5_state[DEC_CONV5_STATE_SIZE];
};
void dred_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state);
diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c
index 9361af17..b34ac40e 100644
--- a/dnn/dred_rdovae_enc.c
+++ b/dnn/dred_rdovae_enc.c
@@ -35,6 +35,15 @@
#include "dred_rdovae_enc.h"
#include "os_support.h"
+static void conv1_cond_init(float *mem, const float *input, int len, int dilation, int *init)
+{
+ if (!*init) {
+ int i;
+ for (i=0;i<dilation;i++) OPUS_COPY(&mem[i*len], input, len);
+ }
+ *init = 1;
+}
+
void dred_rdovae_encode_dframe(
RDOVAEEncState *enc_state, /* io: encoder state */
const RDOVAEEnc *model,
@@ -43,52 +52,53 @@ void dred_rdovae_encode_dframe(
const float *input /* i: double feature frame (concatenated) */
)
{
- float buffer[ENC_DENSE1_OUT_SIZE + ENC_DENSE2_OUT_SIZE + ENC_DENSE3_OUT_SIZE + ENC_DENSE4_OUT_SIZE + ENC_DENSE5_OUT_SIZE + ENC_DENSE6_OUT_SIZE + ENC_DENSE7_OUT_SIZE + ENC_DENSE8_OUT_SIZE + GDENSE1_OUT_SIZE];
+ float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE
+ + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE];
+ float state_hidden[GDENSE1_OUT_SIZE];
int output_index = 0;
- int input_index = 0;
/* run encoder stack and concatenate output in buffer*/
compute_generic_dense(&model->enc_dense1, &buffer[output_index], input, ACTIVATION_TANH);
- input_index = output_index;
output_index += ENC_DENSE1_OUT_SIZE;
- compute_generic_gru(&model->enc_dense2_input, &model->enc_dense2_recurrent, enc_state->dense2_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], enc_state->dense2_state, ENC_DENSE2_OUT_SIZE);
- input_index = output_index;
- output_index += ENC_DENSE2_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += ENC_DENSE3_OUT_SIZE;
-
- compute_generic_gru(&model->enc_dense4_input, &model->enc_dense4_recurrent, enc_state->dense4_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], enc_state->dense4_state, ENC_DENSE4_OUT_SIZE);
- input_index = output_index;
- output_index += ENC_DENSE4_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += ENC_DENSE5_OUT_SIZE;
-
- compute_generic_gru(&model->enc_dense6_input, &model->enc_dense6_recurrent, enc_state->dense6_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], enc_state->dense6_state, ENC_DENSE6_OUT_SIZE);
- input_index = output_index;
- output_index += ENC_DENSE6_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += ENC_DENSE7_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- output_index += ENC_DENSE8_OUT_SIZE;
-
- /* compute latents from concatenated input buffer */
- compute_generic_conv1d(&model->bits_dense, latents, enc_state->bits_dense_state, buffer, BITS_DENSE_IN_SIZE, ACTIVATION_LINEAR);
-
+ compute_generic_gru(&model->enc_gru1_input, &model->enc_gru1_recurrent, enc_state->gru1_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru1_state, ENC_GRU1_OUT_SIZE);
+ output_index += ENC_GRU1_OUT_SIZE;
+ conv1_cond_init(enc_state->conv1_state, buffer, output_index, 1, &enc_state->initialized);
+ compute_generic_conv1d(&model->enc_conv1, &buffer[output_index], enc_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += ENC_CONV1_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru2_input, &model->enc_gru2_recurrent, enc_state->gru2_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru2_state, ENC_GRU2_OUT_SIZE);
+ output_index += ENC_GRU2_OUT_SIZE;
+ conv1_cond_init(enc_state->conv2_state, buffer, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv2, &buffer[output_index], enc_state->conv2_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV2_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru3_input, &model->enc_gru3_recurrent, enc_state->gru3_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru3_state, ENC_GRU3_OUT_SIZE);
+ output_index += ENC_GRU3_OUT_SIZE;
+ conv1_cond_init(enc_state->conv3_state, buffer, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv3, &buffer[output_index], enc_state->conv3_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV3_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru4_input, &model->enc_gru4_recurrent, enc_state->gru4_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru4_state, ENC_GRU4_OUT_SIZE);
+ output_index += ENC_GRU4_OUT_SIZE;
+ conv1_cond_init(enc_state->conv4_state, buffer, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv4, &buffer[output_index], enc_state->conv4_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV4_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru5_input, &model->enc_gru5_recurrent, enc_state->gru5_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru5_state, ENC_GRU5_OUT_SIZE);
+ output_index += ENC_GRU5_OUT_SIZE;
+ conv1_cond_init(enc_state->conv5_state, buffer, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV5_OUT_SIZE;
+
+ compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR);
/* next, calculate initial state */
- compute_generic_dense(&model->gdense1, &buffer[output_index], buffer, ACTIVATION_TANH);
- input_index = output_index;
- compute_generic_dense(&model->gdense2, initial_state, &buffer[input_index], ACTIVATION_TANH);
-
+ compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH);
+ compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_TANH);
}
diff --git a/dnn/dred_rdovae_enc.h b/dnn/dred_rdovae_enc.h
index 70ff6adc..832bd737 100644
--- a/dnn/dred_rdovae_enc.h
+++ b/dnn/dred_rdovae_enc.h
@@ -33,10 +33,17 @@
#include "dred_rdovae_enc_data.h"
struct RDOVAEEncStruct {
- float dense2_state[3 * ENC_DENSE2_STATE_SIZE];
- float dense4_state[3 * ENC_DENSE4_STATE_SIZE];
- float dense6_state[3 * ENC_DENSE6_STATE_SIZE];
- float bits_dense_state[BITS_DENSE_STATE_SIZE];
+ int initialized;
+ float gru1_state[ENC_GRU1_STATE_SIZE];
+ float gru2_state[ENC_GRU2_STATE_SIZE];
+ float gru3_state[ENC_GRU3_STATE_SIZE];
+ float gru4_state[ENC_GRU4_STATE_SIZE];
+ float gru5_state[ENC_GRU5_STATE_SIZE];
+ float conv1_state[ENC_CONV1_STATE_SIZE];
+ float conv2_state[2*ENC_CONV2_STATE_SIZE];
+ float conv3_state[2*ENC_CONV3_STATE_SIZE];
+ float conv4_state[2*ENC_CONV4_STATE_SIZE];
+ float conv5_state[2*ENC_CONV5_STATE_SIZE];
};
void dred_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input);
diff --git a/dnn/nnet.c b/dnn/nnet.c
index 73c49fc3..12f3747d 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -366,6 +366,25 @@ void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem,
OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size);
}
+void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation)
+{
+ float tmp[MAX_CONV_INPUTS_ALL];
+ int ksize = layer->nb_inputs/input_size;
+ int i;
+ celt_assert(input != output);
+ celt_assert(layer->nb_inputs <= MAX_CONV_INPUTS_ALL);
+ if (dilation==1) OPUS_COPY(tmp, mem, layer->nb_inputs-input_size);
+ else for (i=0;i<ksize-1;i++) OPUS_COPY(&tmp[i*input_size], &mem[i*input_size*dilation], input_size);
+ OPUS_COPY(&tmp[layer->nb_inputs-input_size], input, input_size);
+ compute_linear(layer, output, tmp);
+ compute_activation(output, output, layer->nb_outputs, activation);
+ if (dilation==1) OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size);
+ else {
+ OPUS_COPY(mem, &mem[input_size], input_size*dilation*(ksize-1)-input_size);
+ OPUS_COPY(&mem[input_size*dilation*(ksize-1)-input_size], input, input_size);
+ }
+}
+
void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
{
LinearLayer matrix;
diff --git a/dnn/nnet.h b/dnn/nnet.h
index 2b43308a..f054a92d 100644
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -135,6 +135,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in);
void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation);
void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in);
void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation);
+void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation);
void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation);
void compute_activation(float *output, const float *input, int N, int activation);
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index f9c1db81..c2cc61bd 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -116,10 +116,7 @@ f"""
# encoder
encoder_dense_layers = [
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'),
- ('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'),
- ('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'),
- ('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'),
- ('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'),
+ ('core_encoder.module.z_dense' , 'enc_zdense', 'LINEAR'),
('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'),
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH')
]
@@ -130,9 +127,11 @@ f"""
encoder_gru_layers = [
- ('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'),
- ('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'),
- ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH')
+ ('core_encoder.module.gru1' , 'enc_gru1', 'TANH'),
+ ('core_encoder.module.gru2' , 'enc_gru2', 'TANH'),
+ ('core_encoder.module.gru3' , 'enc_gru3', 'TANH'),
+ ('core_encoder.module.gru4' , 'enc_gru4', 'TANH'),
+ ('core_encoder.module.gru5' , 'enc_gru5', 'TANH'),
]
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
@@ -140,7 +139,11 @@ f"""
encoder_conv_layers = [
- ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR')
+ ('core_encoder.module.conv1.conv' , 'enc_conv1', 'TANH'),
+ ('core_encoder.module.conv2.conv' , 'enc_conv2', 'TANH'),
+ ('core_encoder.module.conv3.conv' , 'enc_conv3', 'TANH'),
+ ('core_encoder.module.conv4.conv' , 'enc_conv4', 'TANH'),
+ ('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH'),
]
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in encoder_conv_layers])
@@ -150,15 +153,10 @@ f"""
# decoder
decoder_dense_layers = [
- ('core_decoder.module.gru_1_init' , 'state1', 'TANH'),
- ('core_decoder.module.gru_2_init' , 'state2', 'TANH'),
- ('core_decoder.module.gru_3_init' , 'state3', 'TANH'),
- ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
- ('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'),
- ('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'),
- ('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'),
- ('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'),
- ('core_decoder.module.output' , 'dec_final', 'LINEAR')
+ ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
+ ('core_decoder.module.output' , 'dec_output', 'LINEAR'),
+ ('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH'),
+ ('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH'),
]
for name, export_name, _ in decoder_dense_layers:
@@ -167,14 +165,26 @@ f"""
decoder_gru_layers = [
- ('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'),
- ('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'),
- ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH')
+ ('core_decoder.module.gru1' , 'dec_gru1', 'TANH'),
+ ('core_decoder.module.gru2' , 'dec_gru2', 'TANH'),
+ ('core_decoder.module.gru3' , 'dec_gru3', 'TANH'),
+ ('core_decoder.module.gru4' , 'dec_gru4', 'TANH'),
+ ('core_decoder.module.gru5' , 'dec_gru5', 'TANH'),
]
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
for name, export_name, _ in decoder_gru_layers])
+ decoder_conv_layers = [
+ ('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH'),
+ ('core_decoder.module.conv2.conv' , 'dec_conv2', 'TANH'),
+ ('core_decoder.module.conv3.conv' , 'dec_conv3', 'TANH'),
+ ('core_decoder.module.conv4.conv' , 'dec_conv4', 'TANH'),
+ ('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH'),
+ ]
+
+ dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in decoder_conv_layers])
+
del dec_writer
# statistical model
@@ -196,7 +206,7 @@ f"""
#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
-#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
+#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)}
#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
@@ -268,4 +278,4 @@ if __name__ == "__main__":
elif args.format == 'numpy':
numpy_export(args, model)
else:
- raise ValueError(f'error: unknown export format {args.format}') \ No newline at end of file
+ raise ValueError(f'error: unknown export format {args.format}')
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 0dc943ec..ad118b99 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -224,6 +224,17 @@ def weight_clip_factory(max_value):
# RDOVAE module and submodules
+class MyConv(nn.Module):
+ def __init__(self, input_dim, output_dim, dilation=1):
+ super(MyConv, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.dilation=dilation
+ self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
+ def forward(self, x, state=None):
+ device = x.device
+ conv_in = torch.cat([x[:,0:1,:].repeat(1,self.dilation,1), x], -2).permute(0, 2, 1)
+ return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
class CoreEncoder(nn.Module):
STATE_HIDDEN = 128
@@ -248,22 +259,28 @@ class CoreEncoder(nn.Module):
# derived parameters
self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
- self.conv_input_channels = 5 * cond_size + 3 * cond_size2
# layers
- self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
- self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
- self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
- self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
- self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
- self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
- self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
- self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
- self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
-
- self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
+ self.dense_1 = nn.Linear(self.input_dim, 64)
+ self.gru1 = nn.GRU(64, 64, batch_first=True)
+ self.conv1 = MyConv(128, 32)
+ self.gru2 = nn.GRU(160, 64, batch_first=True)
+ self.conv2 = MyConv(224, 32, dilation=2)
+ self.gru3 = nn.GRU(256, 64, batch_first=True)
+ self.conv3 = MyConv(320, 32, dilation=2)
+ self.gru4 = nn.GRU(352, 64, batch_first=True)
+ self.conv4 = MyConv(416, 32, dilation=2)
+ self.gru5 = nn.GRU(448, 64, batch_first=True)
+ self.conv5 = MyConv(512, 32, dilation=2)
+
+ self.z_dense = nn.Linear(544, self.output_dim)
+
+
+ self.state_dense_1 = nn.Linear(544, self.STATE_HIDDEN)
self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
+ nb_params = sum(p.numel() for p in self.parameters())
+ print(f"encoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
@@ -278,26 +295,23 @@ class CoreEncoder(nn.Module):
device = x.device
# run encoding layer stack
- x1 = torch.tanh(self.dense_1(x))
- x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
- x3 = torch.tanh(self.dense_2(x2))
- x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
- x5 = torch.tanh(self.dense_3(x4))
- x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
- x7 = torch.tanh(self.dense_4(x6))
- x8 = torch.tanh(self.dense_5(x7))
-
- # concatenation of all hidden layer outputs
- x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+ x = torch.tanh(self.dense_1(x))
+ x = torch.cat([x, self.gru1(x)[0]], -1)
+ x = torch.cat([x, self.conv1(x)], -1)
+ x = torch.cat([x, self.gru2(x)[0]], -1)
+ x = torch.cat([x, self.conv2(x)], -1)
+ x = torch.cat([x, self.gru3(x)[0]], -1)
+ x = torch.cat([x, self.conv3(x)], -1)
+ x = torch.cat([x, self.gru4(x)[0]], -1)
+ x = torch.cat([x, self.conv4(x)], -1)
+ x = torch.cat([x, self.gru5(x)[0]], -1)
+ x = torch.cat([x, self.conv5(x)], -1)
+ z = self.z_dense(x)
# init state for decoder
- states = torch.tanh(self.state_dense_1(x9))
+ states = torch.tanh(self.state_dense_1(x))
states = torch.tanh(self.state_dense_2(states))
- # latent representation via convolution
- x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
- z = self.conv1(x9).permute(0, 2, 1)
-
return z, states
@@ -325,47 +339,54 @@ class CoreDecoder(nn.Module):
self.input_size = self.input_dim
- self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
-
# layers
- self.dense_1 = nn.Linear(self.input_size, cond_size2)
- self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True)
- self.dense_2 = nn.Linear(cond_size, cond_size2)
- self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True)
- self.dense_3 = nn.Linear(cond_size, cond_size2)
- self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True)
- self.dense_4 = nn.Linear(cond_size, cond_size2)
- self.dense_5 = nn.Linear(cond_size2, cond_size2)
-
- self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
-
-
- self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
- self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
- self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
-
+ self.dense_1 = nn.Linear(self.input_size, 96)
+ self.gru1 = nn.GRU(96, 64, batch_first=True)
+ self.conv1 = MyConv(160, 32)
+ self.gru2 = nn.GRU(192, 64, batch_first=True)
+ self.conv2 = MyConv(256, 32)
+ self.gru3 = nn.GRU(288, 64, batch_first=True)
+ self.conv3 = MyConv(352, 32)
+ self.gru4 = nn.GRU(384, 64, batch_first=True)
+ self.conv4 = MyConv(448, 32)
+ self.gru5 = nn.GRU(480, 64, batch_first=True)
+ self.conv5 = MyConv(544, 32)
+ self.output = nn.Linear(576, self.FRAMES_PER_STEP * self.output_dim)
+
+ self.hidden_init = nn.Linear(self.state_size, 128)
+ self.gru_init = nn.Linear(128, 320)
+
+ nb_params = sum(p.numel() for p in self.parameters())
+ print(f"decoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
def forward(self, z, initial_state):
- gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
- gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
- gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
+ hidden = torch.tanh(self.hidden_init(initial_state))
+ gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2))
+ h1_state = gru_state[:,:,:64].contiguous()
+ h2_state = gru_state[:,:,64:128].contiguous()
+ h3_state = gru_state[:,:,128:192].contiguous()
+ h4_state = gru_state[:,:,192:256].contiguous()
+ h5_state = gru_state[:,:,256:].contiguous()
# run decoding layer stack
- x1 = torch.tanh(self.dense_1(z))
- x2, _ = self.gru_1(x1, gru_1_state)
- x3 = torch.tanh(self.dense_2(x2))
- x4, _ = self.gru_2(x3, gru_2_state)
- x5 = torch.tanh(self.dense_3(x4))
- x6, _ = self.gru_3(x5, gru_3_state)
- x7 = torch.tanh(self.dense_4(x6))
- x8 = torch.tanh(self.dense_5(x7))
- x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+ x = torch.tanh(self.dense_1(z))
+
+ x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
+ x = torch.cat([x, self.conv1(x)], -1)
+ x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
+ x = torch.cat([x, self.conv2(x)], -1)
+ x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
+ x = torch.cat([x, self.conv3(x)], -1)
+ x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
+ x = torch.cat([x, self.conv4(x)], -1)
+ x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
+ x = torch.cat([x, self.conv5(x)], -1)
# output layer and reshaping
- x10 = self.output(x9)
+ x10 = self.output(x)
features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
return features