Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-15 14:18:30 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-19 18:44:27 +0300
commit2d765480c444107ff5ee4b43293e51738338be46 (patch)
treeae80e99995a1ce8e5bc3449192dc286f1e52e086
parentf7d2e78ec93008f6f91352cc8513099ed303e72c (diff)
added option of having residual connections in NoLACE feature transforms
-rw-r--r--dnn/torch/osce/models/no_lace.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 5ac3cbfa..480fcc55 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -79,7 +79,8 @@ class NoLACE(NNSBase):
sparsify=False,
sparsification_schedule=[100, 1000, 100],
sparsification_density=0.5,
- apply_weight_norm=False):
+ apply_weight_norm=False,
+ residual_in_feature_transform=False):
super().__init__(skip=skip, preemph=preemph)
@@ -94,6 +95,7 @@ class NoLACE(NNSBase):
self.numbits_embedding_dim = numbits_embedding_dim
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
+ self.residual_in_feature_transform = residual_in_feature_transform
if isinstance(sparsification_density, numbers.Number):
sparsification_density = 10 * [sparsification_density]
@@ -176,9 +178,12 @@ class NoLACE(NNSBase):
return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops
def feature_transform(self, f, layer):
- f = f.permute(0, 2, 1)
- f = F.pad(f, [1, 0])
- f = torch.tanh(layer(f))
+ f0 = f.permute(0, 2, 1)
+ f = F.pad(f0, [1, 0])
+ if self.residual_in_feature_transform:
+ f = torch.tanh(layer(f) + f0)
+ else:
+ f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):