diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-23 09:30:36 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-23 09:30:36 +0300 |
commit | c5385cccc546e8d5ffaf771709a05fb19eb06dec (patch) | |
tree | 2b089de8cd42e733b4d217282184385073eee7e3 | |
parent | dff28610b7909907a2a32824fe49babd56da3d43 (diff) |
LSP-domain interpolation
-rw-r--r-- | dnn/torch/fargan/dataset.py | 4 | ||||
-rw-r--r-- | dnn/torch/fargan/fargan.py | 29 | ||||
-rw-r--r-- | dnn/torch/fargan/stft_loss.py | 10 | ||||
-rw-r--r-- | dnn/torch/fargan/test_fargan.py | 22 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 7 |
5 files changed, 60 insertions, 12 deletions
diff --git a/dnn/torch/fargan/dataset.py b/dnn/torch/fargan/dataset.py index 6195c6af..4e10c453 100644 --- a/dnn/torch/fargan/dataset.py +++ b/dnn/torch/fargan/dataset.py @@ -1,5 +1,6 @@ import torch import numpy as np +import fargan class FARGANDataset(torch.utils.data.Dataset): def __init__(self, @@ -51,5 +52,8 @@ class FARGANDataset(torch.utils.data.Dataset): lpc = self.lpc[index, 4:, :].copy() data = self.data[index, :].copy().astype(np.float32) / 2**15 periods = self.periods[index, :].copy() + lpc=lpc[None,:,:] + lpc = fargan.interp_lpc(lpc*(.92**np.arange(1,17)), 4) + lpc=lpc[0,:,:] return features, periods, data, lpc diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 7f405e87..ff2fdfbe 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -4,6 +4,8 @@ from torch import nn import torch.nn.functional as F import filters from torch.nn.utils import weight_norm +from convert_lsp import lpc_to_lsp, lsp_to_lpc +from rc import lpc2rc, rc2lpc Fs = 16000 @@ -27,6 +29,27 @@ def sig_loss(y_true, y_pred): p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True)) return torch.mean(1.-torch.sum(p*t, dim=-1)) +def interp_lpc(lpc, factor): + #print(lpc.shape) + f = (np.arange(factor)+.5*((factor+1)%2))/factor + lsp = lpc_to_lsp(lpc) + #print("lsp0:") + #print(lsp) + shape = lsp.shape + #print("shape is", shape) + shape = (shape[0], shape[1]*factor, shape[2]) + interp_lsp = np.zeros(shape, dtype='float32') + for k in range(factor): + f = (k+.5*((factor+1)%2))/factor + interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:] + interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp + for k in range(factor//2): + interp_lsp[:,k,:] = interp_lsp[:,factor//2,:] + for k in range((factor+1)//2): + interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:] + #print("lsp:") + #print(interp_lsp) + return lsp_to_lpc(interp_lsp) def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): device = x.device @@ -39,9 +62,9 @@ def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size)) out = torch.zeros((batch_size, 0), device=device) - if gamma is not None: - bw = gamma**(torch.arange(1, 17, device=device)) - lpc = lpc*bw[None,None,:] + #if gamma is not None: + # bw = gamma**(torch.arange(1, 17, device=device)) + # lpc = lpc*bw[None,None,:] ones = torch.ones((*(lpc.shape[:-1]), 1), device=device) zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device) a = torch.cat([ones, lpc], -1) diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py index 52cf1711..5227eb67 100644 --- a/dnn/torch/fargan/stft_loss.py +++ b/dnn/torch/fargan/stft_loss.py @@ -138,26 +138,26 @@ class STFTLoss(torch.nn.Module): class MultiResolutionSTFTLoss(torch.nn.Module): - def __init__(self, + '''def __init__(self, device, fft_sizes=[2048, 1024, 512, 256, 128, 64], hop_sizes=[512, 256, 128, 64, 32, 16], win_lengths=[2048, 1024, 512, 256, 128, 64], - window="hann_window"): + window="hann_window"):''' - '''def __init__(self, + '''def __init__(self, device, fft_sizes=[2048, 1024, 512, 256, 128, 64], hop_sizes=[256, 128, 64, 32, 16, 8], win_lengths=[1024, 512, 256, 128, 64, 32], window="hann_window"):''' - '''def __init__(self, + def __init__(self, device, fft_sizes=[2560, 1280, 640, 320, 160, 80], hop_sizes=[640, 320, 160, 80, 40, 20], win_lengths=[2560, 1280, 640, 320, 160, 80], - window="hann_window"):''' + window="hann_window"): super(MultiResolutionSTFTLoss, self).__init__() assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) diff --git a/dnn/torch/fargan/test_fargan.py b/dnn/torch/fargan/test_fargan.py index 8a6d2c25..a6da878c 100644 --- a/dnn/torch/fargan/test_fargan.py +++ b/dnn/torch/fargan/test_fargan.py @@ -90,6 +90,21 @@ def inverse_perceptual_weighting (pw_signal, filters, weighting_vector): buffer[:] = out_sig_frame[-16:] return signal +def inverse_perceptual_weighting40 (pw_signal, filters): + + #inverse perceptual weighting= H_preemph / W(z/gamma) + + signal = np.zeros_like(pw_signal) + buffer = np.zeros(16) + num_frames = pw_signal.shape[0] //40 + assert num_frames == filters.shape[0] + for frame_idx in range(0, num_frames): + + in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:] + out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer) + signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:] + buffer[:] = out_sig_frame[-16:] + return signal if __name__ == '__main__': @@ -97,11 +112,14 @@ if __name__ == '__main__': features = torch.tensor(features).to(device) #lpc = torch.tensor(lpc).to(device) periods = torch.tensor(periods).to(device) + weighting = gamma**np.arange(1, 17) + lpc = lpc*weighting + lpc = fargan.interp_lpc(lpc, 4) sig, _ = model(features, periods, nb_frames - 4) - weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) + #weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) sig = sig.detach().numpy().flatten() - sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector) + sig = inverse_perceptual_weighting40(sig, lpc[0,:,:]) pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16') pcm.tofile(signal_file) diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index 7f2555a4..494421f5 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -114,11 +114,13 @@ if __name__ == '__main__': for i, (features, periods, target, lpc) in enumerate(tepoch): optimizer.zero_grad() features = features.to(device) + #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4)) + #print("interp size", lpc.shape) lpc = lpc.to(device) periods = periods.to(device) if (np.random.rand() > 0.1): target = target[:, :sequence_length*160] - lpc = lpc[:,:sequence_length,:] + lpc = lpc[:,:sequence_length*4,:] features = features[:,:sequence_length+4,:] periods = periods[:,:sequence_length+4] else: @@ -127,7 +129,8 @@ if __name__ == '__main__': features=features[::2,:] periods=periods[::2,:] target = target.to(device) - target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma) + #print(target.shape, lpc.shape) + target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma) #nb_pre = random.randrange(1, 6) nb_pre = 2 |