diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-08-31 01:36:09 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-13 05:50:46 +0300 |
commit | 1b13f6313e8413056f6d9f1f15fa994d0dff7a57 (patch) | |
tree | 06ac09e827d72576069e9905ca106d4f6c5d7825 | |
parent | 4f4b6242099998d7acf89e17c287dc7f605af607 (diff) |
FARGAN initial commit in Opus
Copied/adapted from LPCNet repo
-rw-r--r-- | dnn/torch/fargan/dataset.py | 52 | ||||
-rw-r--r-- | dnn/torch/fargan/fargan.py | 260 | ||||
-rw-r--r-- | dnn/torch/fargan/filters.py | 46 | ||||
-rw-r--r-- | dnn/torch/fargan/stft_loss.py | 184 | ||||
-rw-r--r-- | dnn/torch/fargan/test_fargan.py | 107 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 155 |
6 files changed, 804 insertions, 0 deletions
diff --git a/dnn/torch/fargan/dataset.py b/dnn/torch/fargan/dataset.py new file mode 100644 index 00000000..f33bed36 --- /dev/null +++ b/dnn/torch/fargan/dataset.py @@ -0,0 +1,52 @@ +import torch +import numpy as np + +class FARGANDataset(torch.utils.data.Dataset): + def __init__(self, + feature_file, + signal_file, + frame_size=160, + sequence_length=15, + lookahead=1, + nb_used_features=20, + nb_features=36): + + self.frame_size = frame_size + self.sequence_length = sequence_length + self.lookahead = lookahead + self.nb_features = nb_features + self.nb_used_features = nb_used_features + pcm_chunk_size = self.frame_size*self.sequence_length + + self.data = np.memmap(signal_file, dtype='int16', mode='r') + #self.data = self.data[1::2] + self.nb_sequences = len(self.data)//(pcm_chunk_size)-1 + self.data = self.data[(4-self.lookahead)*self.frame_size:] + self.data = self.data[:self.nb_sequences*pcm_chunk_size] + + + self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size)) + + self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features)) + sizeof = self.features.strides[-1] + self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length+4, nb_features), + strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof)) + self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int') + + self.lpc = self.features[:, :, self.nb_used_features:] + self.features = self.features[:, :, :self.nb_used_features] + print("lpc_size:", self.lpc.shape) + + def __len__(self): + return self.nb_sequences + + def __getitem__(self, index): + features = self.features[index, :, :].copy() + if self.lookahead != 0: + lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy() + else: + lpc = self.lpc[index, 4:, :].copy() + data = self.data[index, :].copy().astype(np.float32) / 2**15 + periods = self.periods[index, :].copy() + + return features, periods, data, lpc diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py new file mode 100644 index 00000000..987cc8e5 --- /dev/null +++ b/dnn/torch/fargan/fargan.py @@ -0,0 +1,260 @@ +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import filters +from torch.nn.utils import weight_norm + +Fs = 16000 + +fid_dict = {} +def dump_signal(x, filename): + return + if filename in fid_dict: + fid = fid_dict[filename] + else: + fid = open(filename, "w") + fid_dict[filename] = fid + x = x.detach().numpy().astype('float32') + x.tofile(fid) + + +def sig_l1(y_true, y_pred): + return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true)) + +def sig_loss(y_true, y_pred): + t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True)) + p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True)) + return torch.mean(1.-torch.sum(p*t, dim=-1)) + + +def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9): + device = x.device + batch_size = lpc.size(0) + + nb_frames = lpc.shape[1] + + + sig = torch.zeros(batch_size, subframe_size+16, device=device) + x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size)) + out = torch.zeros((batch_size, 0), device=device) + + if gamma is not None: + bw = gamma**(torch.arange(1, 17, device=device)) + lpc = lpc*bw[None,None,:] + ones = torch.ones((*(lpc.shape[:-1]), 1), device=device) + zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device) + a = torch.cat([ones, lpc], -1) + a_big = torch.cat([a, zeros], -1) + fir_mat_big = filters.toeplitz_from_filter(a_big) + + #print(a_big[:,0,:]) + for n in range(nb_frames): + for k in range(nb_subframes): + + sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1) + exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None]) + out = torch.cat([out, exc[:,-subframe_size:,0]], 1) + + return out + + +# weight initialization and clipping +def init_weights(module): + if isinstance(module, nn.GRU): + for p in module.named_parameters(): + if p[0].startswith('weight_hh_'): + nn.init.orthogonal_(p[1]) + +def gen_phase_embedding(periods, frame_size): + device = periods.device + batch_size = periods.size(0) + nb_frames = periods.size(1) + w0 = 2*torch.pi/periods + w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1) + cum_phase = frame_size*torch.cumsum(w0_shift, 1) + fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size)) + embed = torch.unsqueeze(cum_phase, 2) + fine_phase + embed = torch.reshape(embed, (batch_size, -1)) + return torch.cos(embed), torch.sin(embed) + +class GLU(nn.Module): + def __init__(self, feat_size): + super(GLU, self).__init__() + + torch.manual_seed(5) + + self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + out = x * torch.sigmoid(self.gate(x)) + + return out + + +class FARGANCond(nn.Module): + def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64): + super(FARGANCond, self).__init__() + + self.feature_dim = feature_dim + self.cond_size = cond_size + + self.pembed = nn.Embedding(256, pembed_dims) + self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False) + self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) + self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False) + self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) + + self.apply(init_weights) + + def forward(self, features, period): + p = self.pembed(period) + features = torch.cat((features, p), -1) + tmp = torch.tanh(self.fdense1(features)) + tmp = tmp.permute(0, 2, 1) + tmp = torch.tanh(self.fconv1(tmp)) + tmp = torch.tanh(self.fconv2(tmp)) + tmp = tmp.permute(0, 2, 1) + tmp = torch.tanh(self.fdense2(tmp)) + return tmp + +class FARGANSub(nn.Module): + def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_size=0, has_gain=False): + super(FARGANSub, self).__init__() + + self.subframe_size = subframe_size + self.nb_subframes = nb_subframes + self.cond_size = cond_size + self.has_gain = has_gain + self.passthrough_size = passthrough_size + + print("has_gain:", self.has_gain) + print("passthrough_size:", self.passthrough_size) + + gain_param = 1 if self.has_gain else 0 + + self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False) + self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) + self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) + self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) + self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) + + self.dense1_glu = GLU(self.cond_size) + self.dense2_glu = GLU(self.cond_size) + self.gru1_glu = GLU(self.cond_size) + self.gru2_glu = GLU(self.cond_size) + self.gru3_glu = GLU(self.cond_size) + + self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False) + if self.has_gain: + self.gain_dense_out = nn.Linear(self.cond_size, 1) + + + self.apply(init_weights) + + def forward(self, cond, prev, exc_mem, phase, period, states): + device = exc_mem.device + #print(cond.shape, prev.shape) + + dump_signal(prev, 'prev_in.f32') + + idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None]) + rng = torch.arange(self.subframe_size, device=device) + idx = idx + rng[None,:] + prev = torch.gather(exc_mem, 1, idx) + #prev = prev*0 + dump_signal(prev, 'pitch_exc.f32') + dump_signal(exc_mem, 'exc_mem.f32') + if self.has_gain: + gain = torch.norm(prev, dim=1, p=2, keepdim=True) + prev = prev/(1e-5+gain) + prev = torch.cat([prev, torch.log(1e-5+gain)], 1) + + passthrough = states[3] + tmp = torch.cat((cond, prev, passthrough, phase), 1) + + tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp))) + tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp))) + gru1_state = self.gru1(tmp, states[0]) + gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1]) + gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2]) + gru3_out = self.gru3_glu(gru3_state) + sig_out = torch.tanh(self.sig_dense_out(gru3_out)) + if self.passthrough_size != 0: + passthrough = sig_out[:,self.subframe_size:] + sig_out = sig_out[:,:self.subframe_size] + if self.has_gain: + out_gain = torch.exp(self.gain_dense_out(gru3_out)) + sig_out = sig_out * out_gain + dump_signal(sig_out, 'exc_out.f32') + exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) + dump_signal(sig_out, 'sig_out.f32') + return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough) + +class FARGAN(nn.Module): + def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None): + super(FARGAN, self).__init__() + + self.subframe_size = subframe_size + self.nb_subframes = nb_subframes + self.frame_size = self.subframe_size*self.nb_subframes + self.feature_dim = feature_dim + self.cond_size = cond_size + self.has_gain = has_gain + self.passthrough_size = passthrough_size + + self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size) + self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, has_gain=has_gain, passthrough_size=passthrough_size) + + def forward(self, features, period, nb_frames, pre=None, states=None): + device = features.device + batch_size = features.size(0) + + phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size) + #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw') + + prev = torch.zeros(batch_size, self.subframe_size, device=device) + exc_mem = torch.zeros(batch_size, 256, device=device) + nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 + + if states is None: + states = ( + torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, self.passthrough_size, device=device) + ) + + sig = torch.zeros((batch_size, 0), device=device) + cond = self.cond_net(features, period) + passthrough = torch.zeros(batch_size, self.passthrough_size, device=device) + for n in range(nb_frames+nb_pre_frames): + for k in range(self.nb_subframes): + pos = n*self.frame_size + k*self.subframe_size + preal = phase_real[:, pos:pos+self.subframe_size] + pimag = phase_imag[:, pos:pos+self.subframe_size] + phase = torch.cat([preal, pimag], 1) + #print("now: ", preal.shape, prev.shape, sig_in.shape) + pitch = period[:, 3+n] + out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states) + + if n < nb_pre_frames: + out = pre[:, pos:pos+self.subframe_size] + exc_mem[:,-self.subframe_size:] = out + else: + sig = torch.cat([sig, out], 1) + + prev = out + states = [s.detach() for s in states] + return sig, states + diff --git a/dnn/torch/fargan/filters.py b/dnn/torch/fargan/filters.py new file mode 100644 index 00000000..4a4a86f8 --- /dev/null +++ b/dnn/torch/fargan/filters.py @@ -0,0 +1,46 @@ +import torch +from torch import nn +import torch.nn.functional as F +import math + +def toeplitz_from_filter(a): + device = a.device + L = a.size(-1) + size0 = (*(a.shape[:-1]), L, L+1) + size = (*(a.shape[:-1]), L, L) + rnge = torch.arange(0, L, dtype=torch.int64, device=device) + z = torch.tensor(0, device=device) + idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z) + a = torch.cat([a[...,:1]*0, a], -1) + #print(a) + a = a[...,None,:] + #print(idx) + a = torch.broadcast_to(a, size0) + idx = torch.broadcast_to(idx, size) + #print(idx) + return torch.gather(a, -1, idx) + +def filter_iir_response(a, N): + device = a.device + L = a.size(-1) + ar = a.flip(dims=(2,)) + size = (*(a.shape[:-1]), N) + R = torch.zeros(size, device=device) + R[:,:,0] = torch.ones((a.shape[:-1]), device=device) + for i in range(1, L): + R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1) + #R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i]) + for i in range(L, N): + R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1) + #R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i]) + return R + +if __name__ == '__main__': + #a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]]) + a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]]) + A = toeplitz_from_filter(a) + #print(A) + R = filter_iir_response(a, 5) + + RA = toeplitz_from_filter(R) + print(RA) diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py new file mode 100644 index 00000000..98a60ec6 --- /dev/null +++ b/dnn/torch/fargan/stft_loss.py @@ -0,0 +1,184 @@ +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F +import numpy as np +import torchaudio + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + + #x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False) + #real = x_stft[..., 0] + #imag = x_stft[..., 1] + + # (kan-bayashi): clamp is needed to avoid nan or inf + #return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80) + #return torch.clamp(torch.abs(x_stft), min=1e-7) + + x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) + return torch.clamp(torch.abs(x_stft), min=1e-7) + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super(SpectralConvergenceLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Spectral convergence loss value. + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self): + """Initilize los STFT magnitude loss module.""" + super(LogSTFTMagnitudeLoss, self).__init__() + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Log STFT magnitude loss value. + """ + #F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) + + #F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\ + #torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)) + + #y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0 + + #return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag)) + + #return F.l1_loss(y_mag, x_mag) + + error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))# + + #x = torch.log(x) + #y = torch.log(y) + #x = x.permute(0,2,1).contiguous() + #y = y.permute(0,2,1).contiguous() + + '''mean_x = torch.mean(x, dim=1, keepdim=True) + mean_y = torch.mean(y, dim=1, keepdim=True) + + var_x = torch.var(x, dim=1, keepdim=True) + var_y = torch.var(y, dim=1, keepdim=True) + + std_x = torch.std(x, dim=1, keepdim=True) + std_y = torch.std(y, dim=1, keepdim=True) + + x_minus_mean = x - mean_x + y_minus_mean = y - mean_y + + pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \ + (torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \ + torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7)) + + numerator = 2.0 * pearson_corr * std_x * std_y + denominator = var_x + var_y + (mean_y - mean_x)**2 + + ccc = numerator/denominator + + ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))''' + + return error_loss #+ ccc_loss#+ ccc_loss + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length).to(device) + self.spectral_convergenge_loss = SpectralConvergenceLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + + def __init__(self, + device, + fft_sizes=[2048, 1024, 512, 256, 128, 64], + hop_sizes=[512, 256, 128, 64, 32, 16], + win_lengths=[2048, 1024, 512, 256, 128, 64], + window="hann_window"): + + '''def __init__(self, + device, + fft_sizes=[2048, 1024, 512, 256, 128, 64], + hop_sizes=[256, 128, 64, 32, 16, 8], + win_lengths=[1024, 512, 256, 128, 64, 32], + window="hann_window"):''' + + '''def __init__(self, + device, + fft_sizes=[2560, 1280, 640, 320, 160, 80], + hop_sizes=[640, 320, 160, 80, 40, 20], + win_lengths=[2560, 1280, 640, 320, 160, 80], + window="hann_window"):''' + + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(device, fs, ss, wl, window)] + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + #mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss #mag_loss #+ diff --git a/dnn/torch/fargan/test_fargan.py b/dnn/torch/fargan/test_fargan.py new file mode 100644 index 00000000..8a6d2c25 --- /dev/null +++ b/dnn/torch/fargan/test_fargan.py @@ -0,0 +1,107 @@ +import os +import argparse +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +import tqdm + +import fargan +from dataset import FARGANDataset + +nb_features = 36 +nb_used_features = 20 + +parser = argparse.ArgumentParser() + +parser.add_argument('model', type=str, help='CELPNet model') +parser.add_argument('features', type=str, help='path to feature file in .f32 format') +parser.add_argument('output', type=str, help='path to output file (16-bit PCM)') + +parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) + + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) + +args = parser.parse_args() + +if args.cuda_visible_devices != None: + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices + + +features_file = args.features +signal_file = args.output + + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +checkpoint = torch.load(args.model, map_location='cpu') + +model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) + + +model.load_state_dict(checkpoint['state_dict'], strict=False) + +features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features)) +lpc = features[:,4-1:-1,nb_used_features:] +features = features[:, :, :nb_used_features] +periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int') + +nb_frames = features.shape[1] +#nb_frames = 1000 +gamma = checkpoint['model_kwargs']['gamma'] + +def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)): + + out = np.zeros_like(frame) + filt = np.flip(filt) + + inp = frame[:] + + + for i in range(0, inp.shape[0]): + + s = inp[i] - np.dot(buffer*weighting_vector, filt) + + buffer[0] = s + + buffer = np.roll(buffer, -1) + + out[i] = s + + return out + +def inverse_perceptual_weighting (pw_signal, filters, weighting_vector): + + #inverse perceptual weighting= H_preemph / W(z/gamma) + + signal = np.zeros_like(pw_signal) + buffer = np.zeros(16) + num_frames = pw_signal.shape[0] //160 + assert num_frames == filters.shape[0] + for frame_idx in range(0, num_frames): + + in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:] + out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector) + signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:] + buffer[:] = out_sig_frame[-16:] + return signal + + + +if __name__ == '__main__': + model.to(device) + features = torch.tensor(features).to(device) + #lpc = torch.tensor(lpc).to(device) + periods = torch.tensor(periods).to(device) + + sig, _ = model(features, periods, nb_frames - 4) + weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) + sig = sig.detach().numpy().flatten() + sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector) + + pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16') + pcm.tofile(signal_file) diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py new file mode 100644 index 00000000..117518b6 --- /dev/null +++ b/dnn/torch/fargan/train_fargan.py @@ -0,0 +1,155 @@ +import os +import argparse +import random +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +import tqdm + +import fargan +from dataset import FARGANDataset +from stft_loss import * + +parser = argparse.ArgumentParser() + +parser.add_argument('features', type=str, help='path to feature file in .f32 format') +parser.add_argument('signal', type=str, help='path to signal file in .s16 format') +parser.add_argument('output', type=str, help='path to output folder') + +parser.add_argument('--suffix', type=str, help="model name suffix", default="") +parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) + + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) +model_group.add_argument('--has-gain', action='store_true', help="use gain-shape network") +model_group.add_argument('--passthrough-size', type=int, help="state passing through in addition to audio, default: 0", default=0) +model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9) + +training_group = parser.add_argument_group(title="training parameters") +training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512) +training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3) +training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20) +training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15) +training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4) +training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) + +args = parser.parse_args() + +if args.cuda_visible_devices != None: + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices + +# checkpoints +checkpoint_dir = os.path.join(args.output, 'checkpoints') +checkpoint = dict() +os.makedirs(checkpoint_dir, exist_ok=True) + + +# training parameters +batch_size = args.batch_size +lr = args.lr +epochs = args.epochs +sequence_length = args.sequence_length +lr_decay = args.lr_decay + +adam_betas = [0.9, 0.99] +adam_eps = 1e-8 +features_file = args.features +signal_file = args.signal + +# model parameters +cond_size = args.cond_size + + +checkpoint['batch_size'] = batch_size +checkpoint['lr'] = lr +checkpoint['lr_decay'] = lr_decay +checkpoint['epochs'] = epochs +checkpoint['sequence_length'] = sequence_length +checkpoint['adam_betas'] = adam_betas + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +checkpoint['model_args'] = () +checkpoint['model_kwargs'] = {'cond_size': cond_size, 'has_gain': args.has_gain, 'passthrough_size': args.passthrough_size, 'gamma': args.gamma} +print(checkpoint['model_kwargs']) +model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs']) + +#model = fargan.FARGAN() +#model = nn.DataParallel(model) + +if type(args.initial_checkpoint) != type(None): + checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') + model.load_state_dict(checkpoint['state_dict'], strict=False) + +checkpoint['state_dict'] = model.state_dict() + + +dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) + + +optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) + + +# learning rate scheduler +scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) + +states = None + +spect_loss = MultiResolutionSTFTLoss(device).to(device) + +if __name__ == '__main__': + model.to(device) + + for epoch in range(1, epochs + 1): + + running_specc = 0 + running_cont_loss = 0 + running_loss = 0 + + print(f"training epoch {epoch}...") + with tqdm.tqdm(dataloader, unit='batch') as tepoch: + for i, (features, periods, target, lpc) in enumerate(tepoch): + optimizer.zero_grad() + features = features.to(device) + lpc = lpc.to(device) + periods = periods.to(device) + target = target.to(device) + target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma) + + #nb_pre = random.randrange(1, 6) + nb_pre = 2 + pre = target[:, :nb_pre*160] + sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None) + sig = torch.cat([pre, sig], -1) + + cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80]) + specc_loss = spect_loss(sig, target.detach()) + loss = .2*cont_loss + specc_loss + + loss.backward() + optimizer.step() + + #model.clip_weights() + + scheduler.step() + + running_specc += specc_loss.detach().cpu().item() + running_cont_loss += cont_loss.detach().cpu().item() + + running_loss += loss.detach().cpu().item() + tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}", + cont_loss=f"{running_cont_loss/(i+1):8.5f}", + specc=f"{running_specc/(i+1):8.5f}", + ) + + # save checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth') + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = running_loss / len(dataloader) + checkpoint['epoch'] = epoch + torch.save(checkpoint, checkpoint_path) |