diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-22 19:57:04 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-28 06:58:23 +0300 |
commit | b53d0ca7ef9f28b323b6cc2c98447d56824b9e4f (patch) | |
tree | 5c1fc6d80609fd1f3b4eea46fccb12cb45ba2d2e | |
parent | 25c65a0c0b9ce8282cfc713a7c0581664c93ab18 (diff) |
version 36
-rw-r--r-- | dnn/torch/fargan/fargan.py | 42 | ||||
-rw-r--r-- | dnn/torch/fargan/stft_loss.py | 4 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 4 |
3 files changed, 29 insertions, 21 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index e9cc687a..7f405e87 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -164,19 +164,21 @@ class FARGANSub(nn.Module): #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size) self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) - self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) - self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) - self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) + self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.gru3 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) self.dense1_glu = GLU(self.cond_size) self.dense2_glu = GLU(self.cond_size) self.gru1_glu = GLU(self.cond_size) self.gru2_glu = GLU(self.cond_size) self.gru3_glu = GLU(self.cond_size) - self.ptaps_dense = nn.Linear(4*self.cond_size, 5) + self.skip_glu = GLU(self.cond_size) + #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) - self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False) - self.gain_dense_out = nn.Linear(4*self.cond_size, 1) + self.skip_dense = nn.Linear(4*self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False) + self.gain_dense_out = nn.Linear(self.cond_size, 4) self.apply(init_weights) @@ -198,29 +200,33 @@ class FARGANSub(nn.Module): dump_signal(exc_mem, 'exc_mem.f32') tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1) + #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] + fpitch = pred[:,2:-2] #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) fwc0_out, fwc0_state = self.fwc0(tmp, states[3]) dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out))) - gru1_state = self.gru1(dense2_out, states[0]) + pitch_gain = torch.sigmoid(self.gain_dense_out(dense2_out)) + + gru1_state = self.gru1(torch.cat([dense2_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0]) gru1_out = self.gru1_glu(gru1_state) - gru2_state = self.gru2(gru1_out, states[1]) + gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1]) gru2_out = self.gru2_glu(gru2_state) - gru3_state = self.gru3(gru2_out, states[2]) + gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2]) gru3_out = self.gru3_glu(gru3_state) gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1) - sig_out = torch.tanh(self.sig_dense_out(gru3_out)) + skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1))) + skip_out = self.skip_glu(skip_out) + sig_out = torch.tanh(self.sig_dense_out(skip_out)) dump_signal(sig_out, 'exc_out.f32') - taps = self.ptaps_dense(gru3_out) - taps = .2*taps + torch.exp(taps) - taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) - dump_signal(taps, 'taps.f32') - #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] - fpitch = pred[:,2:-2] + #taps = self.ptaps_dense(gru3_out) + #taps = .2*taps + torch.exp(taps) + #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) + #dump_signal(taps, 'taps.f32') - pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) dump_signal(pitch_gain, 'pgain.f32') - sig_out = (sig_out + pitch_gain*fpitch) * gain + #sig_out = (sig_out + pitch_gain*fpitch) * gain + sig_out = sig_out * gain exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) dump_signal(sig_out, 'sig_out.f32') return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state) diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py index accf2f4a..542d0e9b 100644 --- a/dnn/torch/fargan/stft_loss.py +++ b/dnn/torch/fargan/stft_loss.py @@ -44,7 +44,9 @@ class SpectralConvergenceLoss(torch.nn.Module): Returns: Tensor: Spectral convergence loss value. """ - return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + x_mag = torch.sqrt(x_mag) + y_mag = torch.sqrt(y_mag) + return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) class LogSTFTMagnitudeLoss(torch.nn.Module): """Log STFT magnitude loss module.""" diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index 4ab20045..be2a1a66 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -135,9 +135,9 @@ if __name__ == '__main__': sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) sig = torch.cat([pre, sig], -1) - cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80]) + cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160]) specc_loss = spect_loss(sig, target.detach()) - loss = .00*cont_loss + specc_loss + loss = .03*cont_loss + specc_loss loss.backward() optimizer.step() |