diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-09-05 13:29:38 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-09-05 13:29:38 +0300 |
commit | 35ee397e060283d30c098ae5e17836316bbec08b (patch) | |
tree | 4a81b86f8c0738bbdc7147214c53fda54cd0f3f3 /dnn/torch/lpcnet/data | |
parent | 90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff) |
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
Diffstat (limited to 'dnn/torch/lpcnet/data')
-rw-r--r-- | dnn/torch/lpcnet/data/__init__.py | 1 | ||||
-rw-r--r-- | dnn/torch/lpcnet/data/lpcnet_dataset.py | 198 |
2 files changed, 199 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/data/__init__.py b/dnn/torch/lpcnet/data/__init__.py new file mode 100644 index 00000000..50bad871 --- /dev/null +++ b/dnn/torch/lpcnet/data/__init__.py @@ -0,0 +1 @@ +from .lpcnet_dataset import LPCNetDataset
\ No newline at end of file diff --git a/dnn/torch/lpcnet/data/lpcnet_dataset.py b/dnn/torch/lpcnet/data/lpcnet_dataset.py new file mode 100644 index 00000000..e37fa385 --- /dev/null +++ b/dnn/torch/lpcnet/data/lpcnet_dataset.py @@ -0,0 +1,198 @@ +""" Dataset for LPCNet training """ +import os + +import yaml +import torch +import numpy as np +from torch.utils.data import Dataset + + +scale = 255.0/32768.0 +scale_1 = 32768.0/255.0 +def ulaw2lin(u): + u = u - 128 + s = np.sign(u) + u = np.abs(u) + return s*scale_1*(np.exp(u/128.*np.log(256))-1) + + +def lin2ulaw(x): + s = np.sign(x) + x = np.abs(x) + u = (s*(128*np.log(1+scale*x)/np.log(256))) + u = np.clip(128 + np.round(u), 0, 255) + return u + + +def run_lpc(signal, lpcs, frame_length=160): + num_frames, lpc_order = lpcs.shape + + prediction = np.concatenate( + [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)] + ) + error = signal[lpc_order :] - prediction + + return prediction, error + +class LPCNetDataset(Dataset): + def __init__(self, + path_to_dataset, + features=['cepstrum', 'periods', 'pitch_corr'], + input_signals=['last_signal', 'prediction', 'last_error'], + target='error', + frames_per_sample=15, + feature_history=2, + feature_lookahead=2, + lpc_gamma=1): + + super(LPCNetDataset, self).__init__() + + # load dataset info + self.path_to_dataset = path_to_dataset + with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f: + dataset = yaml.load(f, yaml.FullLoader) + + # dataset version + self.version = dataset['version'] + if self.version == 1: + self.getitem = self.getitem_v1 + elif self.version == 2: + self.getitem = self.getitem_v2 + else: + raise ValueError(f"dataset version {self.version} unknown") + + # features + self.feature_history = feature_history + self.feature_lookahead = feature_lookahead + self.frame_offset = 1 + self.feature_history + self.frames_per_sample = frames_per_sample + self.input_features = features + self.feature_frame_layout = dataset['feature_frame_layout'] + self.lpc_gamma = lpc_gamma + + # load feature file + self.feature_file = os.path.join(path_to_dataset, dataset['feature_file']) + self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype']) + self.feature_frame_length = dataset['feature_frame_length'] + + assert len(self.features) % self.feature_frame_length == 0 + self.features = self.features.reshape((-1, self.feature_frame_length)) + + # derive number of samples is dataset + self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample + + # signals + self.frame_length = dataset['frame_length'] + self.signal_frame_layout = dataset['signal_frame_layout'] + self.input_signals = input_signals + self.target = target + + # load signals + self.signal_file = os.path.join(path_to_dataset, dataset['signal_file']) + self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype']) + self.signal_frame_length = dataset['signal_frame_length'] + self.signals = self.signals.reshape((-1, self.signal_frame_length)) + assert len(self.signals) == len(self.features) * self.frame_length + + def __getitem__(self, index): + return self.getitem(index) + + def getitem_v2(self, index): + sample = dict() + + # extract features + frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history + frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead + + for feature in self.input_features: + feature_start, feature_stop = self.feature_frame_layout[feature] + sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] + + # convert periods + if 'periods' in self.input_features: + sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') + + signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length + signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length + + # last_signal and signal are always expected to be there + sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] + sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']] + + # calculate prediction and error if lpc coefficients present and prediction not given + if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout: + # lpc coefficients with one frame lookahead + # frame positions (start one frame early for past excitation) + frame_start = self.frame_offset + self.frames_per_sample * index - 1 + frame_stop = self.frame_offset + self.frames_per_sample * (index + 1) + + # feature positions + lpc_start, lpc_stop = self.feature_frame_layout['lpc'] + lpc_order = lpc_stop - lpc_start + lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop] + + # LPC weighting + lpc_order = lpc_stop - lpc_start + weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]) + lpcs = lpcs * weights + + # signal position (lpc_order samples as history) + signal_start = frame_start * self.frame_length - lpc_order + 1 + signal_stop = frame_stop * self.frame_length + 1 + noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] + clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']] + + noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length) + + # extract signals + offset = self.frame_length + sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample] + sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample] + # calculate error between real signal and noisy prediction + + + sample['error'] = sample['signal'] - sample['prediction'] + + + # concatenate features + feature_keys = [key for key in self.input_features if not key.startswith("periods")] + features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) + signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1) + target = torch.LongTensor(lin2ulaw(sample[self.target])) + periods = torch.LongTensor(sample['periods']) + + return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target} + + def getitem_v1(self, index): + sample = dict() + + # extract features + frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history + frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead + + for feature in self.input_features: + feature_start, feature_stop = self.feature_frame_layout[feature] + sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] + + # convert periods + if 'periods' in self.input_features: + sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') + + signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length + signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length + + # last_signal and signal are always expected to be there + for signal_name, index in self.signal_frame_layout.items(): + sample[signal_name] = self.signals[signal_start : signal_stop, index] + + # concatenate features + feature_keys = [key for key in self.input_features if not key.startswith("periods")] + features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) + signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1) + target = torch.LongTensor(sample[self.target]) + periods = torch.LongTensor(sample['periods']) + + return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target} + + def __len__(self): + return self.dataset_length |