diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-08-31 23:32:15 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-13 05:50:46 +0300 |
commit | 4f63743f8f814d608df4b249ac796f8edd15ade0 (patch) | |
tree | eb64645cc41e223c28e38b72da81135858e312a2 | |
parent | 1b13f6313e8413056f6d9f1f15fa994d0dff7a57 (diff) |
explicit signal gain, explicit pitch predictor
-rw-r--r-- | dnn/torch/fargan/fargan.py | 16 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 4 |
2 files changed, 11 insertions, 9 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 987cc8e5..d895eba0 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -9,7 +9,7 @@ Fs = 16000 fid_dict = {} def dump_signal(x, filename): - return + #return if filename in fid_dict: fid = fid_dict[filename] else: @@ -162,7 +162,7 @@ class FARGANSub(nn.Module): self.apply(init_weights) - def forward(self, cond, prev, exc_mem, phase, period, states): + def forward(self, cond, prev, exc_mem, phase, period, states, gain=None): device = exc_mem.device #print(cond.shape, prev.shape) @@ -176,7 +176,7 @@ class FARGANSub(nn.Module): dump_signal(prev, 'pitch_exc.f32') dump_signal(exc_mem, 'exc_mem.f32') if self.has_gain: - gain = torch.norm(prev, dim=1, p=2, keepdim=True) + #gain = torch.norm(prev, dim=1, p=2, keepdim=True) prev = prev/(1e-5+gain) prev = torch.cat([prev, torch.log(1e-5+gain)], 1) @@ -193,10 +193,10 @@ class FARGANSub(nn.Module): if self.passthrough_size != 0: passthrough = sig_out[:,self.subframe_size:] sig_out = sig_out[:,:self.subframe_size] - if self.has_gain: - out_gain = torch.exp(self.gain_dense_out(gru3_out)) - sig_out = sig_out * out_gain dump_signal(sig_out, 'exc_out.f32') + if self.has_gain: + pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) + sig_out = (sig_out + pitch_gain*prev[:,:-1]) * gain exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) dump_signal(sig_out, 'sig_out.f32') return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough) @@ -246,7 +246,9 @@ class FARGAN(nn.Module): phase = torch.cat([preal, pimag], 1) #print("now: ", preal.shape, prev.shape, sig_in.shape) pitch = period[:, 3+n] - out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states) + gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) + #gain = gain[:,:,None] + out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain) if n < nb_pre_frames: out = pre[:, pos:pos+self.subframe_size] diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index 117518b6..b58eabb3 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -127,9 +127,9 @@ if __name__ == '__main__': sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) sig = torch.cat([pre, sig], -1) - cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80]) + cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80]) specc_loss = spect_loss(sig, target.detach()) - loss = .2*cont_loss + specc_loss + loss = .00*cont_loss + specc_loss loss.backward() optimizer.step() |