diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-23 19:10:34 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-23 19:10:34 +0300 |
commit | 7df2c67be1a976cf10b7094b289180b1b5bb1c94 (patch) | |
tree | 9d2282270d5f1c90e99a15c8e959f5b59d5ea607 | |
parent | 3499d0aac76d20ba14918cafb8020278154bf2e6 (diff) |
fixes in osce python codeopus-ng
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 5 | ||||
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 2 |
2 files changed, 2 insertions, 5 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 801857a4..78c3a301 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -177,10 +177,7 @@ class NoLACE(NNSBase): def feature_transform(self, f, layer): f0 = f.permute(0, 2, 1) f = F.pad(f0, [1, 0]) - if self.residual_in_feature_transform: - f = torch.tanh(layer(f) + f0) - else: - f = torch.tanh(layer(f)) + f = torch.tanh(layer(f)) return f.permute(0, 2, 1) def forward(self, x, features, periods, numbits, debug=False): diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 5efa7e70..e6a536fe 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -92,7 +92,7 @@ class SilkFeatureNetPL(nn.Module): def flop_count(self, rate=200): count = 0 - for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]: + for conv in self.conv1, self.conv2, self.tconv: count += _conv1d_flop_count(conv, rate) count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate |