diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-26 04:19:48 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-28 23:15:04 +0300 |
commit | a5d34094ed2764715393e4805e8f77596348a62d (patch) | |
tree | 522d61b35f45989f472cc9c8452d0bce28ad5ec0 | |
parent | b57ddadf9821c52ade6be7b40b4004fcc2c89a17 (diff) |
more simplificationsexp_fargan43g2
-rw-r--r-- | dnn/torch/fargan/fargan.py | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 3e67351f..39b5ddaf 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -185,20 +185,18 @@ class FARGANSub(nn.Module): #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size) - self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) - self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) - self.gru3 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False) + self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) self.dense1_glu = GLU(self.cond_size) - self.dense2_glu = GLU(self.cond_size) self.gru1_glu = GLU(self.cond_size) - self.gru2_glu = GLU(self.cond_size) - self.gru3_glu = GLU(self.cond_size) + self.gru2_glu = GLU(128) + self.gru3_glu = GLU(128) self.skip_glu = GLU(self.cond_size) #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) - self.skip_dense = nn.Linear(4*self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False) self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False) self.gain_dense_out = nn.Linear(self.cond_size, 4) @@ -228,16 +226,15 @@ class FARGANSub(nn.Module): #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) fwc0_out, fwc0_state = self.fwc0(tmp, states[3]) - dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out))) - pitch_gain = torch.sigmoid(self.gain_dense_out(dense2_out)) + pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out)) - gru1_state = self.gru1(torch.cat([dense2_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0]) + gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0]) gru1_out = self.gru1_glu(gru1_state) gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1]) gru2_out = self.gru2_glu(gru2_state) gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2]) gru3_out = self.gru3_glu(gru3_state) - gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1) + gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1) skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1))) skip_out = self.skip_glu(skip_out) sig_out = torch.tanh(self.sig_dense_out(skip_out)) @@ -278,8 +275,8 @@ class FARGAN(nn.Module): states = ( torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, 128, device=device), + torch.zeros(batch_size, 128, device=device), torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device) ) |