Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/models/lace.py')
-rw-r--r--dnn/torch/osce/models/lace.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
index 78c1a717..7e8e739c 100644
--- a/dnn/torch/osce/models/lace.py
+++ b/dnn/torch/osce/models/lace.py
@@ -68,7 +68,9 @@ class LACE(NNSBase):
partial_lookahead=True,
norm_p=2,
softquant=False,
- sparsify=False):
+ sparsify=False,
+ sparsification_schedule=[10000, 30000, 100],
+ sparsification_density=0.5):
super().__init__(skip=skip, preemph=preemph)
@@ -93,7 +95,7 @@ class LACE(NNSBase):
# feature net
if partial_lookahead:
- self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant)
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
@@ -107,7 +109,7 @@ class LACE(NNSBase):
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant)
if sparsify:
- self.sparsify = create_sparsifier(self, 500, 2000, 100)
+ self.sparsify = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):