diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-15 12:08:50 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-11-15 12:08:50 +0300 |
commit | b0620c0bf9864d9b18ead6b4bb6e0800542a931d (patch) | |
tree | ea669c2226514d7f91989080fe2954623cbe83a8 | |
parent | 58923f61c26ac0f5d8284d427344466e3bc2c674 (diff) |
Using sparse GRUs in DRED decoder
Saves ~270 kB of weights in the decoder
-rwxr-xr-x | autogen.sh | 2 | ||||
-rw-r--r-- | dnn/dred_rdovae_dec.c | 10 | ||||
-rw-r--r-- | dnn/torch/lpcnet/utils/sparsification/common.py | 4 | ||||
-rw-r--r-- | dnn/torch/rdovae/export_rdovae_weights.py | 18 | ||||
-rw-r--r-- | dnn/torch/rdovae/rdovae/rdovae.py | 85 | ||||
-rw-r--r-- | dnn/torch/rdovae/train_rdovae.py | 3 | ||||
-rw-r--r-- | silk/dred_config.h | 2 |
7 files changed, 105 insertions, 19 deletions
@@ -9,7 +9,7 @@ set -e srcdir=`dirname $0` test -n "$srcdir" && cd "$srcdir" -dnn/download_model.sh b6095cf +dnn/download_model.sh 58923f6 echo "Updating build configuration files, please wait...." diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c index 59cc8943..e2b19b14 100644 --- a/dnn/dred_rdovae_dec.c +++ b/dnn/dred_rdovae_dec.c @@ -98,35 +98,35 @@ void dred_rdovae_decode_qframe( output_index += DEC_DENSE1_OUT_SIZE; compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer); - OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE); + compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state); output_index += DEC_GRU1_OUT_SIZE; conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized); compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH); output_index += DEC_CONV1_OUT_SIZE; compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer); - OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE); + compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state); output_index += DEC_GRU2_OUT_SIZE; conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized); compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH); output_index += DEC_CONV2_OUT_SIZE; compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer); - OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE); + compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state); output_index += DEC_GRU3_OUT_SIZE; conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized); compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH); output_index += DEC_CONV3_OUT_SIZE; compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer); - OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE); + compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state); output_index += DEC_GRU4_OUT_SIZE; conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized); compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH); output_index += DEC_CONV4_OUT_SIZE; compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer); - OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE); + compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state); output_index += DEC_GRU5_OUT_SIZE; conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized); compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH); diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py index 43fb28d4..2600cd01 100644 --- a/dnn/torch/lpcnet/utils/sparsification/common.py +++ b/dnn/torch/lpcnet/utils/sparsification/common.py @@ -29,7 +29,7 @@ import torch -def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False): +def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False): """ sparsifies matrix with specified block size Parameters: @@ -118,4 +118,4 @@ def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=Fal # activations estimated by 10 flops per activation flops += 30 * hidden_size - return flops
\ No newline at end of file + return flops diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index a7585c9d..3ef9fabd 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -225,10 +225,15 @@ f""" # decoder decoder_dense_layers = [ - ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False), - ('core_decoder.module.output' , 'dec_output', 'LINEAR', True), + ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False), + ('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True), + ('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True), + ('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True), + ('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True), + ('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True), + ('core_decoder.module.output' , 'dec_output', 'LINEAR', True), ('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False), - ('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH', True), + ('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True), ] for name, export_name, _, quantize in decoder_dense_layers: @@ -338,6 +343,13 @@ if __name__ == "__main__": checkpoint = torch.load(args.checkpoint, map_location='cpu') model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + model.apply(_remove_weight_norm) + if len(missing_keys) > 0: raise ValueError(f"error: missing keys in state dict") diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 09b6801a..cdb07b46 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -34,6 +34,12 @@ import math as m import torch from torch import nn import torch.nn.functional as F +import sys +import os +source_dir = os.path.split(os.path.abspath(__file__))[0] +sys.path.append(os.path.join(source_dir, "../../lpcnet/")) +from utils.sparsification import GRUSparsifier +from torch.nn.utils import weight_norm # Quantization and rate related utily functions @@ -227,6 +233,32 @@ def n(x): # RDOVAE module and submodules +sparsify_start = 12000 +sparsify_stop = 24000 +sparsify_interval = 100 +sparsify_exponent = 3 +#sparsify_start = 0 +#sparsify_stop = 0 + +sparse_params1 = { +# 'W_hr' : (1.0, [8, 4], True), +# 'W_hz' : (1.0, [8, 4], True), +# 'W_hn' : (1.0, [8, 4], True), + 'W_ir' : (0.6, [8, 4], False), + 'W_iz' : (0.4, [8, 4], False), + 'W_in' : (0.8, [8, 4], False) + } + +sparse_params2 = { +# 'W_hr' : (1.0, [8, 4], True), +# 'W_hz' : (1.0, [8, 4], True), +# 'W_hn' : (1.0, [8, 4], True), + 'W_ir' : (0.3, [8, 4], False), + 'W_iz' : (0.2, [8, 4], False), + 'W_in' : (0.4, [8, 4], False) + } + + class MyConv(nn.Module): def __init__(self, input_dim, output_dim, dilation=1): super(MyConv, self).__init__() @@ -239,6 +271,29 @@ class MyConv(nn.Module): conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) +class GLU(nn.Module): + def __init__(self, feat_size): + super(GLU, self).__init__() + + torch.manual_seed(5) + + self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + out = x * torch.sigmoid(self.gate(x)) + + return out + class CoreEncoder(nn.Module): STATE_HIDDEN = 128 FRAMES_PER_STEP = 2 @@ -355,7 +410,11 @@ class CoreDecoder(nn.Module): self.gru5 = nn.GRU(608, 96, batch_first=True) self.conv5 = MyConv(704, 32) self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) - + self.glu1 = GLU(96) + self.glu2 = GLU(96) + self.glu3 = GLU(96) + self.glu4 = GLU(96) + self.glu5 = GLU(96) self.hidden_init = nn.Linear(self.state_size, 128) self.gru_init = nn.Linear(128, 480) @@ -363,6 +422,16 @@ class CoreDecoder(nn.Module): print(f"decoder: {nb_params} weights") # initialize weights self.apply(init_weights) + self.sparsifier = [] + self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) + self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) + self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) + self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) + self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) + + def sparsify(self): + for sparsifier in self.sparsifier: + sparsifier.step() def forward(self, z, initial_state): @@ -377,15 +446,15 @@ class CoreDecoder(nn.Module): # run decoding layer stack x = n(torch.tanh(self.dense_1(z))) - x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1) + x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1) x = torch.cat([x, n(self.conv1(x))], -1) - x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1) + x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1) x = torch.cat([x, n(self.conv2(x))], -1) - x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1) + x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1) x = torch.cat([x, n(self.conv3(x))], -1) - x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1) + x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1) x = torch.cat([x, n(self.conv4(x))], -1) - x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1) + x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1) x = torch.cat([x, n(self.conv5(x))], -1) # output layer and reshaping @@ -490,6 +559,10 @@ class RDOVAE(nn.Module): if not type(self.weight_clip_fn) == type(None): self.apply(self.weight_clip_fn) + def sparsify(self): + #self.core_encoder.module.sparsify() + self.core_decoder.module.sparsify() + def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): enc_stride = self.enc_stride diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py index 35c0861c..d9a43b33 100644 --- a/dnn/torch/rdovae/train_rdovae.py +++ b/dnn/torch/rdovae/train_rdovae.py @@ -84,7 +84,7 @@ sequence_length = args.sequence_length lr_decay_factor = args.lr_decay_factor split_mode = args.split_mode # not exposed -adam_betas = [0.9, 0.99] +adam_betas = [0.8, 0.95] adam_eps = 1e-8 checkpoint['batch_size'] = batch_size @@ -239,6 +239,7 @@ if __name__ == '__main__': optimizer.step() model.clip_weights() + model.sparsify() scheduler.step() diff --git a/silk/dred_config.h b/silk/dred_config.h index 207908fc..86da3b00 100644 --- a/silk/dred_config.h +++ b/silk/dred_config.h @@ -32,7 +32,7 @@ #define DRED_EXTENSION_ID 126 /* Remove these two completely once DRED gets an extension number assigned. */ -#define DRED_EXPERIMENTAL_VERSION 7 +#define DRED_EXPERIMENTAL_VERSION 8 #define DRED_EXPERIMENTAL_BYTES 2 |