diff options
Diffstat (limited to 'dnn/torch/osce/models/lace.py')
-rw-r--r-- | dnn/torch/osce/models/lace.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 78c1a717..7e8e739c 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -68,7 +68,9 @@ class LACE(NNSBase): partial_lookahead=True, norm_p=2, softquant=False, - sparsify=False): + sparsify=False, + sparsification_schedule=[10000, 30000, 100], + sparsification_density=0.5): super().__init__(skip=skip, preemph=preemph) @@ -93,7 +95,7 @@ class LACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) @@ -107,7 +109,7 @@ class LACE(NNSBase): self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) if sparsify: - self.sparsify = create_sparsifier(self, 500, 2000, 100) + self.sparsify = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): |