diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-26 02:29:30 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-28 23:15:04 +0300 |
commit | b57ddadf9821c52ade6be7b40b4004fcc2c89a17 (patch) | |
tree | 0b1f293934b2525b9538b7d6b9401eb9ea718dae | |
parent | 7e770ffb3ae1931185b3563831868ea946a330d0 (diff) |
Simplifications
-rw-r--r-- | dnn/torch/fargan/fargan.py | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 8988148f..3e67351f 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -151,29 +151,28 @@ class FWConv(nn.Module): return out, xcat[:,self.in_size:] class FARGANCond(nn.Module): - def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64): + def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): super(FARGANCond, self).__init__() self.feature_dim = feature_dim self.cond_size = cond_size - self.pembed = nn.Embedding(256, pembed_dims) - self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False) - self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) - self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) - self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False) + self.pembed = nn.Embedding(224, pembed_dims) + self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False) + self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) + self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False) self.apply(init_weights) def forward(self, features, period): - p = self.pembed(period) + p = self.pembed(period-32) features = torch.cat((features, p), -1) tmp = torch.tanh(self.fdense1(features)) tmp = tmp.permute(0, 2, 1) tmp = torch.tanh(self.fconv1(tmp)) tmp = torch.tanh(self.fconv2(tmp)) tmp = tmp.permute(0, 2, 1) - tmp = torch.tanh(self.fdense2(tmp)) + #tmp = torch.tanh(self.fdense2(tmp)) return tmp class FARGANSub(nn.Module): |