diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-25 18:20:44 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-28 23:15:03 +0300 |
commit | 7e770ffb3ae1931185b3563831868ea946a330d0 (patch) | |
tree | 1f27abf621664861d02c515c0a1cb1bad2a59405 | |
parent | 5fd10ee92112f6bfc62c599578f91717ab915d9b (diff) |
remove phase
-rw-r--r-- | dnn/torch/fargan/fargan.py | 26 |
1 files changed, 10 insertions, 16 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 66a20fbc..8988148f 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -185,7 +185,7 @@ class FARGANSub(nn.Module): self.cond_size = cond_size #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) - self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size) + self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size) self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) @@ -206,11 +206,10 @@ class FARGANSub(nn.Module): self.apply(init_weights) - def forward(self, cond, prev, exc_mem, phase, period, states, gain=None): + def forward(self, cond, prev_pred, exc_mem, period, states, gain=None): device = exc_mem.device #print(cond.shape, prev.shape) - dump_signal(prev, 'prev_in.f32') idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254) rng = torch.arange(self.subframe_size+4, device=device) @@ -218,11 +217,13 @@ class FARGANSub(nn.Module): pred = torch.gather(exc_mem, 1, idx) pred = pred/(1e-5+gain) + prev = exc_mem[:,-self.subframe_size:] + dump_signal(prev, 'prev_in.f32') prev = prev/(1e-5+gain) dump_signal(prev, 'pitch_exc.f32') dump_signal(exc_mem, 'exc_mem.f32') - tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1) + tmp = torch.cat((cond, pred, prev), 1) #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] fpitch = pred[:,2:-2] @@ -251,8 +252,9 @@ class FARGANSub(nn.Module): #sig_out = (sig_out + pitch_gain*fpitch) * gain sig_out = sig_out * gain exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) + prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1) dump_signal(sig_out, 'sig_out.f32') - return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state) + return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state) class FARGAN(nn.Module): def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): @@ -271,10 +273,7 @@ class FARGAN(nn.Module): device = features.device batch_size = features.size(0) - phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size) - #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw') - - prev = torch.zeros(batch_size, self.subframe_size, device=device) + prev = torch.zeros(batch_size, 256, device=device) exc_mem = torch.zeros(batch_size, 256, device=device) nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 @@ -282,26 +281,22 @@ class FARGAN(nn.Module): torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device) + torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device) ) sig = torch.zeros((batch_size, 0), device=device) cond = self.cond_net(features, period) if pre is not None: - prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size] exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size] start = 1 if nb_pre_frames>0 else 0 for n in range(start, nb_frames+nb_pre_frames): for k in range(self.nb_subframes): pos = n*self.frame_size + k*self.subframe_size - preal = phase_real[:, pos:pos+self.subframe_size] - pimag = phase_imag[:, pos:pos+self.subframe_size] - phase = torch.cat([preal, pimag], 1) #print("now: ", preal.shape, prev.shape, sig_in.shape) pitch = period[:, 3+n] gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) #gain = gain[:,:,None] - out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain) + out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain) if n < nb_pre_frames: out = pre[:, pos:pos+self.subframe_size] @@ -309,6 +304,5 @@ class FARGAN(nn.Module): else: sig = torch.cat([sig, out], 1) - prev = out states = [s.detach() for s in states] return sig, states |