diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-01 20:14:51 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-13 05:50:47 +0300 |
commit | 2e0c1ad3aefdc6bf4c30dc6cba44d52e5567cb68 (patch) | |
tree | a187d110e3c76a7e39385d203b573e6b0d2d2e5c | |
parent | 4f63743f8f814d608df4b249ac796f8edd15ade0 (diff) |
Also use previous frame
-rw-r--r-- | dnn/torch/fargan/fargan.py | 9 | ||||
-rw-r--r-- | dnn/torch/fargan/train_fargan.py | 4 |
2 files changed, 7 insertions, 6 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index d895eba0..fdea4d6b 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -9,7 +9,7 @@ Fs = 16000 fid_dict = {} def dump_signal(x, filename): - #return + return if filename in fid_dict: fid = fid_dict[filename] else: @@ -143,7 +143,7 @@ class FARGANSub(nn.Module): gain_param = 1 if self.has_gain else 0 - self.sig_dense1 = nn.Linear(3*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False) + self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size+gain_param, self.cond_size, bias=False) self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False) self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False) @@ -171,7 +171,8 @@ class FARGANSub(nn.Module): idx = 256-torch.maximum(torch.tensor(self.subframe_size, device=device), period[:,None]) rng = torch.arange(self.subframe_size, device=device) idx = idx + rng[None,:] - prev = torch.gather(exc_mem, 1, idx) + pred = torch.gather(exc_mem, 1, idx) + prev = torch.cat([pred, prev], 1) #prev = prev*0 dump_signal(prev, 'pitch_exc.f32') dump_signal(exc_mem, 'exc_mem.f32') @@ -196,7 +197,7 @@ class FARGANSub(nn.Module): dump_signal(sig_out, 'exc_out.f32') if self.has_gain: pitch_gain = torch.exp(self.gain_dense_out(gru3_out)) - sig_out = (sig_out + pitch_gain*prev[:,:-1]) * gain + sig_out = (sig_out + pitch_gain*prev[:,:self.subframe_size]) * gain exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1) dump_signal(sig_out, 'sig_out.f32') return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough) diff --git a/dnn/torch/fargan/train_fargan.py b/dnn/torch/fargan/train_fargan.py index b58eabb3..b4eef6fd 100644 --- a/dnn/torch/fargan/train_fargan.py +++ b/dnn/torch/fargan/train_fargan.py @@ -54,7 +54,7 @@ epochs = args.epochs sequence_length = args.sequence_length lr_decay = args.lr_decay -adam_betas = [0.9, 0.99] +adam_betas = [0.8, 0.99] adam_eps = 1e-8 features_file = args.features signal_file = args.signal @@ -92,7 +92,7 @@ dataset = FARGANDataset(features_file, signal_file, sequence_length=sequence_len dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) -optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) +optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) # learning rate scheduler |