Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/models/silk_feature_net_pl.py')
-rw-r--r--dnn/torch/osce/models/silk_feature_net_pl.py25
1 files changed, 4 insertions, 21 deletions
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py
index 89e5cc6b..c766d0ab 100644
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -50,9 +50,7 @@ class SilkFeatureNetPL(nn.Module):
softquant=False,
sparsify=True,
sparsification_density=0.5,
- apply_weight_norm=False,
- repeat_upsamp=False,
- repeat_upsamp_dim=16):
+ apply_weight_norm=False):
super(SilkFeatureNetPL, self).__init__()
@@ -62,17 +60,12 @@ class SilkFeatureNetPL(nn.Module):
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
- self.repeat_upsamp = repeat_upsamp
- self.repeat_upsamp_dim = 16
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
- if self.repeat_upsamp:
- self.upsamp_embedding = nn.Embedding(4, self.repeat_upsamp_dim)
- else:
- self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
+ self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
@@ -127,18 +120,8 @@ class SilkFeatureNetPL(nn.Module):
c = torch.tanh(self.conv2(F.pad(c, [1, 0])))
# upsampling
- if self.repeat_upsamp:
- a = torch.arange(num_frames, device=features.device) % 4
- embeddings = torch.repeat_interleave(
- torch.tanh(self.upsamp_embedding(a)).unsqueeze(0),
- batch_size,
- 0
- )
- c = c.permute(0, 2, 1)
- c = torch.cat((torch.repeat_interleave(c, 4, 1), embeddings), dim=2)
- else:
- c = torch.tanh(self.tconv(c))
- c = c.permute(0, 2, 1)
+ c = torch.tanh(self.tconv(c))
+ c = c.permute(0, 2, 1)
c, _ = self.gru(c, state)