Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-01 20:14:51 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-13 05:50:47 +0300
commit2e0c1ad3aefdc6bf4c30dc6cba44d52e5567cb68 (patch)
treea187d110e3c76a7e39385d203b573e6b0d2d2e5c
parent4f63743f8f814d608df4b249ac796f8edd15ade0 (diff)
Also use previous frame
-rw-r--r--dnn/torch/fargan/fargan.py9
-rw-r--r--dnn/torch/fargan/train_fargan.py4
2 files changed, 7 insertions, 6 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index d895eba0..fdea4d6b 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -9,7 +9,7 @@ Fs = 16000
fid_dict = {}
def dump_signal(x, filename):
- #return
+ return
if filename in fid_dict:
fid = fid_dict[filename]
else:
@@ -143,7 +143,7 @@ class FARGANSub(nn.Module):
gain_param = 1 if self.has_gain else 0
- self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False)
+ self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
@@ -171,7 +171,8 @@ class FARGANSub(nn.Module):
idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None])
rng = torch.arange(self.subframe_size, device=device)
idx = idx + rng[None,:]
- prev = torch.gather(exc_mem, 1, idx)
+ pred = torch.gather(exc_mem, 1, idx)
+ prev = torch.cat([pred, prev], 1)
#prev = prev*0
dump_signal(prev, 'pitch_exc.f32')
dump_signal(exc_mem, 'exc_mem.f32')
@@ -196,7 +197,7 @@ class FARGANSub(nn.Module):
dump_signal(sig_out, 'exc_out.f32')
if self.has_gain:
pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
- sig_out = (sig_out + pitch_gain*prev[:,:-1]) * gain
+ sig_out = (sig_out + pitch_gain*prev[:,:self.subframe_size]) * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py
index b58eabb3..b4eef6fd 100644
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -54,7 +54,7 @@ epochs = args.epochs
sequence_length = args.sequence_length
lr_decay = args.lr_decay
-adam_betas = [0.9, 0.99]
+adam_betas = [0.8, 0.99]
adam_eps = 1e-8
features_file = args.features
signal_file = args.signal
@@ -92,7 +92,7 @@ dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_len
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
-optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
+optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler