Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-08-31 01:36:09 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-13 05:50:46 +0300
commit1b13f6313e8413056f6d9f1f15fa994d0dff7a57 (patch)
tree06ac09e827d72576069e9905ca106d4f6c5d7825
parent4f4b6242099998d7acf89e17c287dc7f605af607 (diff)
FARGAN initial commit in Opus
Copied/adapted from LPCNet repo
-rw-r--r--dnn/torch/fargan/dataset.py52
-rw-r--r--dnn/torch/fargan/fargan.py260
-rw-r--r--dnn/torch/fargan/filters.py46
-rw-r--r--dnn/torch/fargan/stft_loss.py184
-rw-r--r--dnn/torch/fargan/test_fargan.py107
-rw-r--r--dnn/torch/fargan/train_fargan.py155
6 files changed, 804 insertions, 0 deletions
diff --git a/dnn/torch/fargan/dataset.py b/dnn/torch/fargan/dataset.py
new file mode 100644
index 00000000..f33bed36
--- /dev/null
+++ b/dnn/torch/fargan/dataset.py
@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+
+class FARGANDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ feature_file,
+ signal_file,
+ frame_size=160,
+ sequence_length=15,
+ lookahead=1,
+ nb_used_features=20,
+ nb_features=36):
+
+ self.frame_size = frame_size
+ self.sequence_length = sequence_length
+ self.lookahead = lookahead
+ self.nb_features = nb_features
+ self.nb_used_features = nb_used_features
+ pcm_chunk_size = self.frame_size*self.sequence_length
+
+ self.data = np.memmap(signal_file, dtype='int16', mode='r')
+ #self.data = self.data[1::2]
+ self.nb_sequences = len(self.data)//(pcm_chunk_size)-1
+ self.data = self.data[(4-self.lookahead)*self.frame_size:]
+ self.data = self.data[:self.nb_sequences*pcm_chunk_size]
+
+
+ self.data = np.reshape(self.data, (self.nb_sequences, pcm_chunk_size))
+
+ self.features = np.reshape(np.memmap(feature_file, dtype='float32', mode='r'), (-1, nb_features))
+ sizeof = self.features.strides[-1]
+ self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length+4, nb_features),
+ strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
+ self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
+
+ self.lpc = self.features[:, :, self.nb_used_features:]
+ self.features = self.features[:, :, :self.nb_used_features]
+ print("lpc_size:", self.lpc.shape)
+
+ def __len__(self):
+ return self.nb_sequences
+
+ def __getitem__(self, index):
+ features = self.features[index, :, :].copy()
+ if self.lookahead != 0:
+ lpc = self.lpc[index, 4-self.lookahead:-self.lookahead, :].copy()
+ else:
+ lpc = self.lpc[index, 4:, :].copy()
+ data = self.data[index, :].copy().astype(np.float32) / 2**15
+ periods = self.periods[index, :].copy()
+
+ return features, periods, data, lpc
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
new file mode 100644
index 00000000..987cc8e5
--- /dev/null
+++ b/dnn/torch/fargan/fargan.py
@@ -0,0 +1,260 @@
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import filters
+from torch.nn.utils import weight_norm
+
+Fs = 16000
+
+fid_dict = {}
+def dump_signal(x, filename):
+ return
+ if filename in fid_dict:
+ fid = fid_dict[filename]
+ else:
+ fid = open(filename, "w")
+ fid_dict[filename] = fid
+ x = x.detach().numpy().astype('float32')
+ x.tofile(fid)
+
+
+def sig_l1(y_true, y_pred):
+ return torch.mean(abs(y_true-y_pred))/torch.mean(abs(y_true))
+
+def sig_loss(y_true, y_pred):
+ t = y_true/(1e-15+torch.norm(y_true, dim=-1, p=2, keepdim=True))
+ p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
+ return torch.mean(1.-torch.sum(p*t, dim=-1))
+
+
+def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
+ device = x.device
+ batch_size = lpc.size(0)
+
+ nb_frames = lpc.shape[1]
+
+
+ sig = torch.zeros(batch_size, subframe_size+16, device=device)
+ x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
+ out = torch.zeros((batch_size, 0), device=device)
+
+ if gamma is not None:
+ bw = gamma**(torch.arange(1, 17, device=device))
+ lpc = lpc*bw[None,None,:]
+ ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
+ zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
+ a = torch.cat([ones, lpc], -1)
+ a_big = torch.cat([a, zeros], -1)
+ fir_mat_big = filters.toeplitz_from_filter(a_big)
+
+ #print(a_big[:,0,:])
+ for n in range(nb_frames):
+ for k in range(nb_subframes):
+
+ sig = torch.cat([sig[:,subframe_size:], x[:,n*nb_subframes + k, :]], 1)
+ exc = torch.bmm(fir_mat_big[:,n,:,:], sig[:,:,None])
+ out = torch.cat([out, exc[:,-subframe_size:,0]], 1)
+
+ return out
+
+
+# weight initialization and clipping
+def init_weights(module):
+ if isinstance(module, nn.GRU):
+ for p in module.named_parameters():
+ if p[0].startswith('weight_hh_'):
+ nn.init.orthogonal_(p[1])
+
+def gen_phase_embedding(periods, frame_size):
+ device = periods.device
+ batch_size = periods.size(0)
+ nb_frames = periods.size(1)
+ w0 = 2*torch.pi/periods
+ w0_shift = torch.cat([2*torch.pi*torch.rand((batch_size, 1), device=device)/frame_size, w0[:,:-1]], 1)
+ cum_phase = frame_size*torch.cumsum(w0_shift, 1)
+ fine_phase = w0[:,:,None]*torch.broadcast_to(torch.arange(frame_size, device=device), (batch_size, nb_frames, frame_size))
+ embed = torch.unsqueeze(cum_phase, 2) + fine_phase
+ embed = torch.reshape(embed, (batch_size, -1))
+ return torch.cos(embed), torch.sin(embed)
+
+class GLU(nn.Module):
+ def __init__(self, feat_size):
+ super(GLU, self).__init__()
+
+ torch.manual_seed(5)
+
+ self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
+
+ self.init_weights()
+
+ def init_weights(self):
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
+ or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
+ nn.init.orthogonal_(m.weight.data)
+
+ def forward(self, x):
+
+ out = x * torch.sigmoid(self.gate(x))
+
+ return out
+
+
+class FARGANCond(nn.Module):
+ def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
+ super(FARGANCond, self).__init__()
+
+ self.feature_dim = feature_dim
+ self.cond_size = cond_size
+
+ self.pembed = nn.Embedding(256, pembed_dims)
+ self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
+ self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
+ self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
+ self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
+
+ self.apply(init_weights)
+
+ def forward(self, features, period):
+ p = self.pembed(period)
+ features = torch.cat((features, p), -1)
+ tmp = torch.tanh(self.fdense1(features))
+ tmp = tmp.permute(0, 2, 1)
+ tmp = torch.tanh(self.fconv1(tmp))
+ tmp = torch.tanh(self.fconv2(tmp))
+ tmp = tmp.permute(0, 2, 1)
+ tmp = torch.tanh(self.fdense2(tmp))
+ return tmp
+
+class FARGANSub(nn.Module):
+ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_size=0, has_gain=False):
+ super(FARGANSub, self).__init__()
+
+ self.subframe_size = subframe_size
+ self.nb_subframes = nb_subframes
+ self.cond_size = cond_size
+ self.has_gain = has_gain
+ self.passthrough_size = passthrough_size
+
+ print("has_gain:", self.has_gain)
+ print("passthrough_size:", self.passthrough_size)
+
+ gain_param = 1 if self.has_gain else 0
+
+ self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False)
+ self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
+ self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+ self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+ self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+
+ self.dense1_glu = GLU(self.cond_size)
+ self.dense2_glu = GLU(self.cond_size)
+ self.gru1_glu = GLU(self.cond_size)
+ self.gru2_glu = GLU(self.cond_size)
+ self.gru3_glu = GLU(self.cond_size)
+
+ self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size+self.passthrough_size, bias=False)
+ if self.has_gain:
+ self.gain_dense_out = nn.Linear(self.cond_size, 1)
+
+
+ self.apply(init_weights)
+
+ def forward(self, cond, prev, exc_mem, phase, period, states):
+ device = exc_mem.device
+ #print(cond.shape, prev.shape)
+
+ dump_signal(prev, 'prev_in.f32')
+
+ idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None])
+ rng = torch.arange(self.subframe_size, device=device)
+ idx = idx + rng[None,:]
+ prev = torch.gather(exc_mem, 1, idx)
+ #prev = prev*0
+ dump_signal(prev, 'pitch_exc.f32')
+ dump_signal(exc_mem, 'exc_mem.f32')
+ if self.has_gain:
+ gain = torch.norm(prev, dim=1, p=2, keepdim=True)
+ prev = prev/(1e-5+gain)
+ prev = torch.cat([prev, torch.log(1e-5+gain)], 1)
+
+ passthrough = states[3]
+ tmp = torch.cat((cond, prev, passthrough, phase), 1)
+
+ tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
+ tmp = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
+ gru1_state = self.gru1(tmp, states[0])
+ gru2_state = self.gru2(self.gru1_glu(gru1_state), states[1])
+ gru3_state = self.gru3(self.gru2_glu(gru2_state), states[2])
+ gru3_out = self.gru3_glu(gru3_state)
+ sig_out = torch.tanh(self.sig_dense_out(gru3_out))
+ if self.passthrough_size != 0:
+ passthrough = sig_out[:,self.subframe_size:]
+ sig_out = sig_out[:,:self.subframe_size]
+ if self.has_gain:
+ out_gain = torch.exp(self.gain_dense_out(gru3_out))
+ sig_out = sig_out * out_gain
+ dump_signal(sig_out, 'exc_out.f32')
+ exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
+ dump_signal(sig_out, 'sig_out.f32')
+ return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
+
+class FARGAN(nn.Module):
+ def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
+ super(FARGAN, self).__init__()
+
+ self.subframe_size = subframe_size
+ self.nb_subframes = nb_subframes
+ self.frame_size = self.subframe_size*self.nb_subframes
+ self.feature_dim = feature_dim
+ self.cond_size = cond_size
+ self.has_gain = has_gain
+ self.passthrough_size = passthrough_size
+
+ self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
+ self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, has_gain=has_gain, passthrough_size=passthrough_size)
+
+ def forward(self, features, period, nb_frames, pre=None, states=None):
+ device = features.device
+ batch_size = features.size(0)
+
+ phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
+ #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
+
+ prev = torch.zeros(batch_size, self.subframe_size, device=device)
+ exc_mem = torch.zeros(batch_size, 256, device=device)
+ nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
+
+ if states is None:
+ states = (
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.passthrough_size, device=device)
+ )
+
+ sig = torch.zeros((batch_size, 0), device=device)
+ cond = self.cond_net(features, period)
+ passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
+ for n in range(nb_frames+nb_pre_frames):
+ for k in range(self.nb_subframes):
+ pos = n*self.frame_size + k*self.subframe_size
+ preal = phase_real[:, pos:pos+self.subframe_size]
+ pimag = phase_imag[:, pos:pos+self.subframe_size]
+ phase = torch.cat([preal, pimag], 1)
+ #print("now: ", preal.shape, prev.shape, sig_in.shape)
+ pitch = period[:, 3+n]
+ out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states)
+
+ if n < nb_pre_frames:
+ out = pre[:, pos:pos+self.subframe_size]
+ exc_mem[:,-self.subframe_size:] = out
+ else:
+ sig = torch.cat([sig, out], 1)
+
+ prev = out
+ states = [s.detach() for s in states]
+ return sig, states
+
diff --git a/dnn/torch/fargan/filters.py b/dnn/torch/fargan/filters.py
new file mode 100644
index 00000000..4a4a86f8
--- /dev/null
+++ b/dnn/torch/fargan/filters.py
@@ -0,0 +1,46 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import math
+
+def toeplitz_from_filter(a):
+ device = a.device
+ L = a.size(-1)
+ size0 = (*(a.shape[:-1]), L, L+1)
+ size = (*(a.shape[:-1]), L, L)
+ rnge = torch.arange(0, L, dtype=torch.int64, device=device)
+ z = torch.tensor(0, device=device)
+ idx = torch.maximum(rnge[:,None] - rnge[None,:] + 1, z)
+ a = torch.cat([a[...,:1]*0, a], -1)
+ #print(a)
+ a = a[...,None,:]
+ #print(idx)
+ a = torch.broadcast_to(a, size0)
+ idx = torch.broadcast_to(idx, size)
+ #print(idx)
+ return torch.gather(a, -1, idx)
+
+def filter_iir_response(a, N):
+ device = a.device
+ L = a.size(-1)
+ ar = a.flip(dims=(2,))
+ size = (*(a.shape[:-1]), N)
+ R = torch.zeros(size, device=device)
+ R[:,:,0] = torch.ones((a.shape[:-1]), device=device)
+ for i in range(1, L):
+ R[:,:,i] = - torch.sum(ar[:,:,L-i-1:-1] * R[:,:,:i], axis=-1)
+ #R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,L-i-1:-1], R[:,:,:i])
+ for i in range(L, N):
+ R[:,:,i] = - torch.sum(ar[:,:,:-1] * R[:,:,i-L+1:i], axis=-1)
+ #R[:,:,i] = - torch.einsum('ijk,ijk->ij', ar[:,:,:-1], R[:,:,i-L+1:i])
+ return R
+
+if __name__ == '__main__':
+ #a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]], [[1, .9, 0], [1, .8, 0]]])
+ a = torch.tensor([ [[1, -.9, 0.02], [1, -.8, .01]]])
+ A = toeplitz_from_filter(a)
+ #print(A)
+ R = filter_iir_response(a, 5)
+
+ RA = toeplitz_from_filter(R)
+ print(RA)
diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py
new file mode 100644
index 00000000..98a60ec6
--- /dev/null
+++ b/dnn/torch/fargan/stft_loss.py
@@ -0,0 +1,184 @@
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import torchaudio
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+ """Perform STFT and convert to magnitude spectrogram.
+ Args:
+ x (Tensor): Input signal tensor (B, T).
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length.
+ window (str): Window function type.
+ Returns:
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+ """
+
+ #x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
+ #real = x_stft[..., 0]
+ #imag = x_stft[..., 1]
+
+ # (kan-bayashi): clamp is needed to avoid nan or inf
+ #return torchaudio.functional.amplitude_to_DB(torch.abs(x_stft),db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)
+ #return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
+ return torch.clamp(torch.abs(x_stft), min=1e-7)
+
+class SpectralConvergenceLoss(torch.nn.Module):
+ """Spectral convergence loss module."""
+
+ def __init__(self):
+ """Initilize spectral convergence loss module."""
+ super(SpectralConvergenceLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ """
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+class LogSTFTMagnitudeLoss(torch.nn.Module):
+ """Log STFT magnitude loss module."""
+
+ def __init__(self):
+ """Initilize los STFT magnitude loss module."""
+ super(LogSTFTMagnitudeLoss, self).__init__()
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ Tensor: Log STFT magnitude loss value.
+ """
+ #F.l1_loss(torch.sqrt(y_mag), torch.sqrt(x_mag)) +
+ #F.l1_loss(torchaudio.functional.amplitude_to_DB(y_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80),\
+ #torchaudio.functional.amplitude_to_DB(x_mag,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80))
+
+ #y_mag[:,:y_mag.size(1)//2,:] = y_mag[:,:y_mag.size(1)//2,:] *0.0
+
+ #return F.l1_loss(torch.log(y_mag) + torch.sqrt(y_mag), torch.log(x_mag) + torch.sqrt(x_mag))
+
+ #return F.l1_loss(y_mag, x_mag)
+
+ error_loss = F.l1_loss(y, x) #+ F.l1_loss(torch.sqrt(y), torch.sqrt(x))#F.l1_loss(torch.log(y), torch.log(x))#
+
+ #x = torch.log(x)
+ #y = torch.log(y)
+ #x = x.permute(0,2,1).contiguous()
+ #y = y.permute(0,2,1).contiguous()
+
+ '''mean_x = torch.mean(x, dim=1, keepdim=True)
+ mean_y = torch.mean(y, dim=1, keepdim=True)
+
+ var_x = torch.var(x, dim=1, keepdim=True)
+ var_y = torch.var(y, dim=1, keepdim=True)
+
+ std_x = torch.std(x, dim=1, keepdim=True)
+ std_y = torch.std(y, dim=1, keepdim=True)
+
+ x_minus_mean = x - mean_x
+ y_minus_mean = y - mean_y
+
+ pearson_corr = torch.sum(x_minus_mean * y_minus_mean, dim=1, keepdim=True) / \
+ (torch.sqrt(torch.sum(x_minus_mean ** 2, dim=1, keepdim=True) + 1e-7) * \
+ torch.sqrt(torch.sum(y_minus_mean ** 2, dim=1, keepdim=True) + 1e-7))
+
+ numerator = 2.0 * pearson_corr * std_x * std_y
+ denominator = var_x + var_y + (mean_y - mean_x)**2
+
+ ccc = numerator/denominator
+
+ ccc_loss = F.l1_loss(1.0 - ccc, torch.zeros_like(ccc))'''
+
+ return error_loss #+ ccc_loss#+ ccc_loss
+
+
+class STFTLoss(torch.nn.Module):
+ """STFT loss module."""
+
+ def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
+ """Initialize STFT loss module."""
+ super(STFTLoss, self).__init__()
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.window = getattr(torch, window)(win_length).to(device)
+ self.spectral_convergenge_loss = SpectralConvergenceLoss()
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ Tensor: Log STFT magnitude loss value.
+ """
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+ return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+
+ def __init__(self,
+ device,
+ fft_sizes=[2048, 1024, 512, 256, 128, 64],
+ hop_sizes=[512, 256, 128, 64, 32, 16],
+ win_lengths=[2048, 1024, 512, 256, 128, 64],
+ window="hann_window"):
+
+ '''def __init__(self,
+ device,
+ fft_sizes=[2048, 1024, 512, 256, 128, 64],
+ hop_sizes=[256, 128, 64, 32, 16, 8],
+ win_lengths=[1024, 512, 256, 128, 64, 32],
+ window="hann_window"):'''
+
+ '''def __init__(self,
+ device,
+ fft_sizes=[2560, 1280, 640, 320, 160, 80],
+ hop_sizes=[640, 320, 160, 80, 40, 20],
+ win_lengths=[2560, 1280, 640, 320, 160, 80],
+ window="hann_window"):'''
+
+ super(MultiResolutionSTFTLoss, self).__init__()
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+ self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Multi resolution spectral convergence loss value.
+ Tensor: Multi resolution log STFT magnitude loss value.
+ """
+ sc_loss = 0.0
+ mag_loss = 0.0
+ for f in self.stft_losses:
+ sc_l, mag_l = f(x, y)
+ sc_loss += sc_l
+ #mag_loss += mag_l
+ sc_loss /= len(self.stft_losses)
+ mag_loss /= len(self.stft_losses)
+
+ return sc_loss #mag_loss #+
diff --git a/dnn/torch/fargan/test_fargan.py b/dnn/torch/fargan/test_fargan.py
new file mode 100644
index 00000000..8a6d2c25
--- /dev/null
+++ b/dnn/torch/fargan/test_fargan.py
@@ -0,0 +1,107 @@
+import os
+import argparse
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import tqdm
+
+import fargan
+from dataset import FARGANDataset
+
+nb_features = 36
+nb_used_features = 20
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('model', type=str, help='CELPNet model')
+parser.add_argument('features', type=str, help='path to feature file in .f32 format')
+parser.add_argument('output', type=str, help='path to output file (16-bit PCM)')
+
+parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
+
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+
+args = parser.parse_args()
+
+if args.cuda_visible_devices != None:
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
+
+
+features_file = args.features
+signal_file = args.output
+
+
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+checkpoint = torch.load(args.model, map_location='cpu')
+
+model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
+lpc = features[:,4-1:-1,nb_used_features:]
+features = features[:, :, :nb_used_features]
+periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
+
+nb_frames = features.shape[1]
+#nb_frames = 1000
+gamma = checkpoint['model_kwargs']['gamma']
+
+def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)):
+
+ out = np.zeros_like(frame)
+ filt = np.flip(filt)
+
+ inp = frame[:]
+
+
+ for i in range(0, inp.shape[0]):
+
+ s = inp[i] - np.dot(buffer*weighting_vector, filt)
+
+ buffer[0] = s
+
+ buffer = np.roll(buffer, -1)
+
+ out[i] = s
+
+ return out
+
+def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
+
+ #inverse perceptual weighting= H_preemph / W(z/gamma)
+
+ signal = np.zeros_like(pw_signal)
+ buffer = np.zeros(16)
+ num_frames = pw_signal.shape[0] //160
+ assert num_frames == filters.shape[0]
+ for frame_idx in range(0, num_frames):
+
+ in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:]
+ out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector)
+ signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:]
+ buffer[:] = out_sig_frame[-16:]
+ return signal
+
+
+
+if __name__ == '__main__':
+ model.to(device)
+ features = torch.tensor(features).to(device)
+ #lpc = torch.tensor(lpc).to(device)
+ periods = torch.tensor(periods).to(device)
+
+ sig, _ = model(features, periods, nb_frames - 4)
+ weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
+ sig = sig.detach().numpy().flatten()
+ sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
+
+ pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
+ pcm.tofile(signal_file)
diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py
new file mode 100644
index 00000000..117518b6
--- /dev/null
+++ b/dnn/torch/fargan/train_fargan.py
@@ -0,0 +1,155 @@
+import os
+import argparse
+import random
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import tqdm
+
+import fargan
+from dataset import FARGANDataset
+from stft_loss import *
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('features', type=str, help='path to feature file in .f32 format')
+parser.add_argument('signal', type=str, help='path to signal file in .s16 format')
+parser.add_argument('output', type=str, help='path to output folder')
+
+parser.add_argument('--suffix', type=str, help="model name suffix", default="")
+parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None)
+
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+model_group.add_argument('--has-gain', action='store_true', help="use gain-shape network")
+model_group.add_argument('--passthrough-size', type=int, help="state passing through in addition to audio, default: 0", default=0)
+model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
+
+training_group = parser.add_argument_group(title="training parameters")
+training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
+training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
+training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20)
+training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15)
+training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4)
+training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
+
+args = parser.parse_args()
+
+if args.cuda_visible_devices != None:
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
+
+# checkpoints
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+checkpoint = dict()
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+
+# training parameters
+batch_size = args.batch_size
+lr = args.lr
+epochs = args.epochs
+sequence_length = args.sequence_length
+lr_decay = args.lr_decay
+
+adam_betas = [0.9, 0.99]
+adam_eps = 1e-8
+features_file = args.features
+signal_file = args.signal
+
+# model parameters
+cond_size = args.cond_size
+
+
+checkpoint['batch_size'] = batch_size
+checkpoint['lr'] = lr
+checkpoint['lr_decay'] = lr_decay
+checkpoint['epochs'] = epochs
+checkpoint['sequence_length'] = sequence_length
+checkpoint['adam_betas'] = adam_betas
+
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+checkpoint['model_args'] = ()
+checkpoint['model_kwargs'] = {'cond_size': cond_size, 'has_gain': args.has_gain, 'passthrough_size': args.passthrough_size, 'gamma': args.gamma}
+print(checkpoint['model_kwargs'])
+model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+#model = fargan.FARGAN()
+#model = nn.DataParallel(model)
+
+if type(args.initial_checkpoint) != type(None):
+ checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+checkpoint['state_dict'] = model.state_dict()
+
+
+dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_length)
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
+
+
+optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
+
+
+# learning rate scheduler
+scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
+
+states = None
+
+spect_loss = MultiResolutionSTFTLoss(device).to(device)
+
+if __name__ == '__main__':
+ model.to(device)
+
+ for epoch in range(1, epochs + 1):
+
+ running_specc = 0
+ running_cont_loss = 0
+ running_loss = 0
+
+ print(f"training epoch {epoch}...")
+ with tqdm.tqdm(dataloader, unit='batch') as tepoch:
+ for i, (features, periods, target, lpc) in enumerate(tepoch):
+ optimizer.zero_grad()
+ features = features.to(device)
+ lpc = lpc.to(device)
+ periods = periods.to(device)
+ target = target.to(device)
+ target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
+
+ #nb_pre = random.randrange(1, 6)
+ nb_pre = 2
+ pre = target[:, :nb_pre*160]
+ sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
+ sig = torch.cat([pre, sig], -1)
+
+ cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+ specc_loss = spect_loss(sig, target.detach())
+ loss = .2*cont_loss + specc_loss
+
+ loss.backward()
+ optimizer.step()
+
+ #model.clip_weights()
+
+ scheduler.step()
+
+ running_specc += specc_loss.detach().cpu().item()
+ running_cont_loss += cont_loss.detach().cpu().item()
+
+ running_loss += loss.detach().cpu().item()
+ tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
+ cont_loss=f"{running_cont_loss/(i+1):8.5f}",
+ specc=f"{running_specc/(i+1):8.5f}",
+ )
+
+ # save checkpoint
+ checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
+ checkpoint['state_dict'] = model.state_dict()
+ checkpoint['loss'] = running_loss / len(dataloader)
+ checkpoint['epoch'] = epoch
+ torch.save(checkpoint, checkpoint_path)