Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-11-15 12:08:50 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-11-15 12:08:50 +0300
commitb0620c0bf9864d9b18ead6b4bb6e0800542a931d (patch)
treeea669c2226514d7f91989080fe2954623cbe83a8
parent58923f61c26ac0f5d8284d427344466e3bc2c674 (diff)
Using sparse GRUs in DRED decoder
Saves ~270 kB of weights in the decoder
-rwxr-xr-xautogen.sh2
-rw-r--r--dnn/dred_rdovae_dec.c10
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/common.py4
-rw-r--r--dnn/torch/rdovae/export_rdovae_weights.py18
-rw-r--r--dnn/torch/rdovae/rdovae/rdovae.py85
-rw-r--r--dnn/torch/rdovae/train_rdovae.py3
-rw-r--r--silk/dred_config.h2
7 files changed, 105 insertions, 19 deletions
diff --git a/autogen.sh b/autogen.sh
index 47b6fe5e..03fd9495 100755
--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@ set -e
srcdir=`dirname $0`
test -n "$srcdir" && cd "$srcdir"
-dnn/download_model.sh b6095cf
+dnn/download_model.sh 58923f6
echo "Updating build configuration files, please wait...."
diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c
index 59cc8943..e2b19b14 100644
--- a/dnn/dred_rdovae_dec.c
+++ b/dnn/dred_rdovae_dec.c
@@ -98,35 +98,35 @@ void dred_rdovae_decode_qframe(
output_index += DEC_DENSE1_OUT_SIZE;
compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
- OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
+ compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state);
output_index += DEC_GRU1_OUT_SIZE;
conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV1_OUT_SIZE;
compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
- OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
+ compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state);
output_index += DEC_GRU2_OUT_SIZE;
conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV2_OUT_SIZE;
compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
- OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
+ compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state);
output_index += DEC_GRU3_OUT_SIZE;
conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV3_OUT_SIZE;
compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
- OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
+ compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state);
output_index += DEC_GRU4_OUT_SIZE;
conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
output_index += DEC_CONV4_OUT_SIZE;
compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
- OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
+ compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state);
output_index += DEC_GRU5_OUT_SIZE;
conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);
diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py
index 43fb28d4..2600cd01 100644
--- a/dnn/torch/lpcnet/utils/sparsification/common.py
+++ b/dnn/torch/lpcnet/utils/sparsification/common.py
@@ -29,7 +29,7 @@
import torch
-def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
+def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Parameters:
@@ -118,4 +118,4 @@ def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=Fal
# activations estimated by 10 flops per activation
flops += 30 * hidden_size
- return flops \ No newline at end of file
+ return flops
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index a7585c9d..3ef9fabd 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -225,10 +225,15 @@ f"""
# decoder
decoder_dense_layers = [
- ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
- ('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
+ ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False),
+ ('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True),
+ ('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True),
+ ('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True),
+ ('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True),
+ ('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True),
+ ('core_decoder.module.output' , 'dec_output', 'LINEAR', True),
('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False),
- ('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH', True),
+ ('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True),
]
for name, export_name, _, quantize in decoder_dense_layers:
@@ -338,6 +343,13 @@ if __name__ == "__main__":
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+ model.apply(_remove_weight_norm)
+
if len(missing_keys) > 0:
raise ValueError(f"error: missing keys in state dict")
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 09b6801a..cdb07b46 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -34,6 +34,12 @@ import math as m
import torch
from torch import nn
import torch.nn.functional as F
+import sys
+import os
+source_dir = os.path.split(os.path.abspath(__file__))[0]
+sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
+from utils.sparsification import GRUSparsifier
+from torch.nn.utils import weight_norm
# Quantization and rate related utily functions
@@ -227,6 +233,32 @@ def n(x):
# RDOVAE module and submodules
+sparsify_start = 12000
+sparsify_stop = 24000
+sparsify_interval = 100
+sparsify_exponent = 3
+#sparsify_start = 0
+#sparsify_stop = 0
+
+sparse_params1 = {
+# 'W_hr' : (1.0, [8, 4], True),
+# 'W_hz' : (1.0, [8, 4], True),
+# 'W_hn' : (1.0, [8, 4], True),
+ 'W_ir' : (0.6, [8, 4], False),
+ 'W_iz' : (0.4, [8, 4], False),
+ 'W_in' : (0.8, [8, 4], False)
+ }
+
+sparse_params2 = {
+# 'W_hr' : (1.0, [8, 4], True),
+# 'W_hz' : (1.0, [8, 4], True),
+# 'W_hn' : (1.0, [8, 4], True),
+ 'W_ir' : (0.3, [8, 4], False),
+ 'W_iz' : (0.2, [8, 4], False),
+ 'W_in' : (0.4, [8, 4], False)
+ }
+
+
class MyConv(nn.Module):
def __init__(self, input_dim, output_dim, dilation=1):
super(MyConv, self).__init__()
@@ -239,6 +271,29 @@ class MyConv(nn.Module):
conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
+class GLU(nn.Module):
+ def __init__(self, feat_size):
+ super(GLU, self).__init__()
+
+ torch.manual_seed(5)
+
+ self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
+
+ self.init_weights()
+
+ def init_weights(self):
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
+ or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
+ nn.init.orthogonal_(m.weight.data)
+
+ def forward(self, x):
+
+ out = x * torch.sigmoid(self.gate(x))
+
+ return out
+
class CoreEncoder(nn.Module):
STATE_HIDDEN = 128
FRAMES_PER_STEP = 2
@@ -355,7 +410,11 @@ class CoreDecoder(nn.Module):
self.gru5 = nn.GRU(608, 96, batch_first=True)
self.conv5 = MyConv(704, 32)
self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
-
+ self.glu1 = GLU(96)
+ self.glu2 = GLU(96)
+ self.glu3 = GLU(96)
+ self.glu4 = GLU(96)
+ self.glu5 = GLU(96)
self.hidden_init = nn.Linear(self.state_size, 128)
self.gru_init = nn.Linear(128, 480)
@@ -363,6 +422,16 @@ class CoreDecoder(nn.Module):
print(f"decoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
+ self.sparsifier = []
+ self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+ self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+ self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+ self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+ self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+
+ def sparsify(self):
+ for sparsifier in self.sparsifier:
+ sparsifier.step()
def forward(self, z, initial_state):
@@ -377,15 +446,15 @@ class CoreDecoder(nn.Module):
# run decoding layer stack
x = n(torch.tanh(self.dense_1(z)))
- x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
+ x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
x = torch.cat([x, n(self.conv1(x))], -1)
- x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
+ x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
x = torch.cat([x, n(self.conv2(x))], -1)
- x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
+ x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
x = torch.cat([x, n(self.conv3(x))], -1)
- x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
+ x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
x = torch.cat([x, n(self.conv4(x))], -1)
- x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
+ x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
x = torch.cat([x, n(self.conv5(x))], -1)
# output layer and reshaping
@@ -490,6 +559,10 @@ class RDOVAE(nn.Module):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
+ def sparsify(self):
+ #self.core_encoder.module.sparsify()
+ self.core_decoder.module.sparsify()
+
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
enc_stride = self.enc_stride
diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py
index 35c0861c..d9a43b33 100644
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -84,7 +84,7 @@ sequence_length = args.sequence_length
lr_decay_factor = args.lr_decay_factor
split_mode = args.split_mode
# not exposed
-adam_betas = [0.9, 0.99]
+adam_betas = [0.8, 0.95]
adam_eps = 1e-8
checkpoint['batch_size'] = batch_size
@@ -239,6 +239,7 @@ if __name__ == '__main__':
optimizer.step()
model.clip_weights()
+ model.sparsify()
scheduler.step()
diff --git a/silk/dred_config.h b/silk/dred_config.h
index 207908fc..86da3b00 100644
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
#define DRED_EXTENSION_ID 126
/* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 7
+#define DRED_EXPERIMENTAL_VERSION 8
#define DRED_EXPERIMENTAL_BYTES 2