Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-21 19:20:11 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-10-02 08:43:44 +0300
commit27663d364188711e5f662304357cb532e689bfe2 (patch)
tree48462bcf3f2c85aeeedcfd534cc3b3752203f68d
parent8e8edf71bde743c21960a97815f14dd48c86d6ad (diff)
Using a DenseNet for DRED
-rw-r--r--dnn/dred_rdovae_dec.c100
-rw-r--r--dnn/dred_rdovae_dec.h14
-rw-r--r--dnn/dred_rdovae_enc.c90
-rw-r--r--dnn/dred_rdovae_enc.h15
-rw-r--r--dnn/nnet.c19
-rw-r--r--dnn/nnet.h1
-rw-r--r--dnn/torch/rdovae/export_rdovae_weights.py54
-rw-r--r--dnn/torch/rdovae/rdovae/rdovae.py143
8 files changed, 268 insertions, 168 deletions
diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c
index c79723b3..1fef375a 100644
--- a/dnn/dred_rdovae_dec.c
+++ b/dnn/dred_rdovae_dec.c
@@ -33,16 +33,36 @@
#include "dred_rdovae_constants.h"
#include "os_support.h"
+static void conv1_cond_init(float *mem, int len, int dilation, int *init)
+{
+ if (!*init) {
+ int i;
+ for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
+ }
+ *init = 1;
+}
+
void dred_rdovae_dec_init_states(
RDOVAEDecState *h, /* io: state buffer handle */
const RDOVAEDec *model,
const float *initial_state /* i: initial state */
)
{
- /* initialize GRU states from initial state */
- compute_generic_dense(&model->state1, h->dense2_state, initial_state, ACTIVATION_TANH);
- compute_generic_dense(&model->state2, h->dense4_state, initial_state, ACTIVATION_TANH);
- compute_generic_dense(&model->state3, h->dense6_state, initial_state, ACTIVATION_TANH);
+ float hidden[DEC_HIDDEN_INIT_OUT_SIZE];
+ float state_init[DEC_GRU1_STATE_SIZE+DEC_GRU2_STATE_SIZE+DEC_GRU3_STATE_SIZE+DEC_GRU4_STATE_SIZE+DEC_GRU5_STATE_SIZE];
+ int counter=0;
+ compute_generic_dense(&model->dec_hidden_init, hidden, initial_state, ACTIVATION_TANH);
+ compute_generic_dense(&model->dec_gru_init, state_init, hidden, ACTIVATION_TANH);
+ OPUS_COPY(h->gru1_state, state_init, DEC_GRU1_STATE_SIZE);
+ counter += DEC_GRU1_STATE_SIZE;
+ OPUS_COPY(h->gru2_state, &state_init[counter], DEC_GRU2_STATE_SIZE);
+ counter += DEC_GRU2_STATE_SIZE;
+ OPUS_COPY(h->gru3_state, &state_init[counter], DEC_GRU3_STATE_SIZE);
+ counter += DEC_GRU3_STATE_SIZE;
+ OPUS_COPY(h->gru4_state, &state_init[counter], DEC_GRU4_STATE_SIZE);
+ counter += DEC_GRU4_STATE_SIZE;
+ OPUS_COPY(h->gru5_state, &state_init[counter], DEC_GRU5_STATE_SIZE);
+ h->initialized = 0;
}
@@ -53,44 +73,48 @@ void dred_rdovae_decode_qframe(
const float *input /* i: latent vector */
)
{
- float buffer[DEC_DENSE1_OUT_SIZE + DEC_DENSE2_OUT_SIZE + DEC_DENSE3_OUT_SIZE + DEC_DENSE4_OUT_SIZE + DEC_DENSE5_OUT_SIZE + DEC_DENSE6_OUT_SIZE + DEC_DENSE7_OUT_SIZE + DEC_DENSE8_OUT_SIZE];
+ float buffer[DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
+ + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE];
int output_index = 0;
- int input_index = 0;
/* run encoder stack and concatenate output in buffer*/
compute_generic_dense(&model->dec_dense1, &buffer[output_index], input, ACTIVATION_TANH);
- input_index = output_index;
output_index += DEC_DENSE1_OUT_SIZE;
- compute_generic_gru(&model->dec_dense2_input, &model->dec_dense2_recurrent, dec_state->dense2_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], dec_state->dense2_state, DEC_DENSE2_OUT_SIZE);
- input_index = output_index;
- output_index += DEC_DENSE2_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += DEC_DENSE3_OUT_SIZE;
-
- compute_generic_gru(&model->dec_dense4_input, &model->dec_dense4_recurrent, dec_state->dense4_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], dec_state->dense4_state, DEC_DENSE4_OUT_SIZE);
- input_index = output_index;
- output_index += DEC_DENSE4_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += DEC_DENSE5_OUT_SIZE;
-
- compute_generic_gru(&model->dec_dense6_input, &model->dec_dense6_recurrent, dec_state->dense6_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], dec_state->dense6_state, DEC_DENSE6_OUT_SIZE);
- input_index = output_index;
- output_index += DEC_DENSE6_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += DEC_DENSE7_OUT_SIZE;
-
- compute_generic_dense(&model->dec_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- output_index += DEC_DENSE8_OUT_SIZE;
-
- compute_generic_dense(&model->dec_final, qframe, buffer, ACTIVATION_LINEAR);
+ compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
+ output_index += DEC_GRU1_OUT_SIZE;
+ conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV1_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
+ output_index += DEC_GRU2_OUT_SIZE;
+ conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV2_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
+ output_index += DEC_GRU3_OUT_SIZE;
+ conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV3_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
+ output_index += DEC_GRU4_OUT_SIZE;
+ conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV4_OUT_SIZE;
+
+ compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
+ OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
+ output_index += DEC_GRU5_OUT_SIZE;
+ conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
+ compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += DEC_CONV5_OUT_SIZE;
+
+ compute_generic_dense(&model->dec_output, qframe, buffer, ACTIVATION_LINEAR);
}
diff --git a/dnn/dred_rdovae_dec.h b/dnn/dred_rdovae_dec.h
index 008551b5..4e039cf2 100644
--- a/dnn/dred_rdovae_dec.h
+++ b/dnn/dred_rdovae_dec.h
@@ -33,9 +33,17 @@
#include "dred_rdovae_stats_data.h"
struct RDOVAEDecStruct {
- float dense2_state[DEC_DENSE2_STATE_SIZE];
- float dense4_state[DEC_DENSE2_STATE_SIZE];
- float dense6_state[DEC_DENSE2_STATE_SIZE];
+ int initialized;
+ float gru1_state[DEC_GRU1_STATE_SIZE];
+ float gru2_state[DEC_GRU2_STATE_SIZE];
+ float gru3_state[DEC_GRU3_STATE_SIZE];
+ float gru4_state[DEC_GRU4_STATE_SIZE];
+ float gru5_state[DEC_GRU5_STATE_SIZE];
+ float conv1_state[DEC_CONV1_STATE_SIZE];
+ float conv2_state[DEC_CONV2_STATE_SIZE];
+ float conv3_state[DEC_CONV3_STATE_SIZE];
+ float conv4_state[DEC_CONV4_STATE_SIZE];
+ float conv5_state[DEC_CONV5_STATE_SIZE];
};
void dred_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state);
diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c
index 9361af17..98ffba8c 100644
--- a/dnn/dred_rdovae_enc.c
+++ b/dnn/dred_rdovae_enc.c
@@ -35,6 +35,15 @@
#include "dred_rdovae_enc.h"
#include "os_support.h"
+static void conv1_cond_init(float *mem, int len, int dilation, int *init)
+{
+ if (!*init) {
+ int i;
+ for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
+ }
+ *init = 1;
+}
+
void dred_rdovae_encode_dframe(
RDOVAEEncState *enc_state, /* io: encoder state */
const RDOVAEEnc *model,
@@ -43,52 +52,53 @@ void dred_rdovae_encode_dframe(
const float *input /* i: double feature frame (concatenated) */
)
{
- float buffer[ENC_DENSE1_OUT_SIZE + ENC_DENSE2_OUT_SIZE + ENC_DENSE3_OUT_SIZE + ENC_DENSE4_OUT_SIZE + ENC_DENSE5_OUT_SIZE + ENC_DENSE6_OUT_SIZE + ENC_DENSE7_OUT_SIZE + ENC_DENSE8_OUT_SIZE + GDENSE1_OUT_SIZE];
+ float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE
+ + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE];
+ float state_hidden[GDENSE1_OUT_SIZE];
int output_index = 0;
- int input_index = 0;
/* run encoder stack and concatenate output in buffer*/
compute_generic_dense(&model->enc_dense1, &buffer[output_index], input, ACTIVATION_TANH);
- input_index = output_index;
output_index += ENC_DENSE1_OUT_SIZE;
- compute_generic_gru(&model->enc_dense2_input, &model->enc_dense2_recurrent, enc_state->dense2_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], enc_state->dense2_state, ENC_DENSE2_OUT_SIZE);
- input_index = output_index;
- output_index += ENC_DENSE2_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += ENC_DENSE3_OUT_SIZE;
-
- compute_generic_gru(&model->enc_dense4_input, &model->enc_dense4_recurrent, enc_state->dense4_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], enc_state->dense4_state, ENC_DENSE4_OUT_SIZE);
- input_index = output_index;
- output_index += ENC_DENSE4_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += ENC_DENSE5_OUT_SIZE;
-
- compute_generic_gru(&model->enc_dense6_input, &model->enc_dense6_recurrent, enc_state->dense6_state, &buffer[input_index]);
- OPUS_COPY(&buffer[output_index], enc_state->dense6_state, ENC_DENSE6_OUT_SIZE);
- input_index = output_index;
- output_index += ENC_DENSE6_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- input_index = output_index;
- output_index += ENC_DENSE7_OUT_SIZE;
-
- compute_generic_dense(&model->enc_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
- output_index += ENC_DENSE8_OUT_SIZE;
-
- /* compute latents from concatenated input buffer */
- compute_generic_conv1d(&model->bits_dense, latents, enc_state->bits_dense_state, buffer, BITS_DENSE_IN_SIZE, ACTIVATION_LINEAR);
-
+ compute_generic_gru(&model->enc_gru1_input, &model->enc_gru1_recurrent, enc_state->gru1_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru1_state, ENC_GRU1_OUT_SIZE);
+ output_index += ENC_GRU1_OUT_SIZE;
+ conv1_cond_init(enc_state->conv1_state, output_index, 1, &enc_state->initialized);
+ compute_generic_conv1d(&model->enc_conv1, &buffer[output_index], enc_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
+ output_index += ENC_CONV1_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru2_input, &model->enc_gru2_recurrent, enc_state->gru2_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru2_state, ENC_GRU2_OUT_SIZE);
+ output_index += ENC_GRU2_OUT_SIZE;
+ conv1_cond_init(enc_state->conv2_state, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv2, &buffer[output_index], enc_state->conv2_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV2_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru3_input, &model->enc_gru3_recurrent, enc_state->gru3_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru3_state, ENC_GRU3_OUT_SIZE);
+ output_index += ENC_GRU3_OUT_SIZE;
+ conv1_cond_init(enc_state->conv3_state, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv3, &buffer[output_index], enc_state->conv3_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV3_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru4_input, &model->enc_gru4_recurrent, enc_state->gru4_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru4_state, ENC_GRU4_OUT_SIZE);
+ output_index += ENC_GRU4_OUT_SIZE;
+ conv1_cond_init(enc_state->conv4_state, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv4, &buffer[output_index], enc_state->conv4_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV4_OUT_SIZE;
+
+ compute_generic_gru(&model->enc_gru5_input, &model->enc_gru5_recurrent, enc_state->gru5_state, buffer);
+ OPUS_COPY(&buffer[output_index], enc_state->gru5_state, ENC_GRU5_OUT_SIZE);
+ output_index += ENC_GRU5_OUT_SIZE;
+ conv1_cond_init(enc_state->conv5_state, output_index, 2, &enc_state->initialized);
+ compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH);
+ output_index += ENC_CONV5_OUT_SIZE;
+
+ compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR);
/* next, calculate initial state */
- compute_generic_dense(&model->gdense1, &buffer[output_index], buffer, ACTIVATION_TANH);
- input_index = output_index;
- compute_generic_dense(&model->gdense2, initial_state, &buffer[input_index], ACTIVATION_TANH);
-
+ compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH);
+ compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_LINEAR);
}
diff --git a/dnn/dred_rdovae_enc.h b/dnn/dred_rdovae_enc.h
index 70ff6adc..832bd737 100644
--- a/dnn/dred_rdovae_enc.h
+++ b/dnn/dred_rdovae_enc.h
@@ -33,10 +33,17 @@
#include "dred_rdovae_enc_data.h"
struct RDOVAEEncStruct {
- float dense2_state[3 * ENC_DENSE2_STATE_SIZE];
- float dense4_state[3 * ENC_DENSE4_STATE_SIZE];
- float dense6_state[3 * ENC_DENSE6_STATE_SIZE];
- float bits_dense_state[BITS_DENSE_STATE_SIZE];
+ int initialized;
+ float gru1_state[ENC_GRU1_STATE_SIZE];
+ float gru2_state[ENC_GRU2_STATE_SIZE];
+ float gru3_state[ENC_GRU3_STATE_SIZE];
+ float gru4_state[ENC_GRU4_STATE_SIZE];
+ float gru5_state[ENC_GRU5_STATE_SIZE];
+ float conv1_state[ENC_CONV1_STATE_SIZE];
+ float conv2_state[2*ENC_CONV2_STATE_SIZE];
+ float conv3_state[2*ENC_CONV3_STATE_SIZE];
+ float conv4_state[2*ENC_CONV4_STATE_SIZE];
+ float conv5_state[2*ENC_CONV5_STATE_SIZE];
};
void dred_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input);
diff --git a/dnn/nnet.c b/dnn/nnet.c
index 3661ba77..d5ef904e 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -366,6 +366,25 @@ void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem,
OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size);
}
+void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation)
+{
+ float tmp[MAX_CONV_INPUTS_ALL];
+ int ksize = layer->nb_inputs/input_size;
+ int i;
+ celt_assert(input != output);
+ celt_assert(layer->nb_inputs <= MAX_CONV_INPUTS_ALL);
+ if (dilation==1) OPUS_COPY(tmp, mem, layer->nb_inputs-input_size);
+ else for (i=0;i<ksize-1;i++) OPUS_COPY(&tmp[i*input_size], &mem[i*input_size*dilation], input_size);
+ OPUS_COPY(&tmp[layer->nb_inputs-input_size], input, input_size);
+ compute_linear(layer, output, tmp);
+ compute_activation(output, output, layer->nb_outputs, activation);
+ if (dilation==1) OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size);
+ else {
+ OPUS_COPY(mem, &mem[input_size], input_size*dilation*(ksize-1)-input_size);
+ OPUS_COPY(&mem[input_size*dilation*(ksize-1)-input_size], input, input_size);
+ }
+}
+
void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
{
LinearLayer matrix;
diff --git a/dnn/nnet.h b/dnn/nnet.h
index 16ce82ba..9ed20b02 100644
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -145,6 +145,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in);
void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation);
void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in);
void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation);
+void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation);
void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation);
void compute_activation(float *output, const float *input, int N, int activation);
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index f9c1db81..c2cc61bd 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -116,10 +116,7 @@ f"""
# encoder
encoder_dense_layers = [
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'),
- ('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'),
- ('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'),
- ('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'),
- ('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'),
+ ('core_encoder.module.z_dense' , 'enc_zdense', 'LINEAR'),
('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'),
('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH')
]
@@ -130,9 +127,11 @@ f"""
encoder_gru_layers = [
- ('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'),
- ('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'),
- ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH')
+ ('core_encoder.module.gru1' , 'enc_gru1', 'TANH'),
+ ('core_encoder.module.gru2' , 'enc_gru2', 'TANH'),
+ ('core_encoder.module.gru3' , 'enc_gru3', 'TANH'),
+ ('core_encoder.module.gru4' , 'enc_gru4', 'TANH'),
+ ('core_encoder.module.gru5' , 'enc_gru5', 'TANH'),
]
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
@@ -140,7 +139,11 @@ f"""
encoder_conv_layers = [
- ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR')
+ ('core_encoder.module.conv1.conv' , 'enc_conv1', 'TANH'),
+ ('core_encoder.module.conv2.conv' , 'enc_conv2', 'TANH'),
+ ('core_encoder.module.conv3.conv' , 'enc_conv3', 'TANH'),
+ ('core_encoder.module.conv4.conv' , 'enc_conv4', 'TANH'),
+ ('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH'),
]
enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in encoder_conv_layers])
@@ -150,15 +153,10 @@ f"""
# decoder
decoder_dense_layers = [
- ('core_decoder.module.gru_1_init' , 'state1', 'TANH'),
- ('core_decoder.module.gru_2_init' , 'state2', 'TANH'),
- ('core_decoder.module.gru_3_init' , 'state3', 'TANH'),
- ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
- ('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'),
- ('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'),
- ('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'),
- ('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'),
- ('core_decoder.module.output' , 'dec_final', 'LINEAR')
+ ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'),
+ ('core_decoder.module.output' , 'dec_output', 'LINEAR'),
+ ('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH'),
+ ('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH'),
]
for name, export_name, _ in decoder_dense_layers:
@@ -167,14 +165,26 @@ f"""
decoder_gru_layers = [
- ('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'),
- ('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'),
- ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH')
+ ('core_decoder.module.gru1' , 'dec_gru1', 'TANH'),
+ ('core_decoder.module.gru2' , 'dec_gru2', 'TANH'),
+ ('core_decoder.module.gru3' , 'dec_gru3', 'TANH'),
+ ('core_decoder.module.gru4' , 'dec_gru4', 'TANH'),
+ ('core_decoder.module.gru5' , 'dec_gru5', 'TANH'),
]
dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
for name, export_name, _ in decoder_gru_layers])
+ decoder_conv_layers = [
+ ('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH'),
+ ('core_decoder.module.conv2.conv' , 'dec_conv2', 'TANH'),
+ ('core_decoder.module.conv3.conv' , 'dec_conv3', 'TANH'),
+ ('core_decoder.module.conv4.conv' , 'dec_conv4', 'TANH'),
+ ('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH'),
+ ]
+
+ dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in decoder_conv_layers])
+
del dec_writer
# statistical model
@@ -196,7 +206,7 @@ f"""
#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
-#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
+#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)}
#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
@@ -268,4 +278,4 @@ if __name__ == "__main__":
elif args.format == 'numpy':
numpy_export(args, model)
else:
- raise ValueError(f'error: unknown export format {args.format}') \ No newline at end of file
+ raise ValueError(f'error: unknown export format {args.format}')
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 0dc943ec..b126d4c4 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -224,6 +224,17 @@ def weight_clip_factory(max_value):
# RDOVAE module and submodules
+class MyConv(nn.Module):
+ def __init__(self, input_dim, output_dim, dilation=1):
+ super(MyConv, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.dilation=dilation
+ self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
+ def forward(self, x, state=None):
+ device = x.device
+ conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
+ return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
class CoreEncoder(nn.Module):
STATE_HIDDEN = 128
@@ -248,22 +259,28 @@ class CoreEncoder(nn.Module):
# derived parameters
self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
- self.conv_input_channels = 5 * cond_size + 3 * cond_size2
# layers
- self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
- self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
- self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
- self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
- self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
- self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
- self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
- self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
- self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
-
- self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
+ self.dense_1 = nn.Linear(self.input_dim, 64)
+ self.gru1 = nn.GRU(64, 64, batch_first=True)
+ self.conv1 = MyConv(128, 96)
+ self.gru2 = nn.GRU(224, 64, batch_first=True)
+ self.conv2 = MyConv(288, 96, dilation=2)
+ self.gru3 = nn.GRU(384, 64, batch_first=True)
+ self.conv3 = MyConv(448, 96, dilation=2)
+ self.gru4 = nn.GRU(544, 64, batch_first=True)
+ self.conv4 = MyConv(608, 96, dilation=2)
+ self.gru5 = nn.GRU(704, 64, batch_first=True)
+ self.conv5 = MyConv(768, 96, dilation=2)
+
+ self.z_dense = nn.Linear(864, self.output_dim)
+
+
+ self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN)
self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
+ nb_params = sum(p.numel() for p in self.parameters())
+ print(f"encoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
@@ -278,25 +295,22 @@ class CoreEncoder(nn.Module):
device = x.device
# run encoding layer stack
- x1 = torch.tanh(self.dense_1(x))
- x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
- x3 = torch.tanh(self.dense_2(x2))
- x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
- x5 = torch.tanh(self.dense_3(x4))
- x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
- x7 = torch.tanh(self.dense_4(x6))
- x8 = torch.tanh(self.dense_5(x7))
-
- # concatenation of all hidden layer outputs
- x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+ x = torch.tanh(self.dense_1(x))
+ x = torch.cat([x, self.gru1(x)[0]], -1)
+ x = torch.cat([x, self.conv1(x)], -1)
+ x = torch.cat([x, self.gru2(x)[0]], -1)
+ x = torch.cat([x, self.conv2(x)], -1)
+ x = torch.cat([x, self.gru3(x)[0]], -1)
+ x = torch.cat([x, self.conv3(x)], -1)
+ x = torch.cat([x, self.gru4(x)[0]], -1)
+ x = torch.cat([x, self.conv4(x)], -1)
+ x = torch.cat([x, self.gru5(x)[0]], -1)
+ x = torch.cat([x, self.conv5(x)], -1)
+ z = self.z_dense(x)
# init state for decoder
- states = torch.tanh(self.state_dense_1(x9))
- states = torch.tanh(self.state_dense_2(states))
-
- # latent representation via convolution
- x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
- z = self.conv1(x9).permute(0, 2, 1)
+ states = torch.tanh(self.state_dense_1(x))
+ states = self.state_dense_2(states)
return z, states
@@ -325,47 +339,54 @@ class CoreDecoder(nn.Module):
self.input_size = self.input_dim
- self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
-
# layers
- self.dense_1 = nn.Linear(self.input_size, cond_size2)
- self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True)
- self.dense_2 = nn.Linear(cond_size, cond_size2)
- self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True)
- self.dense_3 = nn.Linear(cond_size, cond_size2)
- self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True)
- self.dense_4 = nn.Linear(cond_size, cond_size2)
- self.dense_5 = nn.Linear(cond_size2, cond_size2)
-
- self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
-
-
- self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
- self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
- self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
-
+ self.dense_1 = nn.Linear(self.input_size, 96)
+ self.gru1 = nn.GRU(96, 96, batch_first=True)
+ self.conv1 = MyConv(192, 32)
+ self.gru2 = nn.GRU(224, 96, batch_first=True)
+ self.conv2 = MyConv(320, 32)
+ self.gru3 = nn.GRU(352, 96, batch_first=True)
+ self.conv3 = MyConv(448, 32)
+ self.gru4 = nn.GRU(480, 96, batch_first=True)
+ self.conv4 = MyConv(576, 32)
+ self.gru5 = nn.GRU(608, 96, batch_first=True)
+ self.conv5 = MyConv(704, 32)
+ self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
+
+ self.hidden_init = nn.Linear(self.state_size, 128)
+ self.gru_init = nn.Linear(128, 480)
+
+ nb_params = sum(p.numel() for p in self.parameters())
+ print(f"decoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
def forward(self, z, initial_state):
- gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
- gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
- gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
+ hidden = torch.tanh(self.hidden_init(initial_state))
+ gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2))
+ h1_state = gru_state[:,:,:96].contiguous()
+ h2_state = gru_state[:,:,96:192].contiguous()
+ h3_state = gru_state[:,:,192:288].contiguous()
+ h4_state = gru_state[:,:,288:384].contiguous()
+ h5_state = gru_state[:,:,384:].contiguous()
# run decoding layer stack
- x1 = torch.tanh(self.dense_1(z))
- x2, _ = self.gru_1(x1, gru_1_state)
- x3 = torch.tanh(self.dense_2(x2))
- x4, _ = self.gru_2(x3, gru_2_state)
- x5 = torch.tanh(self.dense_3(x4))
- x6, _ = self.gru_3(x5, gru_3_state)
- x7 = torch.tanh(self.dense_4(x6))
- x8 = torch.tanh(self.dense_5(x7))
- x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+ x = torch.tanh(self.dense_1(z))
+
+ x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
+ x = torch.cat([x, self.conv1(x)], -1)
+ x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
+ x = torch.cat([x, self.conv2(x)], -1)
+ x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
+ x = torch.cat([x, self.conv3(x)], -1)
+ x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
+ x = torch.cat([x, self.conv4(x)], -1)
+ x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
+ x = torch.cat([x, self.conv5(x)], -1)
# output layer and reshaping
- x10 = self.output(x9)
+ x10 = self.output(x)
features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
return features
@@ -466,7 +487,7 @@ class RDOVAE(nn.Module):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
- def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
+ def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24):
enc_stride = self.enc_stride
dec_stride = self.dec_stride