Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-08 14:00:49 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-08 14:05:04 +0300
commit999ddbe09ca1e521cc8d71005c1ecd8513e47611 (patch)
tree85022704a9352cb03751863f9a1ae788afcc805b /dnn/torch/osce/models/no_lace.py
parente968878f06fb96e022f09c0fab60b88ab3f3ac81 (diff)
more sparsification stuff
Diffstat (limited to 'dnn/torch/osce/models/no_lace.py')
-rw-r--r--dnn/torch/osce/models/no_lace.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 5654db6c..7e021930 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -72,7 +72,9 @@ class NoLACE(NNSBase):
avg_pool_k=4,
pool_after=False,
softquant=False,
- sparsify=False):
+ sparsify=False,
+ sparsification_schedule=[100, 1000, 100],
+ sparsification_density=0.5):
super().__init__(skip=skip, preemph=preemph)
@@ -97,7 +99,7 @@ class NoLACE(NNSBase):
# feature net
if partial_lookahead:
- self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant)
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
@@ -136,13 +138,13 @@ class NoLACE(NNSBase):
if sparsify:
- mark_for_sparsification(self.post_cf1, (0.25, [8, 4]))
- mark_for_sparsification(self.post_cf2, (0.25, [8, 4]))
- mark_for_sparsification(self.post_af1, (0.25, [8, 4]))
- mark_for_sparsification(self.post_af2, (0.25, [8, 4]))
- mark_for_sparsification(self.post_af3, (0.25, [8, 4]))
+ mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4]))
- self.sparsify = create_sparsifier(self, 500, 1000, 100)
+ self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):