diff options
Diffstat (limited to 'dnn/torch/osce/data/lpcnet_vocoding_dataset.py')
-rw-r--r-- | dnn/torch/osce/data/lpcnet_vocoding_dataset.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/dnn/torch/osce/data/lpcnet_vocoding_dataset.py b/dnn/torch/osce/data/lpcnet_vocoding_dataset.py index 36c8c724..d9b5c6b8 100644 --- a/dnn/torch/osce/data/lpcnet_vocoding_dataset.py +++ b/dnn/torch/osce/data/lpcnet_vocoding_dataset.py @@ -86,6 +86,8 @@ class LPCNetVocodingDataset(Dataset): self.getitem = self.getitem_v1 elif self.version == 2: self.getitem = self.getitem_v2 + elif self.version == 3: + self.getitem = self.getitem_v2 else: raise ValueError(f"dataset version {self.version} unknown") @@ -138,7 +140,10 @@ class LPCNetVocodingDataset(Dataset): # convert periods if 'periods' in self.input_features: - sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') + if self.version < 3: + sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') + else: + sample['periods'] = np.round(np.clip(256./2**(sample['periods']+1.5), 32, 256)).astype('int') signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length |