diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-05 19:16:45 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-13 05:50:47 +0300 |
commit | d54b9fb49af339c8ee72a8f54ee7e5beadbd724f (patch) | |
tree | 1d7a24f2122b339ee714dab6ff8930ca93c032fb | |
parent | fb570ed8bb2648e07e84faf40f30d93b7a0311d7 (diff) |
Adds skip connections
-rw-r--r-- | dnn/torch/fargan/fargan.py | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 2d826da0..daa13f17 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -140,7 +140,7 @@ class FARGANSub(nn.Module): print("has_gain:", self.has_gain) print("passthrough_size:", self.passthrough_size) - self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size+4, self.cond_size, bias=False) + self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) @@ -151,11 +151,11 @@ class FARGANSub(nn.Module): self.gru1_glu = GLU(self.cond_size) self.gru2_glu = GLU(self.cond_size) self.gru3_glu = GLU(self.cond_size) - self.ptaps_dense = nn.Linear(self.cond_size, 5) + self.ptaps_dense = nn.Linear(4*self.cond_size, 5) - self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False) + self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size+self.passthrough_size, bias=False) if self.has_gain: - self.gain_dense_out = nn.Linear(self.cond_size, 1) + self.gain_dense_out = nn.Linear(4*self.cond_size, 1) self.apply(init_weights) @@ -173,30 +173,35 @@ class FARGANSub(nn.Module): pred = pred/(1e-5+gain) prev = prev/(1e-5+gain) - #prev = prev*0 dump_signal(prev, 'pitch_exc.f32') dump_signal(exc_mem, 'exc_mem.f32') passthrough = states[3] - tmp = torch.cat((cond, pred, prev, passthrough, phase), 1) + tmp = torch.cat((cond, pred[:,2:-2], prev, passthrough, phase), 1) tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) - tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp))) - gru1_state = self.gru1(tmp, states[0]) - gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1]) - gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2]) + dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(tmp))) + gru1_state = self.gru1(dense2_out, states[0]) + gru1_out = self.gru1_glu(gru1_state) + #gru1_out = torch.cat([gru1_out, fpitch], 1) + gru2_state = self.gru2(gru1_out, states[1]) + gru2_out = self.gru2_glu(gru2_state) + #gru2_out = torch.cat([gru2_out, fpitch], 1) + gru3_state = self.gru3(gru2_out, states[2]) gru3_out = self.gru3_glu(gru3_state) + gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1) sig_out = torch.tanh(self.sig_dense_out(gru3_out)) if self.passthrough_size != 0: passthrough = sig_out[:,self.subframe_size:] sig_out = sig_out[:,:self.subframe_size] dump_signal(sig_out, 'exc_out.f32') + taps = self.ptaps_dense(gru3_out) + taps = .2*taps + torch.exp(taps) + taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) + dump_signal(taps, 'taps.f32') + fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] + if self.has_gain: - taps = self.ptaps_dense(gru3_out) - taps = .2*taps + torch.exp(taps) - taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) - dump_signal(taps, 'taps.f32') - fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) dump_signal(pitch_gain, 'pgain.f32') sig_out = (sig_out + pitch_gain*fpitch) * gain |