diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-10-10 07:51:57 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-10-10 07:51:57 +0300 |
commit | 9e76a7bfb835ebe7cb97cf24da98462b78de0207 (patch) | |
tree | dd209e366796acd6130ed59d219daae1a0fbfb4c | |
parent | d1c5b32add990473df84e42a8db64851b2dd65f6 (diff) |
update fargan to match version 45
-rw-r--r-- | dnn/torch/fargan/adv_train_fargan.py | 25 | ||||
-rw-r--r-- | dnn/torch/fargan/dataset.py | 8 | ||||
-rw-r--r-- | dnn/torch/fargan/fargan.py | 156 | ||||
-rw-r--r-- | dnn/torch/fargan/rc.py | 29 | ||||
-rw-r--r-- | dnn/torch/fargan/stft_loss.py | 14 | ||||
-rw-r--r-- | dnn/torch/fargan/test_fargan.py | 27 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 17 |
7 files changed, 194 insertions, 82 deletions
diff --git a/dnn/torch/fargan/adv_train_fargan.py b/dnn/torch/fargan/adv_train_fargan.py index 23f5b2d0..94817cbc 100644 --- a/dnn/torch/fargan/adv_train_fargan.py +++ b/dnn/torch/fargan/adv_train_fargan.py @@ -132,6 +132,10 @@ states = None spect_loss = MultiResolutionSTFTLoss(device).to(device) +for param in model.parameters(): + param.requires_grad = False + +batch_count = 0 if __name__ == '__main__': model.to(device) disc.to(device) @@ -153,22 +157,28 @@ if __name__ == '__main__': print(f"training epoch {epoch}...") with tqdm.tqdm(dataloader, unit='batch') as tepoch: for i, (features, periods, target, lpc) in enumerate(tepoch): + if epoch == 1 and i == 400: + for param in model.parameters(): + param.requires_grad = True + optimizer.zero_grad() features = features.to(device) - lpc = lpc.to(device) + #lpc = lpc.to(device) + #lpc = lpc*(args.gamma**torch.arange(1,17, device=device)) + #lpc = fargan.interp_lpc(lpc, 4) periods = periods.to(device) if True: target = target[:, :sequence_length*160] - lpc = lpc[:,:sequence_length,:] + #lpc = lpc[:,:sequence_length*4,:] features = features[:,:sequence_length+4,:] periods = periods[:,:sequence_length+4] else: target=target[::2, :] - lpc=lpc[::2,:] + #lpc=lpc[::2,:] features=features[::2,:] periods=periods[::2,:] target = target.to(device) - target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma) + #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma) #nb_pre = random.randrange(1, 6) nb_pre = 2 @@ -208,7 +218,7 @@ if __name__ == '__main__': cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80]) specc_loss = spect_loss(output, target.detach()) - reg_loss = args.reg_weight * (.00*cont_loss + specc_loss) + reg_loss = (.00*cont_loss + specc_loss) loss_gen = 0 for scale in scores_gen: @@ -216,7 +226,8 @@ if __name__ == '__main__': feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen) - gen_loss = reg_loss + feat_loss + loss_gen + reg_weight = args.reg_weight + 15./(1 + (batch_count/7600.)) + gen_loss = reg_weight * reg_loss + feat_loss + loss_gen model.zero_grad() @@ -238,12 +249,14 @@ if __name__ == '__main__': tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}", + reg_weight=f"{reg_weight:8.5f}", gen_loss=f"{running_gen_loss/(i+1):8.5f}", disc_loss=f"{running_disc_loss/(i+1):8.5f}", fmap_loss=f"{running_fmap_loss/(i+1):8.5f}", reg_loss=f"{running_reg_loss/(i+1):8.5f}", wc = f"{running_wc/(i+1):8.5f}", ) + batch_count = batch_count + 1 # save checkpoint checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth') diff --git a/dnn/torch/fargan/dataset.py b/dnn/torch/fargan/dataset.py index 6195c6af..2dfbb0b5 100644 --- a/dnn/torch/fargan/dataset.py +++ b/dnn/torch/fargan/dataset.py @@ -1,5 +1,6 @@ import torch import numpy as np +import fargan class FARGANDataset(torch.utils.data.Dataset): def __init__(self, @@ -34,7 +35,8 @@ class FARGANDataset(torch.utils.data.Dataset): sizeof = self.features.strides[-1] self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features), strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof)) - self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int') + #self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int') + self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int') self.lpc = self.features[:, :, self.nb_used_features:] self.features = self.features[:, :, :self.nb_used_features] @@ -51,5 +53,9 @@ class FARGANDataset(torch.utils.data.Dataset): lpc = self.lpc[index, 4:, :].copy() data = self.data[index, :].copy().astype(np.float32) / 2**15 periods = self.periods[index, :].copy() + #lpc = lpc*(self.gamma**np.arange(1,17)) + #lpc=lpc[None,:,:] + #lpc = fargan.interp_lpc(lpc, 4) + #lpc=lpc[0,:,:] return features, periods, data, lpc diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index e9cc687a..65f0a97b 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -4,6 +4,8 @@ from torch import nn import torch.nn.functional as F import filters from torch.nn.utils import weight_norm +#from convert_lsp import lpc_to_lsp, lsp_to_lpc +from rc import lpc2rc, rc2lpc Fs = 16000 @@ -27,6 +29,27 @@ def sig_loss(y_true, y_pred): p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True)) return torch.mean(1.-torch.sum(p*t, dim=-1)) +def interp_lpc(lpc, factor): + #print(lpc.shape) + #f = (np.arange(factor)+.5*((factor+1)%2))/factor + lsp = torch.atanh(lpc2rc(lpc)) + #print("lsp0:") + #print(lsp) + shape = lsp.shape + #print("shape is", shape) + shape = (shape[0], shape[1]*factor, shape[2]) + interp_lsp = torch.zeros(shape, device=lpc.device) + for k in range(factor): + f = (k+.5*((factor+1)%2))/factor + interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:] + interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp + for k in range(factor//2): + interp_lsp[:,k,:] = interp_lsp[:,factor//2,:] + for k in range((factor+1)//2): + interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:] + #print("lsp:") + #print(interp_lsp) + return rc2lpc(torch.tanh(interp_lsp)) def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): device = x.device @@ -39,9 +62,9 @@ def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size)) out = torch.zeros((batch_size, 0), device=device) - if gamma is not None: - bw = gamma**(torch.arange(1, 17, device=device)) - lpc = lpc*bw[None,None,:] + #if gamma is not None: + # bw = gamma**(torch.arange(1, 17, device=device)) + # lpc = lpc*bw[None,None,:] ones = torch.ones((*(lpc.shape[:-1]), 1), device=device) zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device) a = torch.cat([ones, lpc], -1) @@ -127,30 +150,34 @@ class FWConv(nn.Module): out = self.glu(torch.tanh(self.conv(xcat))) return out, xcat[:,self.in_size:] +def n(x): + return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) + class FARGANCond(nn.Module): - def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64): + def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): super(FARGANCond, self).__init__() self.feature_dim = feature_dim self.cond_size = cond_size - self.pembed = nn.Embedding(256, pembed_dims) - self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False) - self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) - self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) - self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False) + self.pembed = nn.Embedding(224, pembed_dims) + self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False) + self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) + self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False) self.apply(init_weights) + nb_params = sum(p.numel() for p in self.parameters()) + print(f"cond model: {nb_params} weights") def forward(self, features, period): - p = self.pembed(period) + p = self.pembed(period-32) features = torch.cat((features, p), -1) tmp = torch.tanh(self.fdense1(features)) tmp = tmp.permute(0, 2, 1) tmp = torch.tanh(self.fconv1(tmp)) tmp = torch.tanh(self.fconv2(tmp)) tmp = tmp.permute(0, 2, 1) - tmp = torch.tanh(self.fdense2(tmp)) + #tmp = torch.tanh(self.fdense2(tmp)) return tmp class FARGANSub(nn.Module): @@ -160,70 +187,87 @@ class FARGANSub(nn.Module): self.subframe_size = subframe_size self.nb_subframes = nb_subframes self.cond_size = cond_size + self.cond_gain_dense = nn.Linear(80, 1) #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) - self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size) - self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) - self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) - self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) - self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) + self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size) + self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False) + self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) self.dense1_glu = GLU(self.cond_size) - self.dense2_glu = GLU(self.cond_size) self.gru1_glu = GLU(self.cond_size) - self.gru2_glu = GLU(self.cond_size) - self.gru3_glu = GLU(self.cond_size) - self.ptaps_dense = nn.Linear(4*self.cond_size, 5) + self.gru2_glu = GLU(128) + self.gru3_glu = GLU(128) + self.skip_glu = GLU(self.cond_size) + #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) - self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False) - self.gain_dense_out = nn.Linear(4*self.cond_size, 1) + self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False) + self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False) + self.gain_dense_out = nn.Linear(self.cond_size, 4) self.apply(init_weights) + nb_params = sum(p.numel() for p in self.parameters()) + print(f"subframe model: {nb_params} weights") - def forward(self, cond, prev, exc_mem, phase, period, states, gain=None): + def forward(self, cond, prev_pred, exc_mem, period, states, gain=None): device = exc_mem.device #print(cond.shape, prev.shape) - dump_signal(prev, 'prev_in.f32') - - idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254) + cond = n(cond) + dump_signal(gain, 'gain0.f32') + gain = torch.exp(self.cond_gain_dense(cond)) + dump_signal(gain, 'gain1.f32') + idx = 256-period[:,None] rng = torch.arange(self.subframe_size+4, device=device) idx = idx + rng[None,:] - 2 + mask = idx >= 256 + idx = idx - mask*period[:,None] pred = torch.gather(exc_mem, 1, idx) - pred = pred/(1e-5+gain) + pred = n(pred/(1e-5+gain)) - prev = prev/(1e-5+gain) + prev = exc_mem[:,-self.subframe_size:] + dump_signal(prev, 'prev_in.f32') + prev = n(prev/(1e-5+gain)) dump_signal(prev, 'pitch_exc.f32') dump_signal(exc_mem, 'exc_mem.f32') - tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1) + tmp = torch.cat((cond, pred, prev), 1) + #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] + fpitch = pred[:,2:-2] #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) fwc0_out, fwc0_state = self.fwc0(tmp, states[3]) - dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out))) - gru1_state = self.gru1(dense2_out, states[0]) - gru1_out = self.gru1_glu(gru1_state) - gru2_state = self.gru2(gru1_out, states[1]) - gru2_out = self.gru2_glu(gru2_state) - gru3_state = self.gru3(gru2_out, states[2]) - gru3_out = self.gru3_glu(gru3_state) - gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1) - sig_out = torch.tanh(self.sig_dense_out(gru3_out)) + fwc0_out = n(fwc0_out) + pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out)) + + gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0]) + gru1_out = self.gru1_glu(n(gru1_state)) + gru1_out = n(gru1_out) + gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1]) + gru2_out = self.gru2_glu(n(gru2_state)) + gru2_out = n(gru2_out) + gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2]) + gru3_out = self.gru3_glu(n(gru3_state)) + gru3_out = n(gru3_out) + gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1) + skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1))) + skip_out = self.skip_glu(n(skip_out)) + sig_out = torch.tanh(self.sig_dense_out(skip_out)) dump_signal(sig_out, 'exc_out.f32') - taps = self.ptaps_dense(gru3_out) - taps = .2*taps + torch.exp(taps) - taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) - dump_signal(taps, 'taps.f32') - #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:] - fpitch = pred[:,2:-2] + #taps = self.ptaps_dense(gru3_out) + #taps = .2*taps + torch.exp(taps) + #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True)) + #dump_signal(taps, 'taps.f32') - pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) dump_signal(pitch_gain, 'pgain.f32') - sig_out = (sig_out + pitch_gain*fpitch) * gain + #sig_out = (sig_out + pitch_gain*fpitch) * gain + sig_out = sig_out * gain exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) + prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1) dump_signal(sig_out, 'sig_out.f32') - return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state) + return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state) class FARGAN(nn.Module): def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): @@ -242,37 +286,30 @@ class FARGAN(nn.Module): device = features.device batch_size = features.size(0) - phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size) - #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw') - - prev = torch.zeros(batch_size, self.subframe_size, device=device) + prev = torch.zeros(batch_size, 256, device=device) exc_mem = torch.zeros(batch_size, 256, device=device) nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 states = ( torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device) + torch.zeros(batch_size, 128, device=device), + torch.zeros(batch_size, 128, device=device), + torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device) ) sig = torch.zeros((batch_size, 0), device=device) cond = self.cond_net(features, period) if pre is not None: - prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size] exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size] start = 1 if nb_pre_frames>0 else 0 for n in range(start, nb_frames+nb_pre_frames): for k in range(self.nb_subframes): pos = n*self.frame_size + k*self.subframe_size - preal = phase_real[:, pos:pos+self.subframe_size] - pimag = phase_imag[:, pos:pos+self.subframe_size] - phase = torch.cat([preal, pimag], 1) #print("now: ", preal.shape, prev.shape, sig_in.shape) pitch = period[:, 3+n] gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0)) #gain = gain[:,:,None] - out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain) + out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain) if n < nb_pre_frames: out = pre[:, pos:pos+self.subframe_size] @@ -280,6 +317,5 @@ class FARGAN(nn.Module): else: sig = torch.cat([sig, out], 1) - prev = out states = [s.detach() for s in states] return sig, states diff --git a/dnn/torch/fargan/rc.py b/dnn/torch/fargan/rc.py new file mode 100644 index 00000000..7f67016a --- /dev/null +++ b/dnn/torch/fargan/rc.py @@ -0,0 +1,29 @@ +import torch + + + +def rc2lpc(rc): + order = rc.shape[-1] + lpc=rc[...,0:1] + for i in range(1, order): + lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1) + #print("to:", lpc) + return lpc + +def lpc2rc(lpc): + order = lpc.shape[-1] + rc = lpc[...,-1:] + for i in range(order-1, 0, -1): + ki = lpc[...,-1:] + lpc = lpc[...,:-1] + lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki) + rc = torch.cat([lpc[...,-1:] , rc], -1) + return rc + +if __name__ == "__main__": + rc = torch.tensor([[.5, -.5, .6, -.6]]) + print(rc) + lpc = rc2lpc(rc) + print(lpc) + rc2 = lpc2rc(lpc) + print(rc2) diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py index accf2f4a..8c904054 100644 --- a/dnn/torch/fargan/stft_loss.py +++ b/dnn/torch/fargan/stft_loss.py @@ -44,7 +44,9 @@ class SpectralConvergenceLoss(torch.nn.Module): Returns: Tensor: Spectral convergence loss value. """ - return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + x_mag = torch.sqrt(x_mag) + y_mag = torch.sqrt(y_mag) + return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) class LogSTFTMagnitudeLoss(torch.nn.Module): """Log STFT magnitude loss module.""" @@ -136,26 +138,26 @@ class STFTLoss(torch.nn.Module): class MultiResolutionSTFTLoss(torch.nn.Module): - def __init__(self, + '''def __init__(self, device, fft_sizes=[2048, 1024, 512, 256, 128, 64], hop_sizes=[512, 256, 128, 64, 32, 16], win_lengths=[2048, 1024, 512, 256, 128, 64], - window="hann_window"): + window="hann_window"):''' - '''def __init__(self, + '''def __init__(self, device, fft_sizes=[2048, 1024, 512, 256, 128, 64], hop_sizes=[256, 128, 64, 32, 16, 8], win_lengths=[1024, 512, 256, 128, 64, 32], window="hann_window"):''' - '''def __init__(self, + def __init__(self, device, fft_sizes=[2560, 1280, 640, 320, 160, 80], hop_sizes=[640, 320, 160, 80, 40, 20], win_lengths=[2560, 1280, 640, 320, 160, 80], - window="hann_window"):''' + window="hann_window"): super(MultiResolutionSTFTLoss, self).__init__() assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) diff --git a/dnn/torch/fargan/test_fargan.py b/dnn/torch/fargan/test_fargan.py index 76e1f854..d3aeb613 100644 --- a/dnn/torch/fargan/test_fargan.py +++ b/dnn/torch/fargan/test_fargan.py @@ -48,7 +48,9 @@ model.load_state_dict(checkpoint['state_dict'], strict=False) features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features)) lpc = features[:,4-1:-1,nb_used_features:] features = features[:, :, :nb_used_features] -periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int') +#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int') +periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int') + nb_frames = features.shape[1] #nb_frames = 1000 @@ -90,18 +92,37 @@ def inverse_perceptual_weighting (pw_signal, filters, weighting_vector): buffer[:] = out_sig_frame[-16:] return signal +def inverse_perceptual_weighting40 (pw_signal, filters): + + #inverse perceptual weighting= H_preemph / W(z/gamma) + + signal = np.zeros_like(pw_signal) + buffer = np.zeros(16) + num_frames = pw_signal.shape[0] //40 + assert num_frames == filters.shape[0] + for frame_idx in range(0, num_frames): + in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:] + out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer) + signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:] + buffer[:] = out_sig_frame[-16:] + return signal +from scipy.signal import lfilter if __name__ == '__main__': model.to(device) features = torch.tensor(features).to(device) #lpc = torch.tensor(lpc).to(device) periods = torch.tensor(periods).to(device) + weighting = gamma**np.arange(1, 17) + lpc = lpc*weighting + lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy() sig, _ = model(features, periods, nb_frames - 4) - weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) + #weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) sig = sig.detach().numpy().flatten() - sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector) + sig = lfilter(np.array([1.]), np.array([1., -.85]), sig) + #sig = inverse_perceptual_weighting40(sig, lpc[0,:,:]) pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16') pcm.tofile(signal_file) diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index 4ab20045..dc6feb2d 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -114,20 +114,25 @@ if __name__ == '__main__': for i, (features, periods, target, lpc) in enumerate(tepoch): optimizer.zero_grad() features = features.to(device) - lpc = lpc.to(device) + #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4)) + #print("interp size", lpc.shape) + #lpc = lpc.to(device) + #lpc = lpc*(args.gamma**torch.arange(1,17, device=device)) + #lpc = fargan.interp_lpc(lpc, 4) periods = periods.to(device) if (np.random.rand() > 0.1): target = target[:, :sequence_length*160] - lpc = lpc[:,:sequence_length,:] + #lpc = lpc[:,:sequence_length*4,:] features = features[:,:sequence_length+4,:] periods = periods[:,:sequence_length+4] else: target=target[::2, :] - lpc=lpc[::2,:] + #lpc=lpc[::2,:] features=features[::2,:] periods=periods[::2,:] target = target.to(device) - target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma) + #print(target.shape, lpc.shape) + #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma) #nb_pre = random.randrange(1, 6) nb_pre = 2 @@ -135,9 +140,9 @@ if __name__ == '__main__': sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) sig = torch.cat([pre, sig], -1) - cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80]) + cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160]) specc_loss = spect_loss(sig, target.detach()) - loss = .00*cont_loss + specc_loss + loss = .03*cont_loss + specc_loss loss.backward() optimizer.step() |