diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-20 13:28:24 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-11-20 13:28:24 +0300 |
commit | 10d0b6fa66715f85ae87d0de7da6d90c58fd7f4a (patch) | |
tree | b6c3e4d52cb4192bd5cc049274833236c464694b | |
parent | 1a97c168a8bbf5ea9bc0d873b934e8a571736e20 (diff) |
bugfixesexp-lace-variable-lookahead
-rw-r--r-- | dnn/torch/osce/models/lace.py | 13 | ||||
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 4 |
2 files changed, 15 insertions, 2 deletions
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 1d443c21..e8166295 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -122,6 +122,19 @@ class LACE(NNSBase): def forward(self, x, features, periods, numbits, debug=False): + if features.size(1) % self.max_lookahead: + if self.training: + raise ValueError(f"number of frames must be divisible by {self.max_lookahead}") + else: + # truncate input + print(features.shape, periods.shape, numbits.shape, x.shape) + num_frames = self.max_lookahead * (features.size(1) // self.max_lookahead) + features = features[:, :num_frames] + periods = periods[:, :num_frames] + numbits = numbits[:, :num_frames] + x = x[..., :num_frames * self.FRAME_SIZE] + print(features.shape, periods.shape, numbits.shape, x.shape) + periods = periods.squeeze(-1) pitch_embedding = self.pitch_embedding(periods) numbits_embedding = self.numbits_embedding(numbits).flatten(2) diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 95439d17..8b400b9b 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -155,7 +155,7 @@ class NoLACE(NNSBase): def forward(self, x, features, periods, numbits, debug=False): if features.size(1) % self.max_lookahead: - if self.train: + if self.training: raise ValueError(f"number of frames must be divisible by {self.max_lookahead}") else: # truncate input @@ -163,7 +163,7 @@ class NoLACE(NNSBase): features = features[:, :num_frames] periods = periods[:, :num_frames] numbits = numbits[:, :num_frames] - x = x[:, :num_frames * self.FRAME_SIZE] + x = x[..., :num_frames * self.FRAME_SIZE] periods = periods.squeeze(-1) pitch_embedding = self.pitch_embedding(periods) |