diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-07 00:15:24 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-09-13 05:50:48 +0300 |
commit | 72c5ea4129dc6473fb1d82ef3ec3d5714fab8cbd (patch) | |
tree | 15f2b6e26b195671f7cddbb41c9dc570ac343ce2 | |
parent | 108b75c4b189ea2eedd8e9cb0a2b56a9b4424466 (diff) |
Only use one frame of pre-loading
-rw-r--r-- | dnn/torch/fargan/fargan.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index daa13f17..952c1b84 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -235,18 +235,21 @@ class FARGAN(nn.Module): exc_mem = torch.zeros(batch_size, 256, device=device) nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 - if states is None: - states = ( - torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.cond_size, device=device), - torch.zeros(batch_size, self.passthrough_size, device=device) - ) + states = ( + torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, self.passthrough_size, device=device) + ) sig = torch.zeros((batch_size, 0), device=device) cond = self.cond_net(features, period) passthrough = torch.zeros(batch_size, self.passthrough_size, device=device) - for n in range(nb_frames+nb_pre_frames): + if pre is not None: + prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size] + exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size] + start = 1 if nb_pre_frames>0 else 0 + for n in range(start, nb_frames+nb_pre_frames): for k in range(self.nb_subframes): pos = n*self.frame_size + k*self.subframe_size preal = phase_real[:, pos:pos+self.subframe_size] |