diff options
Diffstat (limited to 'dnn/torch/osce/models/no_lace.py')
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 77 |
1 files changed, 57 insertions, 20 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 0e0fb1b3..801857a4 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -27,9 +27,13 @@ */ """ +import numbers + import torch from torch import nn import torch.nn.functional as F +from torch.nn.utils import weight_norm + import numpy as np @@ -43,6 +47,11 @@ from models.silk_feature_net_pl import SilkFeatureNetPL from models.silk_feature_net import SilkFeatureNet from .scale_embedding import ScaleEmbedding +import sys +sys.path.append('../dnntools') +from dnntools.quantization import soft_quant +from dnntools.sparsification import create_sparsifier, mark_for_sparsification + class NoLACE(NNSBase): """ Non-Linear Adaptive Coding Enhancer """ FRAME_SIZE=80 @@ -64,11 +73,15 @@ class NoLACE(NNSBase): partial_lookahead=True, norm_p=2, avg_pool_k=4, - pool_after=False): + pool_after=False, + softquant=False, + sparsify=False, + sparsification_schedule=[100, 1000, 100], + sparsification_density=0.5, + apply_weight_norm=False): super().__init__(skip=skip, preemph=preemph) - self.num_features = num_features self.cond_dim = cond_dim self.pitch_max = pitch_max @@ -81,6 +94,11 @@ class NoLACE(NNSBase): self.hidden_feature_dim = hidden_feature_dim self.partial_lookahead = partial_lookahead + if isinstance(sparsification_density, numbers.Number): + sparsification_density = 10 * [sparsification_density] + + norm = weight_norm if apply_weight_norm else lambda x, name=None: x + # pitch embedding self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) @@ -89,36 +107,52 @@ class NoLACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) # comb filters left_pad = self.kernel_size // 2 right_pad = self.kernel_size - 1 - left_pad - self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) - self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # non-linear transforms - self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) - self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) - self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) # combinators - self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) - self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) - self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # feature transforms - self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + + if softquant: + self.post_cf1 = soft_quant(self.post_cf1) + self.post_cf2 = soft_quant(self.post_cf2) + self.post_af1 = soft_quant(self.post_af1) + self.post_af2 = soft_quant(self.post_af2) + self.post_af3 = soft_quant(self.post_af3) + + if sparsify: + mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4])) + mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4])) + mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4])) + mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4])) + mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4])) + + self.sparsifier = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): @@ -141,9 +175,12 @@ class NoLACE(NNSBase): return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops def feature_transform(self, f, layer): - f = f.permute(0, 2, 1) - f = F.pad(f, [1, 0]) - f = torch.tanh(layer(f)) + f0 = f.permute(0, 2, 1) + f = F.pad(f0, [1, 0]) + if self.residual_in_feature_transform: + f = torch.tanh(layer(f) + f0) + else: + f = torch.tanh(layer(f)) return f.permute(0, 2, 1) def forward(self, x, features, periods, numbits, debug=False): |