Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-10-10 07:51:57 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-10-10 07:51:57 +0300
commit9e76a7bfb835ebe7cb97cf24da98462b78de0207 (patch)
treedd209e366796acd6130ed59d219daae1a0fbfb4c
parentd1c5b32add990473df84e42a8db64851b2dd65f6 (diff)
update fargan to match version 45
-rw-r--r--dnn/torch/fargan/adv_train_fargan.py25
-rw-r--r--dnn/torch/fargan/dataset.py8
-rw-r--r--dnn/torch/fargan/fargan.py156
-rw-r--r--dnn/torch/fargan/rc.py29
-rw-r--r--dnn/torch/fargan/stft_loss.py14
-rw-r--r--dnn/torch/fargan/test_fargan.py27
-rw-r--r--dnn/torch/fargan/train_fargan.py17
7 files changed, 194 insertions, 82 deletions
diff --git a/dnn/torch/fargan/adv_train_fargan.py b/dnn/torch/fargan/adv_train_fargan.py
index 23f5b2d0..94817cbc 100644
--- a/dnn/torch/fargan/adv_train_fargan.py
+++ b/dnn/torch/fargan/adv_train_fargan.py
@@ -132,6 +132,10 @@ states = None
spect_loss = MultiResolutionSTFTLoss(device).to(device)
+for param in model.parameters():
+ param.requires_grad = False
+
+batch_count = 0
if __name__ == '__main__':
model.to(device)
disc.to(device)
@@ -153,22 +157,28 @@ if __name__ == '__main__':
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
+ if epoch == 1 and i == 400:
+ for param in model.parameters():
+ param.requires_grad = True
+
optimizer.zero_grad()
features = features.to(device)
- lpc = lpc.to(device)
+ #lpc = lpc.to(device)
+ #lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
+ #lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if True:
target = target[:, :sequence_length*160]
- lpc = lpc[:,:sequence_length,:]
+ #lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
- lpc=lpc[::2,:]
+ #lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
- target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
+ #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
@@ -208,7 +218,7 @@ if __name__ == '__main__':
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(output, target.detach())
- reg_loss = args.reg_weight * (.00*cont_loss + specc_loss)
+ reg_loss = (.00*cont_loss + specc_loss)
loss_gen = 0
for scale in scores_gen:
@@ -216,7 +226,8 @@ if __name__ == '__main__':
feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
- gen_loss = reg_loss + feat_loss + loss_gen
+ reg_weight = args.reg_weight + 15./(1 + (batch_count/7600.))
+ gen_loss = reg_weight * reg_loss + feat_loss + loss_gen
model.zero_grad()
@@ -238,12 +249,14 @@ if __name__ == '__main__':
tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
+ reg_weight=f"{reg_weight:8.5f}",
gen_loss=f"{running_gen_loss/(i+1):8.5f}",
disc_loss=f"{running_disc_loss/(i+1):8.5f}",
fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
reg_loss=f"{running_reg_loss/(i+1):8.5f}",
wc = f"{running_wc/(i+1):8.5f}",
)
+ batch_count = batch_count + 1
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')
diff --git a/dnn/torch/fargan/dataset.py b/dnn/torch/fargan/dataset.py
index 6195c6af..2dfbb0b5 100644
--- a/dnn/torch/fargan/dataset.py
+++ b/dnn/torch/fargan/dataset.py
@@ -1,5 +1,6 @@
import torch
import numpy as np
+import fargan
class FARGANDataset(torch.utils.data.Dataset):
def __init__(self,
@@ -34,7 +35,8 @@ class FARGANDataset(torch.utils.data.Dataset):
sizeof = self.features.strides[-1]
self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
- self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
+ #self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
+ self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
self.lpc = self.features[:, :, self.nb_used_features:]
self.features = self.features[:, :, :self.nb_used_features]
@@ -51,5 +53,9 @@ class FARGANDataset(torch.utils.data.Dataset):
lpc = self.lpc[index, 4:, :].copy()
data = self.data[index, :].copy().astype(np.float32) / 2**15
periods = self.periods[index, :].copy()
+ #lpc = lpc*(self.gamma**np.arange(1,17))
+ #lpc=lpc[None,:,:]
+ #lpc = fargan.interp_lpc(lpc, 4)
+ #lpc=lpc[0,:,:]
return features, periods, data, lpc
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index e9cc687a..65f0a97b 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -4,6 +4,8 @@ from torch import nn
import torch.nn.functional as F
import filters
from torch.nn.utils import weight_norm
+#from convert_lsp import lpc_to_lsp, lsp_to_lpc
+from rc import lpc2rc, rc2lpc
Fs = 16000
@@ -27,6 +29,27 @@ def sig_loss(y_true, y_pred):
p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
return torch.mean(1.-torch.sum(p*t, dim=-1))
+def interp_lpc(lpc, factor):
+ #print(lpc.shape)
+ #f = (np.arange(factor)+.5*((factor+1)%2))/factor
+ lsp = torch.atanh(lpc2rc(lpc))
+ #print("lsp0:")
+ #print(lsp)
+ shape = lsp.shape
+ #print("shape is", shape)
+ shape = (shape[0], shape[1]*factor, shape[2])
+ interp_lsp = torch.zeros(shape, device=lpc.device)
+ for k in range(factor):
+ f = (k+.5*((factor+1)%2))/factor
+ interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
+ interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
+ for k in range(factor//2):
+ interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
+ for k in range((factor+1)//2):
+ interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
+ #print("lsp:")
+ #print(interp_lsp)
+ return rc2lpc(torch.tanh(interp_lsp))
def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
device = x.device
@@ -39,9 +62,9 @@ def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
out = torch.zeros((batch_size, 0), device=device)
- if gamma is not None:
- bw = gamma**(torch.arange(1, 17, device=device))
- lpc = lpc*bw[None,None,:]
+ #if gamma is not None:
+ # bw = gamma**(torch.arange(1, 17, device=device))
+ # lpc = lpc*bw[None,None,:]
ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
a = torch.cat([ones, lpc], -1)
@@ -127,30 +150,34 @@ class FWConv(nn.Module):
out = self.glu(torch.tanh(self.conv(xcat)))
return out, xcat[:,self.in_size:]
+def n(x):
+ return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
+
class FARGANCond(nn.Module):
- def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
+ def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
super(FARGANCond, self).__init__()
self.feature_dim = feature_dim
self.cond_size = cond_size
- self.pembed = nn.Embedding(256, pembed_dims)
- self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
- self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
- self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
- self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False)
+ self.pembed = nn.Embedding(224, pembed_dims)
+ self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
+ self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
+ self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False)
self.apply(init_weights)
+ nb_params = sum(p.numel() for p in self.parameters())
+ print(f"cond model: {nb_params} weights")
def forward(self, features, period):
- p = self.pembed(period)
+ p = self.pembed(period-32)
features = torch.cat((features, p), -1)
tmp = torch.tanh(self.fdense1(features))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fconv1(tmp))
tmp = torch.tanh(self.fconv2(tmp))
tmp = tmp.permute(0, 2, 1)
- tmp = torch.tanh(self.fdense2(tmp))
+ #tmp = torch.tanh(self.fdense2(tmp))
return tmp
class FARGANSub(nn.Module):
@@ -160,70 +187,87 @@ class FARGANSub(nn.Module):
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.cond_size = cond_size
+ self.cond_gain_dense = nn.Linear(80, 1)
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
- self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
- self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
- self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
- self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
- self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+ self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
+ self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+ self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False)
+ self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
self.dense1_glu = GLU(self.cond_size)
- self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
- self.gru2_glu = GLU(self.cond_size)
- self.gru3_glu = GLU(self.cond_size)
- self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
+ self.gru2_glu = GLU(128)
+ self.gru3_glu = GLU(128)
+ self.skip_glu = GLU(self.cond_size)
+ #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
- self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False)
- self.gain_dense_out = nn.Linear(4*self.cond_size, 1)
+ self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+ self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
+ self.gain_dense_out = nn.Linear(self.cond_size, 4)
self.apply(init_weights)
+ nb_params = sum(p.numel() for p in self.parameters())
+ print(f"subframe model: {nb_params} weights")
- def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
+ def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)
- dump_signal(prev, 'prev_in.f32')
-
- idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254)
+ cond = n(cond)
+ dump_signal(gain, 'gain0.f32')
+ gain = torch.exp(self.cond_gain_dense(cond))
+ dump_signal(gain, 'gain1.f32')
+ idx = 256-period[:,None]
rng = torch.arange(self.subframe_size+4, device=device)
idx = idx + rng[None,:] - 2
+ mask = idx >= 256
+ idx = idx - mask*period[:,None]
pred = torch.gather(exc_mem, 1, idx)
- pred = pred/(1e-5+gain)
+ pred = n(pred/(1e-5+gain))
- prev = prev/(1e-5+gain)
+ prev = exc_mem[:,-self.subframe_size:]
+ dump_signal(prev, 'prev_in.f32')
+ prev = n(prev/(1e-5+gain))
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
- tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
+ tmp = torch.cat((cond, pred, prev), 1)
+ #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+ fpitch = pred[:,2:-2]
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
- dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
- gru1_state = self.gru1(dense2_out, states[0])
- gru1_out = self.gru1_glu(gru1_state)
- gru2_state = self.gru2(gru1_out, states[1])
- gru2_out = self.gru2_glu(gru2_state)
- gru3_state = self.gru3(gru2_out, states[2])
- gru3_out = self.gru3_glu(gru3_state)
- gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
- sig_out = torch.tanh(self.sig_dense_out(gru3_out))
+ fwc0_out = n(fwc0_out)
+ pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
+
+ gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
+ gru1_out = self.gru1_glu(n(gru1_state))
+ gru1_out = n(gru1_out)
+ gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
+ gru2_out = self.gru2_glu(n(gru2_state))
+ gru2_out = n(gru2_out)
+ gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
+ gru3_out = self.gru3_glu(n(gru3_state))
+ gru3_out = n(gru3_out)
+ gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
+ skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
+ skip_out = self.skip_glu(n(skip_out))
+ sig_out = torch.tanh(self.sig_dense_out(skip_out))
dump_signal(sig_out, 'exc_out.f32')
- taps = self.ptaps_dense(gru3_out)
- taps = .2*taps + torch.exp(taps)
- taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
- dump_signal(taps, 'taps.f32')
- #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
- fpitch = pred[:,2:-2]
+ #taps = self.ptaps_dense(gru3_out)
+ #taps = .2*taps + torch.exp(taps)
+ #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
+ #dump_signal(taps, 'taps.f32')
- pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
dump_signal(pitch_gain, 'pgain.f32')
- sig_out = (sig_out + pitch_gain*fpitch) * gain
+ #sig_out = (sig_out + pitch_gain*fpitch) * gain
+ sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
+ prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
dump_signal(sig_out, 'sig_out.f32')
- return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
+ return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@@ -242,37 +286,30 @@ class FARGAN(nn.Module):
device = features.device
batch_size = features.size(0)
- phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
- #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
-
- prev = torch.zeros(batch_size, self.subframe_size, device=device)
+ prev = torch.zeros(batch_size, 256, device=device)
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
states = (
torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device)
+ torch.zeros(batch_size, 128, device=device),
+ torch.zeros(batch_size, 128, device=device),
+ torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
)
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
if pre is not None:
- prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
start = 1 if nb_pre_frames>0 else 0
for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
- preal = phase_real[:, pos:pos+self.subframe_size]
- pimag = phase_imag[:, pos:pos+self.subframe_size]
- phase = torch.cat([preal, pimag], 1)
#print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n]
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
#gain = gain[:,:,None]
- out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain)
+ out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
@@ -280,6 +317,5 @@ class FARGAN(nn.Module):
else:
sig = torch.cat([sig, out], 1)
- prev = out
states = [s.detach() for s in states]
return sig, states
diff --git a/dnn/torch/fargan/rc.py b/dnn/torch/fargan/rc.py
new file mode 100644
index 00000000..7f67016a
--- /dev/null
+++ b/dnn/torch/fargan/rc.py
@@ -0,0 +1,29 @@
+import torch
+
+
+
+def rc2lpc(rc):
+ order = rc.shape[-1]
+ lpc=rc[...,0:1]
+ for i in range(1, order):
+ lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1)
+ #print("to:", lpc)
+ return lpc
+
+def lpc2rc(lpc):
+ order = lpc.shape[-1]
+ rc = lpc[...,-1:]
+ for i in range(order-1, 0, -1):
+ ki = lpc[...,-1:]
+ lpc = lpc[...,:-1]
+ lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki)
+ rc = torch.cat([lpc[...,-1:] , rc], -1)
+ return rc
+
+if __name__ == "__main__":
+ rc = torch.tensor([[.5, -.5, .6, -.6]])
+ print(rc)
+ lpc = rc2lpc(rc)
+ print(lpc)
+ rc2 = lpc2rc(lpc)
+ print(rc2)
diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py
index accf2f4a..8c904054 100644
--- a/dnn/torch/fargan/stft_loss.py
+++ b/dnn/torch/fargan/stft_loss.py
@@ -44,7 +44,9 @@ class SpectralConvergenceLoss(torch.nn.Module):
Returns:
Tensor: Spectral convergence loss value.
"""
- return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+ x_mag = torch.sqrt(x_mag)
+ y_mag = torch.sqrt(y_mag)
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
@@ -136,26 +138,26 @@ class STFTLoss(torch.nn.Module):
class MultiResolutionSTFTLoss(torch.nn.Module):
- def __init__(self,
+ '''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[512, 256, 128, 64, 32, 16],
win_lengths=[2048, 1024, 512, 256, 128, 64],
- window="hann_window"):
+ window="hann_window"):'''
- '''def __init__(self,
+ '''def __init__(self,
device,
fft_sizes=[2048, 1024, 512, 256, 128, 64],
hop_sizes=[256, 128, 64, 32, 16, 8],
win_lengths=[1024, 512, 256, 128, 64, 32],
window="hann_window"):'''
- '''def __init__(self,
+ def __init__(self,
device,
fft_sizes=[2560, 1280, 640, 320, 160, 80],
hop_sizes=[640, 320, 160, 80, 40, 20],
win_lengths=[2560, 1280, 640, 320, 160, 80],
- window="hann_window"):'''
+ window="hann_window"):
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
diff --git a/dnn/torch/fargan/test_fargan.py b/dnn/torch/fargan/test_fargan.py
index 76e1f854..d3aeb613 100644
--- a/dnn/torch/fargan/test_fargan.py
+++ b/dnn/torch/fargan/test_fargan.py
@@ -48,7 +48,9 @@ model.load_state_dict(checkpoint['state_dict'], strict=False)
features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
lpc = features[:,4-1:-1,nb_used_features:]
features = features[:, :, :nb_used_features]
-periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
+#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
+periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
+
nb_frames = features.shape[1]
#nb_frames = 1000
@@ -90,18 +92,37 @@ def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
buffer[:] = out_sig_frame[-16:]
return signal
+def inverse_perceptual_weighting40 (pw_signal, filters):
+
+ #inverse perceptual weighting= H_preemph / W(z/gamma)
+
+ signal = np.zeros_like(pw_signal)
+ buffer = np.zeros(16)
+ num_frames = pw_signal.shape[0] //40
+ assert num_frames == filters.shape[0]
+ for frame_idx in range(0, num_frames):
+ in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
+ out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
+ signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
+ buffer[:] = out_sig_frame[-16:]
+ return signal
+from scipy.signal import lfilter
if __name__ == '__main__':
model.to(device)
features = torch.tensor(features).to(device)
#lpc = torch.tensor(lpc).to(device)
periods = torch.tensor(periods).to(device)
+ weighting = gamma**np.arange(1, 17)
+ lpc = lpc*weighting
+ lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
sig, _ = model(features, periods, nb_frames - 4)
- weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
+ #weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten()
- sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
+ sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
+ #sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file)
diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py
index 4ab20045..dc6feb2d 100644
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -114,20 +114,25 @@ if __name__ == '__main__':
for i, (features, periods, target, lpc) in enumerate(tepoch):
optimizer.zero_grad()
features = features.to(device)
- lpc = lpc.to(device)
+ #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
+ #print("interp size", lpc.shape)
+ #lpc = lpc.to(device)
+ #lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
+ #lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160]
- lpc = lpc[:,:sequence_length,:]
+ #lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
- lpc=lpc[::2,:]
+ #lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
- target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
+ #print(target.shape, lpc.shape)
+ #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#nb_pre = random.randrange(1, 6)
nb_pre = 2
@@ -135,9 +140,9 @@ if __name__ == '__main__':
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
- cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+ cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
specc_loss = spect_loss(sig, target.detach())
- loss = .00*cont_loss + specc_loss
+ loss = .03*cont_loss + specc_loss
loss.backward()
optimizer.step()