diff options
Diffstat (limited to 'dnn/torch/osce/models/lace.py')
-rw-r--r-- | dnn/torch/osce/models/lace.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 58293de4..51d65c3e 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -41,6 +41,12 @@ from models.silk_feature_net_pl import SilkFeatureNetPL from models.silk_feature_net import SilkFeatureNet from .scale_embedding import ScaleEmbedding +import sys +sys.path.append('../dnntools') + +from dnntools.sparsification import create_sparsifier + + class LACE(NNSBase): """ Linear-Adaptive Coding Enhancer """ FRAME_SIZE=80 @@ -60,7 +66,12 @@ class LACE(NNSBase): numbits_embedding_dim=8, hidden_feature_dim=64, partial_lookahead=True, - norm_p=2): + norm_p=2, + softquant=False, + sparsify=False, + sparsification_schedule=[10000, 30000, 100], + sparsification_density=0.5, + apply_weight_norm=False): super().__init__(skip=skip, preemph=preemph) @@ -85,18 +96,21 @@ class LACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) # comb filters left_pad = self.kernel_size // 2 right_pad = self.kernel_size - 1 - left_pad - self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) - self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + + if sparsify: + self.sparsifier = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): |