diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-12 06:28:52 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-13 05:50:48 +0300 |
commit | 2f8b36d691a3802714a54abd7409234e41ec3e21 (patch) | |
tree | 0ab66685e1e0295671d13b5332fbfe28aaf176f7 | |
parent | 72c5ea4129dc6473fb1d82ef3ec3d5714fab8cbd (diff) |
Add conditioning interpolation, fwconv layer
-rw-r--r-- | dnn/torch/fargan/fargan.py | 56 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 2 |
2 files changed, 41 insertions, 17 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 952c1b84..b532f268 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -101,6 +101,31 @@ class GLU(nn.Module): return out +class FWConv(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3): + super(FWConv, self).__init__() + + torch.manual_seed(5) + + self.in_size = in_size + self.kernel_size = kernel_size + self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) + self.glu = GLU(out_size) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x, state): + xcat = torch.cat((state, x), -1) + #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size) + out = self.glu(torch.tanh(self.conv(xcat))) + return out, xcat[:,self.in_size:] class FARGANCond(nn.Module): def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64): @@ -113,7 +138,7 @@ class FARGANCond(nn.Module): self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False) self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) - self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) + self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False) self.apply(init_weights) @@ -138,9 +163,10 @@ class FARGANSub(nn.Module): self.has_gain = has_gain self.passthrough_size = passthrough_size - print("has_gain:", self.has_gain) - print("passthrough_size:", self.passthrough_size) - self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) + #print("has_gain:", self.has_gain) + #print("passthrough_size:", self.passthrough_size) + #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) + self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size) self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) @@ -176,30 +202,26 @@ class FARGANSub(nn.Module): dump_signal(prev, 'pitch_exc.f32') dump_signal(exc_mem, 'exc_mem.f32') - passthrough = states[3] - tmp = torch.cat((cond, pred[:,2:-2], prev, passthrough, phase), 1) + tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1) - tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) - dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(tmp))) + #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) + fwc0_out, fwc0_state = self.fwc0(tmp, states[3]) + dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out))) gru1_state = self.gru1(dense2_out, states[0]) gru1_out = self.gru1_glu(gru1_state) - #gru1_out = torch.cat([gru1_out, fpitch], 1) gru2_state = self.gru2(gru1_out, states[1]) gru2_out = self.gru2_glu(gru2_state) - #gru2_out = torch.cat([gru2_out, fpitch], 1) gru3_state = self.gru3(gru2_out, states[2]) gru3_out = self.gru3_glu(gru3_state) gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1) sig_out = torch.tanh(self.sig_dense_out(gru3_out)) - if self.passthrough_size != 0: - passthrough = sig_out[:,self.subframe_size:] - sig_out = sig_out[:,:self.subframe_size] dump_signal(sig_out, 'exc_out.f32') taps = self.ptaps_dense(gru3_out) taps = .2*taps + torch.exp(taps) taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) dump_signal(taps, 'taps.f32') - fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] + #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] + fpitch = pred[:,2:-2] if self.has_gain: pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) @@ -207,7 +229,7 @@ class FARGANSub(nn.Module): sig_out = (sig_out + pitch_gain*fpitch) * gain exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) dump_signal(sig_out, 'sig_out.f32') - return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough) + return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state) class FARGAN(nn.Module): def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): @@ -239,7 +261,7 @@ class FARGAN(nn.Module): torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, self.cond_size, device=device), torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.passthrough_size, device=device) + torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device) ) sig = torch.zeros((batch_size, 0), device=device) @@ -259,7 +281,7 @@ class FARGAN(nn.Module): pitch = period[:, 3+n] gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) #gain = gain[:,:,None] - out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain) + out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain) if n < nb_pre_frames: out = pre[:, pos:pos+self.subframe_size] diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index 852b87be..3904253e 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -121,6 +121,8 @@ if __name__ == '__main__': if (np.random.rand() > 0.1): target = target[:, :sequence_length*160] lpc = lpc[:,:sequence_length,:] + features = features[:,:sequence_length+4,:] + periods = periods[:,:sequence_length+4] else: target=target[::2, :] lpc=lpc[::2,:] |