Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-12 06:28:52 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-13 05:50:48 +0300
commit2f8b36d691a3802714a54abd7409234e41ec3e21 (patch)
tree0ab66685e1e0295671d13b5332fbfe28aaf176f7
parent72c5ea4129dc6473fb1d82ef3ec3d5714fab8cbd (diff)
Add conditioning interpolation, fwconv layer
-rw-r--r--dnn/torch/fargan/fargan.py56
-rw-r--r--dnn/torch/fargan/train_fargan.py2
2 files changed, 41 insertions, 17 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index 952c1b84..b532f268 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -101,6 +101,31 @@ class GLU(nn.Module):
return out
+class FWConv(nn.Module):
+ def __init__(self, in_size, out_size, kernel_size=3):
+ super(FWConv, self).__init__()
+
+ torch.manual_seed(5)
+
+ self.in_size = in_size
+ self.kernel_size = kernel_size
+ self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
+ self.glu = GLU(out_size)
+
+ self.init_weights()
+
+ def init_weights(self):
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
+ or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
+ nn.init.orthogonal_(m.weight.data)
+
+ def forward(self, x, state):
+ xcat = torch.cat((state, x), -1)
+ #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
+ out = self.glu(torch.tanh(self.conv(xcat)))
+ return out, xcat[:,self.in_size:]
class FARGANCond(nn.Module):
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
@@ -113,7 +138,7 @@ class FARGANCond(nn.Module):
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
- self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
+ self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False)
self.apply(init_weights)
@@ -138,9 +163,10 @@ class FARGANSub(nn.Module):
self.has_gain = has_gain
self.passthrough_size = passthrough_size
- print("has_gain:", self.has_gain)
- print("passthrough_size:", self.passthrough_size)
- self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
+ #print("has_gain:", self.has_gain)
+ #print("passthrough_size:", self.passthrough_size)
+ #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
+ self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
@@ -176,30 +202,26 @@ class FARGANSub(nn.Module):
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
- passthrough = states[3]
- tmp = torch.cat((cond, pred[:,2:-2], prev, passthrough, phase), 1)
+ tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
- tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
- dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
+ #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
+ fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
+ dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
gru1_state = self.gru1(dense2_out, states[0])
gru1_out = self.gru1_glu(gru1_state)
- #gru1_out = torch.cat([gru1_out, fpitch], 1)
gru2_state = self.gru2(gru1_out, states[1])
gru2_out = self.gru2_glu(gru2_state)
- #gru2_out = torch.cat([gru2_out, fpitch], 1)
gru3_state = self.gru3(gru2_out, states[2])
gru3_out = self.gru3_glu(gru3_state)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
- if self.passthrough_size != 0:
- passthrough = sig_out[:,self.subframe_size:]
- sig_out = sig_out[:,:self.subframe_size]
dump_signal(sig_out, 'exc_out.f32')
taps = self.ptaps_dense(gru3_out)
taps = .2*taps + torch.exp(taps)
taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
dump_signal(taps, 'taps.f32')
- fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+ #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+ fpitch = pred[:,2:-2]
if self.has_gain:
pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
@@ -207,7 +229,7 @@ class FARGANSub(nn.Module):
sig_out = (sig_out + pitch_gain*fpitch) * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
dump_signal(sig_out, 'sig_out.f32')
- return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
+ return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@@ -239,7 +261,7 @@ class FARGAN(nn.Module):
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.passthrough_size, device=device)
+ torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device)
)
sig = torch.zeros((batch_size, 0), device=device)
@@ -259,7 +281,7 @@ class FARGAN(nn.Module):
pitch = period[:, 3+n]
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
#gain = gain[:,:,None]
- out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain)
+ out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py
index 852b87be..3904253e 100644
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -121,6 +121,8 @@ if __name__ == '__main__':
if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160]
lpc = lpc[:,:sequence_length,:]
+ features = features[:,:sequence_length+4,:]
+ periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
lpc=lpc[::2,:]