diff options
Diffstat (limited to 'dnn/torch/osce/models/silk_feature_net_pl.py')
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 25 |
1 files changed, 4 insertions, 21 deletions
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 89e5cc6b..c766d0ab 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -50,9 +50,7 @@ class SilkFeatureNetPL(nn.Module): softquant=False, sparsify=True, sparsification_density=0.5, - apply_weight_norm=False, - repeat_upsamp=False, - repeat_upsamp_dim=16): + apply_weight_norm=False): super(SilkFeatureNetPL, self).__init__() @@ -62,17 +60,12 @@ class SilkFeatureNetPL(nn.Module): self.feature_dim = feature_dim self.num_channels = num_channels self.hidden_feature_dim = hidden_feature_dim - self.repeat_upsamp = repeat_upsamp - self.repeat_upsamp_dim = 16 norm = weight_norm if apply_weight_norm else lambda x, name=None: x self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)) self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)) - if self.repeat_upsamp: - self.upsamp_embedding = nn.Embedding(4, self.repeat_upsamp_dim) - else: - self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4)) + self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4)) gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0') @@ -127,18 +120,8 @@ class SilkFeatureNetPL(nn.Module): c = torch.tanh(self.conv2(F.pad(c, [1, 0]))) # upsampling - if self.repeat_upsamp: - a = torch.arange(num_frames, device=features.device) % 4 - embeddings = torch.repeat_interleave( - torch.tanh(self.upsamp_embedding(a)).unsqueeze(0), - batch_size, - 0 - ) - c = c.permute(0, 2, 1) - c = torch.cat((torch.repeat_interleave(c, 4, 1), embeddings), dim=2) - else: - c = torch.tanh(self.tconv(c)) - c = c.permute(0, 2, 1) + c = torch.tanh(self.tconv(c)) + c = c.permute(0, 2, 1) c, _ = self.gru(c, state) |