Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-20 13:28:24 +0300
committerJan Buethe <jbuethe@amazon.de>2023-11-20 13:28:24 +0300
commit10d0b6fa66715f85ae87d0de7da6d90c58fd7f4a (patch)
treeb6c3e4d52cb4192bd5cc049274833236c464694b
parent1a97c168a8bbf5ea9bc0d873b934e8a571736e20 (diff)
-rw-r--r--dnn/torch/osce/models/lace.py13
-rw-r--r--dnn/torch/osce/models/no_lace.py4
2 files changed, 15 insertions, 2 deletions
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
index 1d443c21..e8166295 100644
--- a/dnn/torch/osce/models/lace.py
+++ b/dnn/torch/osce/models/lace.py
@@ -122,6 +122,19 @@ class LACE(NNSBase):
def forward(self, x, features, periods, numbits, debug=False):
+ if features.size(1) % self.max_lookahead:
+ if self.training:
+ raise ValueError(f"number of frames must be divisible by {self.max_lookahead}")
+ else:
+ # truncate input
+ print(features.shape, periods.shape, numbits.shape, x.shape)
+ num_frames = self.max_lookahead * (features.size(1) // self.max_lookahead)
+ features = features[:, :num_frames]
+ periods = periods[:, :num_frames]
+ numbits = numbits[:, :num_frames]
+ x = x[..., :num_frames * self.FRAME_SIZE]
+ print(features.shape, periods.shape, numbits.shape, x.shape)
+
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)
numbits_embedding = self.numbits_embedding(numbits).flatten(2)
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 95439d17..8b400b9b 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -155,7 +155,7 @@ class NoLACE(NNSBase):
def forward(self, x, features, periods, numbits, debug=False):
if features.size(1) % self.max_lookahead:
- if self.train:
+ if self.training:
raise ValueError(f"number of frames must be divisible by {self.max_lookahead}")
else:
# truncate input
@@ -163,7 +163,7 @@ class NoLACE(NNSBase):
features = features[:, :num_frames]
periods = periods[:, :num_frames]
numbits = numbits[:, :num_frames]
- x = x[:, :num_frames * self.FRAME_SIZE]
+ x = x[..., :num_frames * self.FRAME_SIZE]
periods = periods.squeeze(-1)
pitch_embedding = self.pitch_embedding(periods)