Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-22 19:57:04 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-28 06:58:23 +0300
commitb53d0ca7ef9f28b323b6cc2c98447d56824b9e4f (patch)
tree5c1fc6d80609fd1f3b4eea46fccb12cb45ba2d2e
parent25c65a0c0b9ce8282cfc713a7c0581664c93ab18 (diff)
version 36
-rw-r--r--dnn/torch/fargan/fargan.py42
-rw-r--r--dnn/torch/fargan/stft_loss.py4
-rw-r--r--dnn/torch/fargan/train_fargan.py4
3 files changed, 29 insertions, 21 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index e9cc687a..7f405e87 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -164,19 +164,21 @@ class FARGANSub(nn.Module):
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
- self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
- self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
- self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+ self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+ self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+ self.gru3 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.dense1_glu = GLU(self.cond_size)
self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
self.gru2_glu = GLU(self.cond_size)
self.gru3_glu = GLU(self.cond_size)
- self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
+ self.skip_glu = GLU(self.cond_size)
+ #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
- self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False)
- self.gain_dense_out = nn.Linear(4*self.cond_size, 1)
+ self.skip_dense = nn.Linear(4*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+ self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
+ self.gain_dense_out = nn.Linear(self.cond_size, 4)
self.apply(init_weights)
@@ -198,29 +200,33 @@ class FARGANSub(nn.Module):
dump_signal(exc_mem, 'exc_mem.f32')
tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
+ #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+ fpitch = pred[:,2:-2]
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
- gru1_state = self.gru1(dense2_out, states[0])
+ pitch_gain = torch.sigmoid(self.gain_dense_out(dense2_out))
+
+ gru1_state = self.gru1(torch.cat([dense2_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_out = self.gru1_glu(gru1_state)
- gru2_state = self.gru2(gru1_out, states[1])
+ gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
gru2_out = self.gru2_glu(gru2_state)
- gru3_state = self.gru3(gru2_out, states[2])
+ gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
gru3_out = self.gru3_glu(gru3_state)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
- sig_out = torch.tanh(self.sig_dense_out(gru3_out))
+ skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
+ skip_out = self.skip_glu(skip_out)
+ sig_out = torch.tanh(self.sig_dense_out(skip_out))
dump_signal(sig_out, 'exc_out.f32')
- taps = self.ptaps_dense(gru3_out)
- taps = .2*taps + torch.exp(taps)
- taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
- dump_signal(taps, 'taps.f32')
- #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
- fpitch = pred[:,2:-2]
+ #taps = self.ptaps_dense(gru3_out)
+ #taps = .2*taps + torch.exp(taps)
+ #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
+ #dump_signal(taps, 'taps.f32')
- pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
dump_signal(pitch_gain, 'pgain.f32')
- sig_out = (sig_out + pitch_gain*fpitch) * gain
+ #sig_out = (sig_out + pitch_gain*fpitch) * gain
+ sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
diff --git a/dnn/torch/fargan/stft_loss.py b/dnn/torch/fargan/stft_loss.py
index accf2f4a..542d0e9b 100644
--- a/dnn/torch/fargan/stft_loss.py
+++ b/dnn/torch/fargan/stft_loss.py
@@ -44,7 +44,9 @@ class SpectralConvergenceLoss(torch.nn.Module):
Returns:
Tensor: Spectral convergence loss value.
"""
- return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+ x_mag = torch.sqrt(x_mag)
+ y_mag = torch.sqrt(y_mag)
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py
index 4ab20045..be2a1a66 100644
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -135,9 +135,9 @@ if __name__ == '__main__':
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
- cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+ cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
specc_loss = spect_loss(sig, target.detach())
- loss = .00*cont_loss + specc_loss
+ loss = .03*cont_loss + specc_loss
loss.backward()
optimizer.step()