Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-25 18:20:44 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-28 23:15:03 +0300
commit7e770ffb3ae1931185b3563831868ea946a330d0 (patch)
tree1f27abf621664861d02c515c0a1cb1bad2a59405
parent5fd10ee92112f6bfc62c599578f91717ab915d9b (diff)
remove phase
-rw-r--r--dnn/torch/fargan/fargan.py26
1 files changed, 10 insertions, 16 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index 66a20fbc..8988148f 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -185,7 +185,7 @@ class FARGANSub(nn.Module):
self.cond_size = cond_size
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
- self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
+ self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
@@ -206,11 +206,10 @@ class FARGANSub(nn.Module):
self.apply(init_weights)
- def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
+ def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)
- dump_signal(prev, 'prev_in.f32')
idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254)
rng = torch.arange(self.subframe_size+4, device=device)
@@ -218,11 +217,13 @@ class FARGANSub(nn.Module):
pred = torch.gather(exc_mem, 1, idx)
pred = pred/(1e-5+gain)
+ prev = exc_mem[:,-self.subframe_size:]
+ dump_signal(prev, 'prev_in.f32')
prev = prev/(1e-5+gain)
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
- tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
+ tmp = torch.cat((cond, pred, prev), 1)
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
fpitch = pred[:,2:-2]
@@ -251,8 +252,9 @@ class FARGANSub(nn.Module):
#sig_out = (sig_out + pitch_gain*fpitch) * gain
sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
+ prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
dump_signal(sig_out, 'sig_out.f32')
- return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
+ return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
class FARGAN(nn.Module):
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@@ -271,10 +273,7 @@ class FARGAN(nn.Module):
device = features.device
batch_size = features.size(0)
- phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
- #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
-
- prev = torch.zeros(batch_size, self.subframe_size, device=device)
+ prev = torch.zeros(batch_size, 256, device=device)
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
@@ -282,26 +281,22 @@ class FARGAN(nn.Module):
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device)
+ torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
)
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
if pre is not None:
- prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
start = 1 if nb_pre_frames>0 else 0
for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
- preal = phase_real[:, pos:pos+self.subframe_size]
- pimag = phase_imag[:, pos:pos+self.subframe_size]
- phase = torch.cat([preal, pimag], 1)
#print("now: ", preal.shape, prev.shape, sig_in.shape)
pitch = period[:, 3+n]
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
#gain = gain[:,:,None]
- out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain)
+ out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
if n < nb_pre_frames:
out = pre[:, pos:pos+self.subframe_size]
@@ -309,6 +304,5 @@ class FARGAN(nn.Module):
else:
sig = torch.cat([sig, out], 1)
- prev = out
states = [s.detach() for s in states]
return sig, states