Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-08-31 23:32:15 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-13 05:50:46 +0300
commit4f63743f8f814d608df4b249ac796f8edd15ade0 (patch)
treeeb64645cc41e223c28e38b72da81135858e312a2
parent1b13f6313e8413056f6d9f1f15fa994d0dff7a57 (diff)
explicit signal gain, explicit pitch predictor
-rw-r--r--dnn/torch/fargan/fargan.py16
-rw-r--r--dnn/torch/fargan/train_fargan.py4
2 files changed, 11 insertions, 9 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index 987cc8e5..d895eba0 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -9,7 +9,7 @@ Fs = 16000
fid_dict = {}
def dump_signal(x, filename):
- return
+ #return
if filename in fid_dict:
fid = fid_dict[filename]
else:
@@ -162,7 +162,7 @@ class FARGANSub(nn.Module):
self.apply(init_weights)
- def forward(self, cond, prev, exc_mem, phase, period, states):
+ def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)
@@ -176,7 +176,7 @@ class FARGANSub(nn.Module):
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
if self.has_gain:
- gain = torch.norm(prev, dim=1, p=2, keepdim=True)
+ #gain = torch.norm(prev, dim=1, p=2, keepdim=True)
prev = prev/(1e-5+gain)
prev = torch.cat([prev, torch.log(1e-5+gain)], 1)
@@ -193,10 +193,10 @@ class FARGANSub(nn.Module):
if self.passthrough_size != 0:
passthrough = sig_out[:,self.subframe_size:]
sig_out = sig_out[:,:self.subframe_size]
- if self.has_gain:
- out_gain = torch.exp(self.gain_dense_out(gru3_out))
- sig_out = sig_out * out_gain
dump_signal(sig_out, 'exc_out.f32')
+ if self.has_gain:
+ pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
+ sig_out = (sig_out + pitch_gain*prev[:,:-1]) * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
@@ -246,7 +246,9 @@ class FARGAN(nn.Module):
phase = torch.cat([preal, pimag], 1)
#print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n]
- out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states)
+ gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
+ #gain = gain[:,:,None]
+ out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py
index 117518b6..b58eabb3 100644
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -127,9 +127,9 @@ if __name__ == '__main__':
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)
- cont_loss = fargan.sig_l1(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+ cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(sig, target.detach())
- loss = .2*cont_loss + specc_loss
+ loss = .00*cont_loss + specc_loss
loss.backward()
optimizer.step()