Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-09-07 00:15:24 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-09-13 05:50:48 +0300
commit72c5ea4129dc6473fb1d82ef3ec3d5714fab8cbd (patch)
tree15f2b6e26b195671f7cddbb41c9dc570ac343ce2
parent108b75c4b189ea2eedd8e9cb0a2b56a9b4424466 (diff)
Only use one frame of pre-loading
-rw-r--r--dnn/torch/fargan/fargan.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py
index daa13f17..952c1b84 100644
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -235,18 +235,21 @@ class FARGAN(nn.Module):
exc_mem = torch.zeros(batch_size, 256, device=device)
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
- if states is None:
- states = (
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.cond_size, device=device),
- torch.zeros(batch_size, self.passthrough_size, device=device)
- )
+ states = (
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.cond_size, device=device),
+ torch.zeros(batch_size, self.passthrough_size, device=device)
+ )
sig = torch.zeros((batch_size, 0), device=device)
cond = self.cond_net(features, period)
passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
- for n in range(nb_frames+nb_pre_frames):
+ if pre is not None:
+ prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
+ exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
+ start = 1 if nb_pre_frames>0 else 0
+ for n in range(start, nb_frames+nb_pre_frames):
for k in range(self.nb_subframes):
pos = n*self.frame_size + k*self.subframe_size
preal = phase_real[:, pos:pos+self.subframe_size]