Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/models/no_lace.py')
-rw-r--r--dnn/torch/osce/models/no_lace.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 19bc6caf..5ac3cbfa 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -27,6 +27,8 @@
*/
"""
+import numbers
+
import torch
from torch import nn
import torch.nn.functional as F
@@ -93,6 +95,9 @@ class NoLACE(NNSBase):
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
+ if isinstance(sparsification_density, numbers.Number):
+ sparsification_density = 10 * [sparsification_density]
+
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
# pitch embedding
@@ -142,11 +147,11 @@ class NoLACE(NNSBase):
if sparsify:
- mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
+ mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
+ mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
+ mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
+ mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
self.sparsifier = create_sparsifier(self, *sparsification_schedule)