diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-15 14:18:30 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:44:27 +0300 |
commit | 2d765480c444107ff5ee4b43293e51738338be46 (patch) | |
tree | ae80e99995a1ce8e5bc3449192dc286f1e52e086 | |
parent | f7d2e78ec93008f6f91352cc8513099ed303e72c (diff) |
added option of having residual connections in NoLACE feature transforms
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 5ac3cbfa..480fcc55 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -79,7 +79,8 @@ class NoLACE(NNSBase): sparsify=False, sparsification_schedule=[100, 1000, 100], sparsification_density=0.5, - apply_weight_norm=False): + apply_weight_norm=False, + residual_in_feature_transform=False): super().__init__(skip=skip, preemph=preemph) @@ -94,6 +95,7 @@ class NoLACE(NNSBase): self.numbits_embedding_dim = numbits_embedding_dim self.hidden_feature_dim = hidden_feature_dim self.partial_lookahead = partial_lookahead + self.residual_in_feature_transform = residual_in_feature_transform if isinstance(sparsification_density, numbers.Number): sparsification_density = 10 * [sparsification_density] @@ -176,9 +178,12 @@ class NoLACE(NNSBase): return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops def feature_transform(self, f, layer): - f = f.permute(0, 2, 1) - f = F.pad(f, [1, 0]) - f = torch.tanh(layer(f)) + f0 = f.permute(0, 2, 1) + f = F.pad(f0, [1, 0]) + if self.residual_in_feature_transform: + f = torch.tanh(layer(f) + f0) + else: + f = torch.tanh(layer(f)) return f.permute(0, 2, 1) def forward(self, x, features, periods, numbits, debug=False): |