Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
commit35ee397e060283d30c098ae5e17836316bbec08b (patch)
tree4a81b86f8c0738bbdc7147214c53fda54cd0f3f3 /dnn/torch/lpcnet/data
parent90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff)
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
Diffstat (limited to 'dnn/torch/lpcnet/data')
-rw-r--r--dnn/torch/lpcnet/data/__init__.py1
-rw-r--r--dnn/torch/lpcnet/data/lpcnet_dataset.py198
2 files changed, 199 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/data/__init__.py b/dnn/torch/lpcnet/data/__init__.py
new file mode 100644
index 00000000..50bad871
--- /dev/null
+++ b/dnn/torch/lpcnet/data/__init__.py
@@ -0,0 +1 @@
+from .lpcnet_dataset import LPCNetDataset \ No newline at end of file
diff --git a/dnn/torch/lpcnet/data/lpcnet_dataset.py b/dnn/torch/lpcnet/data/lpcnet_dataset.py
new file mode 100644
index 00000000..e37fa385
--- /dev/null
+++ b/dnn/torch/lpcnet/data/lpcnet_dataset.py
@@ -0,0 +1,198 @@
+""" Dataset for LPCNet training """
+import os
+
+import yaml
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+
+
+scale = 255.0/32768.0
+scale_1 = 32768.0/255.0
+def ulaw2lin(u):
+ u = u - 128
+ s = np.sign(u)
+ u = np.abs(u)
+ return s*scale_1*(np.exp(u/128.*np.log(256))-1)
+
+
+def lin2ulaw(x):
+ s = np.sign(x)
+ x = np.abs(x)
+ u = (s*(128*np.log(1+scale*x)/np.log(256)))
+ u = np.clip(128 + np.round(u), 0, 255)
+ return u
+
+
+def run_lpc(signal, lpcs, frame_length=160):
+ num_frames, lpc_order = lpcs.shape
+
+ prediction = np.concatenate(
+ [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
+ )
+ error = signal[lpc_order :] - prediction
+
+ return prediction, error
+
+class LPCNetDataset(Dataset):
+ def __init__(self,
+ path_to_dataset,
+ features=['cepstrum', 'periods', 'pitch_corr'],
+ input_signals=['last_signal', 'prediction', 'last_error'],
+ target='error',
+ frames_per_sample=15,
+ feature_history=2,
+ feature_lookahead=2,
+ lpc_gamma=1):
+
+ super(LPCNetDataset, self).__init__()
+
+ # load dataset info
+ self.path_to_dataset = path_to_dataset
+ with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
+ dataset = yaml.load(f, yaml.FullLoader)
+
+ # dataset version
+ self.version = dataset['version']
+ if self.version == 1:
+ self.getitem = self.getitem_v1
+ elif self.version == 2:
+ self.getitem = self.getitem_v2
+ else:
+ raise ValueError(f"dataset version {self.version} unknown")
+
+ # features
+ self.feature_history = feature_history
+ self.feature_lookahead = feature_lookahead
+ self.frame_offset = 1 + self.feature_history
+ self.frames_per_sample = frames_per_sample
+ self.input_features = features
+ self.feature_frame_layout = dataset['feature_frame_layout']
+ self.lpc_gamma = lpc_gamma
+
+ # load feature file
+ self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
+ self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
+ self.feature_frame_length = dataset['feature_frame_length']
+
+ assert len(self.features) % self.feature_frame_length == 0
+ self.features = self.features.reshape((-1, self.feature_frame_length))
+
+ # derive number of samples is dataset
+ self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
+
+ # signals
+ self.frame_length = dataset['frame_length']
+ self.signal_frame_layout = dataset['signal_frame_layout']
+ self.input_signals = input_signals
+ self.target = target
+
+ # load signals
+ self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
+ self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
+ self.signal_frame_length = dataset['signal_frame_length']
+ self.signals = self.signals.reshape((-1, self.signal_frame_length))
+ assert len(self.signals) == len(self.features) * self.frame_length
+
+ def __getitem__(self, index):
+ return self.getitem(index)
+
+ def getitem_v2(self, index):
+ sample = dict()
+
+ # extract features
+ frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
+ frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
+
+ for feature in self.input_features:
+ feature_start, feature_stop = self.feature_frame_layout[feature]
+ sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
+
+ # convert periods
+ if 'periods' in self.input_features:
+ sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
+
+ signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
+ signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
+
+ # last_signal and signal are always expected to be there
+ sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
+ sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
+
+ # calculate prediction and error if lpc coefficients present and prediction not given
+ if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
+ # lpc coefficients with one frame lookahead
+ # frame positions (start one frame early for past excitation)
+ frame_start = self.frame_offset + self.frames_per_sample * index - 1
+ frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
+
+ # feature positions
+ lpc_start, lpc_stop = self.feature_frame_layout['lpc']
+ lpc_order = lpc_stop - lpc_start
+ lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
+
+ # LPC weighting
+ lpc_order = lpc_stop - lpc_start
+ weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
+ lpcs = lpcs * weights
+
+ # signal position (lpc_order samples as history)
+ signal_start = frame_start * self.frame_length - lpc_order + 1
+ signal_stop = frame_stop * self.frame_length + 1
+ noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
+ clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
+
+ noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
+
+ # extract signals
+ offset = self.frame_length
+ sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
+ sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
+ # calculate error between real signal and noisy prediction
+
+
+ sample['error'] = sample['signal'] - sample['prediction']
+
+
+ # concatenate features
+ feature_keys = [key for key in self.input_features if not key.startswith("periods")]
+ features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
+ signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
+ target = torch.LongTensor(lin2ulaw(sample[self.target]))
+ periods = torch.LongTensor(sample['periods'])
+
+ return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
+
+ def getitem_v1(self, index):
+ sample = dict()
+
+ # extract features
+ frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
+ frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
+
+ for feature in self.input_features:
+ feature_start, feature_stop = self.feature_frame_layout[feature]
+ sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
+
+ # convert periods
+ if 'periods' in self.input_features:
+ sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
+
+ signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
+ signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
+
+ # last_signal and signal are always expected to be there
+ for signal_name, index in self.signal_frame_layout.items():
+ sample[signal_name] = self.signals[signal_start : signal_stop, index]
+
+ # concatenate features
+ feature_keys = [key for key in self.input_features if not key.startswith("periods")]
+ features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
+ signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
+ target = torch.LongTensor(sample[self.target])
+ periods = torch.LongTensor(sample['periods'])
+
+ return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
+
+ def __len__(self):
+ return self.dataset_length