Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
commit35ee397e060283d30c098ae5e17836316bbec08b (patch)
tree4a81b86f8c0738bbdc7147214c53fda54cd0f3f3
parent90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff)
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
-rw-r--r--dnn/torch/lpcnet/README.md27
-rw-r--r--dnn/torch/lpcnet/add_dataset_config.py48
-rw-r--r--dnn/torch/lpcnet/data/__init__.py1
-rw-r--r--dnn/torch/lpcnet/data/lpcnet_dataset.py198
-rw-r--r--dnn/torch/lpcnet/engine/lpcnet_engine.py112
-rw-r--r--dnn/torch/lpcnet/make_default_setup.py27
-rw-r--r--dnn/torch/lpcnet/make_test_config.py49
-rw-r--r--dnn/torch/lpcnet/models/__init__.py8
-rw-r--r--dnn/torch/lpcnet/models/lpcnet.py274
-rw-r--r--dnn/torch/lpcnet/models/multi_rate_lpcnet.py408
-rw-r--r--dnn/torch/lpcnet/print_lpcnet_complexity.py35
-rw-r--r--dnn/torch/lpcnet/scripts/collect_multi_run_results.py161
-rw-r--r--dnn/torch/lpcnet/scripts/loop_run.sh52
-rw-r--r--dnn/torch/lpcnet/scripts/make_animation.py37
-rw-r--r--dnn/torch/lpcnet/scripts/modify_dataset_target.py17
-rw-r--r--dnn/torch/lpcnet/scripts/multi_run.sh17
-rw-r--r--dnn/torch/lpcnet/scripts/run_inference_test.sh22
-rw-r--r--dnn/torch/lpcnet/scripts/update_checkpoints.py25
-rw-r--r--dnn/torch/lpcnet/scripts/update_output_folder.sh22
-rw-r--r--dnn/torch/lpcnet/scripts/update_setups.py28
-rw-r--r--dnn/torch/lpcnet/test_lpcnet.py60
-rw-r--r--dnn/torch/lpcnet/train_lpcnet.py243
-rw-r--r--dnn/torch/lpcnet/utils/__init__.py4
-rw-r--r--dnn/torch/lpcnet/utils/data.py112
-rw-r--r--dnn/torch/lpcnet/utils/endoscopy.py205
-rw-r--r--dnn/torch/lpcnet/utils/layers/__init__.py3
-rw-r--r--dnn/torch/lpcnet/utils/layers/dual_fc.py15
-rw-r--r--dnn/torch/lpcnet/utils/layers/pcm_embeddings.py42
-rw-r--r--dnn/torch/lpcnet/utils/layers/subconditioner.py468
-rw-r--r--dnn/torch/lpcnet/utils/misc.py36
-rw-r--r--dnn/torch/lpcnet/utils/pcm.py6
-rw-r--r--dnn/torch/lpcnet/utils/sample.py15
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/__init__.py2
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/common.py92
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py158
-rw-r--r--dnn/torch/lpcnet/utils/templates.py128
-rw-r--r--dnn/torch/lpcnet/utils/ulaw.py29
-rw-r--r--dnn/torch/lpcnet/utils/wav.py14
38 files changed, 3200 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/README.md b/dnn/torch/lpcnet/README.md
new file mode 100644
index 00000000..26d9ea19
--- /dev/null
+++ b/dnn/torch/lpcnet/README.md
@@ -0,0 +1,27 @@
+# LPCNet
+
+Incomplete pytorch implementation of LPCNet
+
+## Data preparation
+For data preparation use dump_data in github.com/xiph/LPCNet. To turn this into
+a training dataset, copy data and feature file to a folder and run
+
+python add_dataset_config.py my_dataset_folder
+
+
+## Training
+To train a model, create and adjust a setup file, e.g. with
+
+python make_default_setup.py my_setup.yml --path2dataset my_dataset_folder
+
+Then simply run
+
+python train_lpcnet.py my_setup.yml my_output
+
+## Inference
+Create feature file with dump_data from github.com/xiph/LPCNet. Then run e.g.
+
+python test_lpcnet.py features.f32 my_output/checkpoints/checkpoint_ep_10.pth out.wav
+
+Inference runs on CPU and takes usually between 3 and 20 seconds per generated second of audio,
+depending on the CPU.
diff --git a/dnn/torch/lpcnet/add_dataset_config.py b/dnn/torch/lpcnet/add_dataset_config.py
new file mode 100644
index 00000000..2dba0030
--- /dev/null
+++ b/dnn/torch/lpcnet/add_dataset_config.py
@@ -0,0 +1,48 @@
+import argparse
+import os
+
+import yaml
+
+
+from utils.templates import dataset_template_v1, dataset_template_v2
+
+
+
+
+parser = argparse.ArgumentParser("add_dataset_config.py")
+
+parser.add_argument('path', type=str, help='path to folder containing feature and data file')
+parser.add_argument('--version', type=int, help="dataset version, 1 for classic LPCNet with 55 feature slots, 2 for new format with 36 feature slots.", default=2)
+parser.add_argument('--description', type=str, help='brief dataset description', default="I will add a description later")
+args = parser.parse_args()
+
+
+if args.version == 1:
+ template = dataset_template_v1
+ data_extension = '.u8'
+elif args.version == 2:
+ template = dataset_template_v2
+ data_extension = '.s16'
+else:
+ raise ValueError(f"unknown dataset version {args.version}")
+
+# get folder content
+content = os.listdir(args.path)
+
+features = [c for c in content if c.endswith('.f32')]
+
+if len(features) != 1:
+ print("could not determine feature file")
+else:
+ template['feature_file'] = features[0]
+
+data = [c for c in content if c.endswith(data_extension)]
+if len(data) != 1:
+ print("could not determine data file")
+else:
+ template['signal_file'] = data[0]
+
+template['description'] = args.description
+
+with open(os.path.join(args.path, 'info.yml'), 'w') as f:
+ yaml.dump(template, f)
diff --git a/dnn/torch/lpcnet/data/__init__.py b/dnn/torch/lpcnet/data/__init__.py
new file mode 100644
index 00000000..50bad871
--- /dev/null
+++ b/dnn/torch/lpcnet/data/__init__.py
@@ -0,0 +1 @@
+from .lpcnet_dataset import LPCNetDataset \ No newline at end of file
diff --git a/dnn/torch/lpcnet/data/lpcnet_dataset.py b/dnn/torch/lpcnet/data/lpcnet_dataset.py
new file mode 100644
index 00000000..e37fa385
--- /dev/null
+++ b/dnn/torch/lpcnet/data/lpcnet_dataset.py
@@ -0,0 +1,198 @@
+""" Dataset for LPCNet training """
+import os
+
+import yaml
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+
+
+scale = 255.0/32768.0
+scale_1 = 32768.0/255.0
+def ulaw2lin(u):
+ u = u - 128
+ s = np.sign(u)
+ u = np.abs(u)
+ return s*scale_1*(np.exp(u/128.*np.log(256))-1)
+
+
+def lin2ulaw(x):
+ s = np.sign(x)
+ x = np.abs(x)
+ u = (s*(128*np.log(1+scale*x)/np.log(256)))
+ u = np.clip(128 + np.round(u), 0, 255)
+ return u
+
+
+def run_lpc(signal, lpcs, frame_length=160):
+ num_frames, lpc_order = lpcs.shape
+
+ prediction = np.concatenate(
+ [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
+ )
+ error = signal[lpc_order :] - prediction
+
+ return prediction, error
+
+class LPCNetDataset(Dataset):
+ def __init__(self,
+ path_to_dataset,
+ features=['cepstrum', 'periods', 'pitch_corr'],
+ input_signals=['last_signal', 'prediction', 'last_error'],
+ target='error',
+ frames_per_sample=15,
+ feature_history=2,
+ feature_lookahead=2,
+ lpc_gamma=1):
+
+ super(LPCNetDataset, self).__init__()
+
+ # load dataset info
+ self.path_to_dataset = path_to_dataset
+ with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
+ dataset = yaml.load(f, yaml.FullLoader)
+
+ # dataset version
+ self.version = dataset['version']
+ if self.version == 1:
+ self.getitem = self.getitem_v1
+ elif self.version == 2:
+ self.getitem = self.getitem_v2
+ else:
+ raise ValueError(f"dataset version {self.version} unknown")
+
+ # features
+ self.feature_history = feature_history
+ self.feature_lookahead = feature_lookahead
+ self.frame_offset = 1 + self.feature_history
+ self.frames_per_sample = frames_per_sample
+ self.input_features = features
+ self.feature_frame_layout = dataset['feature_frame_layout']
+ self.lpc_gamma = lpc_gamma
+
+ # load feature file
+ self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
+ self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
+ self.feature_frame_length = dataset['feature_frame_length']
+
+ assert len(self.features) % self.feature_frame_length == 0
+ self.features = self.features.reshape((-1, self.feature_frame_length))
+
+ # derive number of samples is dataset
+ self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample
+
+ # signals
+ self.frame_length = dataset['frame_length']
+ self.signal_frame_layout = dataset['signal_frame_layout']
+ self.input_signals = input_signals
+ self.target = target
+
+ # load signals
+ self.signal_file = os.path.join(path_to_dataset, dataset['signal_file'])
+ self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
+ self.signal_frame_length = dataset['signal_frame_length']
+ self.signals = self.signals.reshape((-1, self.signal_frame_length))
+ assert len(self.signals) == len(self.features) * self.frame_length
+
+ def __getitem__(self, index):
+ return self.getitem(index)
+
+ def getitem_v2(self, index):
+ sample = dict()
+
+ # extract features
+ frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
+ frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
+
+ for feature in self.input_features:
+ feature_start, feature_stop = self.feature_frame_layout[feature]
+ sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
+
+ # convert periods
+ if 'periods' in self.input_features:
+ sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
+
+ signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
+ signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
+
+ # last_signal and signal are always expected to be there
+ sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
+ sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]
+
+ # calculate prediction and error if lpc coefficients present and prediction not given
+ if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
+ # lpc coefficients with one frame lookahead
+ # frame positions (start one frame early for past excitation)
+ frame_start = self.frame_offset + self.frames_per_sample * index - 1
+ frame_stop = self.frame_offset + self.frames_per_sample * (index + 1)
+
+ # feature positions
+ lpc_start, lpc_stop = self.feature_frame_layout['lpc']
+ lpc_order = lpc_stop - lpc_start
+ lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]
+
+ # LPC weighting
+ lpc_order = lpc_stop - lpc_start
+ weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
+ lpcs = lpcs * weights
+
+ # signal position (lpc_order samples as history)
+ signal_start = frame_start * self.frame_length - lpc_order + 1
+ signal_stop = frame_stop * self.frame_length + 1
+ noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
+ clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]
+
+ noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)
+
+ # extract signals
+ offset = self.frame_length
+ sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
+ sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
+ # calculate error between real signal and noisy prediction
+
+
+ sample['error'] = sample['signal'] - sample['prediction']
+
+
+ # concatenate features
+ feature_keys = [key for key in self.input_features if not key.startswith("periods")]
+ features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
+ signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
+ target = torch.LongTensor(lin2ulaw(sample[self.target]))
+ periods = torch.LongTensor(sample['periods'])
+
+ return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
+
+ def getitem_v1(self, index):
+ sample = dict()
+
+ # extract features
+ frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history
+ frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead
+
+ for feature in self.input_features:
+ feature_start, feature_stop = self.feature_frame_layout[feature]
+ sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]
+
+ # convert periods
+ if 'periods' in self.input_features:
+ sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')
+
+ signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length
+ signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length
+
+ # last_signal and signal are always expected to be there
+ for signal_name, index in self.signal_frame_layout.items():
+ sample[signal_name] = self.signals[signal_start : signal_stop, index]
+
+ # concatenate features
+ feature_keys = [key for key in self.input_features if not key.startswith("periods")]
+ features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
+ signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
+ target = torch.LongTensor(sample[self.target])
+ periods = torch.LongTensor(sample['periods'])
+
+ return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}
+
+ def __len__(self):
+ return self.dataset_length
diff --git a/dnn/torch/lpcnet/engine/lpcnet_engine.py b/dnn/torch/lpcnet/engine/lpcnet_engine.py
new file mode 100644
index 00000000..d78c8266
--- /dev/null
+++ b/dnn/torch/lpcnet/engine/lpcnet_engine.py
@@ -0,0 +1,112 @@
+import torch
+from tqdm import tqdm
+import sys
+
+def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):
+
+ model.to(device)
+ model.train()
+
+ running_loss = 0
+ previous_running_loss = 0
+
+ # gru states
+ gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
+ gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
+ gru_states = [gru_a_state, gru_b_state]
+
+ with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+ for i, batch in enumerate(tepoch):
+
+ # set gradients to zero
+ optimizer.zero_grad()
+
+ # zero out initial gru states
+ gru_a_state.zero_()
+ gru_b_state.zero_()
+
+ # push batch to device
+ for key in batch:
+ batch[key] = batch[key].to(device)
+
+ target = batch['target']
+
+ # calculate model output
+ output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
+
+ # calculate loss
+ loss = criterion(output.permute(0, 2, 1), target)
+
+ # calculate gradients
+ loss.backward()
+
+ # update weights
+ optimizer.step()
+
+ # update learning rate
+ scheduler.step()
+
+ # call sparsifier
+ model.sparsify()
+
+ # update running loss
+ running_loss += float(loss.cpu())
+
+ # update status bar
+ if i % log_interval == 0:
+ tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+ previous_running_loss = running_loss
+
+
+ running_loss /= len(dataloader)
+
+ return running_loss
+
+def evaluate(model, criterion, dataloader, device, log_interval=10):
+
+ model.to(device)
+ model.eval()
+
+ running_loss = 0
+ previous_running_loss = 0
+
+ # gru states
+ gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
+ gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
+ gru_states = [gru_a_state, gru_b_state]
+
+ with torch.no_grad():
+ with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:
+
+ for i, batch in enumerate(tepoch):
+
+
+ # zero out initial gru states
+ gru_a_state.zero_()
+ gru_b_state.zero_()
+
+ # push batch to device
+ for key in batch:
+ batch[key] = batch[key].to(device)
+
+ target = batch['target']
+
+ # calculate model output
+ output = model(batch['features'], batch['periods'], batch['signals'], gru_states)
+
+ # calculate loss
+ loss = criterion(output.permute(0, 2, 1), target)
+
+ # update running loss
+ running_loss += float(loss.cpu())
+
+ # update status bar
+ if i % log_interval == 0:
+ tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
+ previous_running_loss = running_loss
+
+
+ running_loss /= len(dataloader)
+
+ return running_loss \ No newline at end of file
diff --git a/dnn/torch/lpcnet/make_default_setup.py b/dnn/torch/lpcnet/make_default_setup.py
new file mode 100644
index 00000000..bfe18380
--- /dev/null
+++ b/dnn/torch/lpcnet/make_default_setup.py
@@ -0,0 +1,27 @@
+import argparse
+
+import yaml
+
+from utils.templates import setup_dict
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('name', type=str, help='name of default setup file')
+parser.add_argument('--model', choices=['lpcnet', 'multi_rate'], help='LPCNet model name', default='lpcnet')
+parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)
+
+args = parser.parse_args()
+
+setup = setup_dict[args.model]
+
+# update dataset if given
+if type(args.path2dataset) != type(None):
+ setup['dataset'] = args.path2dataset
+
+name = args.name
+if not name.endswith('.yml'):
+ name += '.yml'
+
+if __name__ == '__main__':
+ with open(name, 'w') as f:
+ f.write(yaml.dump(setup))
diff --git a/dnn/torch/lpcnet/make_test_config.py b/dnn/torch/lpcnet/make_test_config.py
new file mode 100644
index 00000000..18a38ae0
--- /dev/null
+++ b/dnn/torch/lpcnet/make_test_config.py
@@ -0,0 +1,49 @@
+import argparse
+import os
+import sys
+
+parser = argparse.ArgumentParser()
+parser.add_argument("config_name", type=str, help="name of config file (.yml will be appended)")
+parser.add_argument("test_name", type=str, help="name for test result display")
+parser.add_argument("checkpoint", type=str, help="checkpoint to test")
+parser.add_argument("--lpcnet-demo", type=str, help="path to lpcnet_demo binary, default: /local/code/LPCNet/lpcnet_demo", default="/local/code/LPCNet/lpcnet_demo")
+parser.add_argument("--lpcnext-path", type=str, help="path to lpcnext folder, defalut: dirname(__file__)", default=os.path.dirname(__file__))
+parser.add_argument("--python-exe", type=str, help='python executable path, default: sys.executable', default=sys.executable)
+parser.add_argument("--pad", type=str, help="left pad of output in seconds, default: 0.015", default="0.015")
+parser.add_argument("--trim", type=str, help="left trim of output in seconds, default: 0", default="0")
+
+
+
+template='''
+test: "{NAME}"
+processing:
+ - "sox {{INPUT}} {{INPUT}}.raw"
+ - "{LPCNET_DEMO} -features {{INPUT}}.raw {{INPUT}}.features.f32"
+ - "{PYTHON} {WORKING}/test_lpcnet.py {{INPUT}}.features.f32 {CHECKPOINT} {{OUTPUT}}.ua.wav"
+ - "sox {{OUTPUT}}.ua.wav {{OUTPUT}}.uap.wav pad {PAD}"
+ - "sox {{OUTPUT}}.uap.wav {{OUTPUT}} trim {TRIM}"
+ - "rm {{INPUT}}.raw {{OUTPUT}}.uap.wav {{OUTPUT}}.ua.wav {{INPUT}}.features.f32"
+'''
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+
+ file_content = template.format(
+ NAME=args.test_name,
+ LPCNET_DEMO=os.path.abspath(args.lpcnet_demo),
+ PYTHON=os.path.abspath(args.python_exe),
+ PAD=args.pad,
+ TRIM=args.trim,
+ WORKING=os.path.abspath(args.lpcnext_path),
+ CHECKPOINT=os.path.abspath(args.checkpoint)
+ )
+
+ print(file_content)
+
+ filename = args.config_name
+ if not filename.endswith(".yml"):
+ filename += ".yml"
+
+ with open(filename, "w") as f:
+ f.write(file_content)
diff --git a/dnn/torch/lpcnet/models/__init__.py b/dnn/torch/lpcnet/models/__init__.py
new file mode 100644
index 00000000..a26bc1cd
--- /dev/null
+++ b/dnn/torch/lpcnet/models/__init__.py
@@ -0,0 +1,8 @@
+from .lpcnet import LPCNet
+from .multi_rate_lpcnet import MultiRateLPCNet
+
+
+model_dict = {
+ 'lpcnet' : LPCNet,
+ 'multi_rate' : MultiRateLPCNet
+} \ No newline at end of file
diff --git a/dnn/torch/lpcnet/models/lpcnet.py b/dnn/torch/lpcnet/models/lpcnet.py
new file mode 100644
index 00000000..e20ae68d
--- /dev/null
+++ b/dnn/torch/lpcnet/models/lpcnet.py
@@ -0,0 +1,274 @@
+import torch
+from torch import nn
+import numpy as np
+
+from utils.ulaw import lin2ulawq, ulaw2lin
+from utils.sample import sample_excitation
+from utils.pcm import clip_to_int16
+from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
+from utils.layers import DualFC
+from utils.misc import get_pdf_from_tree
+
+
+class LPCNet(nn.Module):
+ def __init__(self, config):
+ super(LPCNet, self).__init__()
+
+ #
+ self.input_layout = config['input_layout']
+ self.feature_history = config['feature_history']
+ self.feature_lookahead = config['feature_lookahead']
+
+ # frame rate network parameters
+ self.feature_dimension = config['feature_dimension']
+ self.period_embedding_dim = config['period_embedding_dim']
+ self.period_levels = config['period_levels']
+ self.feature_channels = self.feature_dimension + self.period_embedding_dim
+ self.feature_conditioning_dim = config['feature_conditioning_dim']
+ self.feature_conv_kernel_size = config['feature_conv_kernel_size']
+
+
+ # frame rate network layers
+ self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
+ self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
+ self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
+
+ # sample rate network parameters
+ self.frame_size = config['frame_size']
+ self.signal_levels = config['signal_levels']
+ self.signal_embedding_dim = config['signal_embedding_dim']
+ self.gru_a_units = config['gru_a_units']
+ self.gru_b_units = config['gru_b_units']
+ self.output_levels = config['output_levels']
+ self.hsampling = config.get('hsampling', False)
+
+ self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
+ self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim
+
+ # sample rate network layers
+ self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
+ self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
+ self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
+ self.dual_fc = DualFC(self.gru_b_units, self.output_levels)
+
+ # sparsification
+ self.sparsifier = []
+
+ # GRU A
+ if 'gru_a' in config['sparsification']:
+ gru_config = config['sparsification']['gru_a']
+ task_list = [(self.gru_a, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
+ gru_config['params'], drop_input=True)
+ else:
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
+
+ # GRU B
+ if 'gru_b' in config['sparsification']:
+ gru_config = config['sparsification']['gru_b']
+ task_list = [(self.gru_b, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
+ gru_config['params'])
+ else:
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
+
+ # inference parameters
+ self.lpc_gamma = config.get('lpc_gamma', 1)
+
+ def sparsify(self):
+ for sparsifier in self.sparsifier:
+ sparsifier.step()
+
+ def get_gflops(self, fs, verbose=False):
+ gflops = 0
+
+ # frame rate network
+ conditioning_dim = self.feature_conditioning_dim
+ feature_channels = self.feature_channels
+ frame_rate = fs / self.frame_size
+ frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
+ if verbose:
+ print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
+ gflops += frame_rate_network_complexity
+
+ # gru a
+ gru_a_rate = fs
+ gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
+ if verbose:
+ print(f"gru A: {gru_a_complexity} GFLOPS")
+ gflops += gru_a_complexity
+
+ # gru b
+ gru_b_rate = fs
+ gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
+ if verbose:
+ print(f"gru B: {gru_b_complexity} GFLOPS")
+ gflops += gru_b_complexity
+
+
+ # dual fcs
+ fc = self.dual_fc
+ rate = fs
+ input_size = fc.dense1.in_features
+ output_size = fc.dense1.out_features
+ dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
+ if self.hsampling:
+ dual_fc_complexity /= 8
+ if verbose:
+ print(f"dual_fc: {dual_fc_complexity} GFLOPS")
+ gflops += dual_fc_complexity
+
+ if verbose:
+ print(f'total: {gflops} GFLOPS')
+
+ return gflops
+
+ def frame_rate_network(self, features, periods):
+
+ embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
+ features = torch.concat((features, embedded_periods), dim=-1)
+
+ # convert to channels first and calculate conditioning vector
+ c = torch.permute(features, [0, 2, 1])
+
+ c = torch.tanh(self.feature_conv1(c))
+ c = torch.tanh(self.feature_conv2(c))
+ # back to channels last
+ c = torch.permute(c, [0, 2, 1])
+ c = torch.tanh(self.feature_dense1(c))
+ c = torch.tanh(self.feature_dense2(c))
+
+ return c
+
+ def sample_rate_network(self, signals, c, gru_states):
+ embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
+ c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1)
+
+ y = torch.concat((embedded_signals, c_upsampled), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ y = torch.concat((y, c_upsampled), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+
+ y = self.dual_fc(y)
+
+ if self.hsampling:
+ y = torch.sigmoid(y)
+ log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
+ else:
+ log_probs = torch.log_softmax(y, dim=-1)
+
+ return log_probs, (gru_a_state, gru_b_state)
+
+ def decoder(self, signals, c, gru_states):
+ embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
+
+ y = torch.concat((embedded_signals, c), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ y = torch.concat((y, c), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+
+ y = self.dual_fc(y)
+
+ if self.hsampling:
+ y = torch.sigmoid(y)
+ probs = get_pdf_from_tree(y)
+ else:
+ probs = torch.softmax(y, dim=-1)
+
+ return probs, (gru_a_state, gru_b_state)
+
+ def forward(self, features, periods, signals, gru_states):
+
+ c = self.frame_rate_network(features, periods)
+ log_probs, _ = self.sample_rate_network(signals, c, gru_states)
+
+ return log_probs
+
+ def generate(self, features, periods, lpcs):
+
+ with torch.no_grad():
+ device = self.parameters().__next__().device
+
+ num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
+ lpc_order = lpcs.shape[-1]
+ num_input_signals = len(self.input_layout['signals'])
+ pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
+
+ # signal buffers
+ pcm = torch.zeros((num_frames * self.frame_size + lpc_order))
+ output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
+ mem = 0
+
+ # state buffers
+ gru_a_state = torch.zeros((1, 1, self.gru_a_units))
+ gru_b_state = torch.zeros((1, 1, self.gru_b_units))
+ gru_states = [gru_a_state, gru_b_state]
+
+ input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
+
+ # push data to device
+ features = features.to(device)
+ periods = periods.to(device)
+ lpcs = lpcs.to(device)
+
+ # lpc weighting
+ weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
+ lpcs = lpcs * weights
+
+ # run feature encoding
+ c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
+
+ for frame_index in range(num_frames):
+ frame_start = frame_index * self.frame_size
+ pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
+ a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
+ current_c = c[:, frame_index : frame_index + 1, :]
+
+ for i in range(self.frame_size):
+ pcm_position = frame_start + i + lpc_order
+ output_position = frame_start + i
+
+ # prepare input
+ pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
+ if 'prediction' in self.input_layout['signals']:
+ input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
+
+ # run single step of sample rate network
+ probs, gru_states = self.decoder(
+ input_signals,
+ current_c,
+ gru_states
+ )
+
+ # sample from output
+ exc_ulaw = sample_excitation(probs, pitch_corr)
+
+ # signal generation
+ exc = ulaw2lin(exc_ulaw)
+ sig = exc + pred
+ pcm[pcm_position] = sig
+ mem = 0.85 * mem + float(sig)
+ output[output_position] = clip_to_int16(round(mem))
+
+ # buffer update
+ if 'last_signal' in self.input_layout['signals']:
+ input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
+
+ if 'last_error' in self.input_layout['signals']:
+ input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc)
+
+ return output
diff --git a/dnn/torch/lpcnet/models/multi_rate_lpcnet.py b/dnn/torch/lpcnet/models/multi_rate_lpcnet.py
new file mode 100644
index 00000000..c6850101
--- /dev/null
+++ b/dnn/torch/lpcnet/models/multi_rate_lpcnet.py
@@ -0,0 +1,408 @@
+import torch
+from torch import nn
+from utils.layers.subconditioner import get_subconditioner
+from utils.layers import DualFC
+
+from utils.ulaw import lin2ulawq, ulaw2lin
+from utils.sample import sample_excitation
+from utils.pcm import clip_to_int16
+from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
+
+from utils.misc import interleave_tensors
+
+
+
+
+# MultiRateLPCNet
+class MultiRateLPCNet(nn.Module):
+ def __init__(self, config):
+ super(MultiRateLPCNet, self).__init__()
+
+ # general parameters
+ self.input_layout = config['input_layout']
+ self.feature_history = config['feature_history']
+ self.feature_lookahead = config['feature_lookahead']
+ self.signals = config['signals']
+
+ # frame rate network parameters
+ self.feature_dimension = config['feature_dimension']
+ self.period_embedding_dim = config['period_embedding_dim']
+ self.period_levels = config['period_levels']
+ self.feature_channels = self.feature_dimension + self.period_embedding_dim
+ self.feature_conditioning_dim = config['feature_conditioning_dim']
+ self.feature_conv_kernel_size = config['feature_conv_kernel_size']
+
+ # frame rate network layers
+ self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
+ self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
+ self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
+
+ # sample rate network parameters
+ self.frame_size = config['frame_size']
+ self.signal_levels = config['signal_levels']
+ self.signal_embedding_dim = config['signal_embedding_dim']
+ self.gru_a_units = config['gru_a_units']
+ self.gru_b_units = config['gru_b_units']
+ self.output_levels = config['output_levels']
+
+ # subconditioning B
+ sub_config = config['subconditioning']['subconditioning_b']
+ self.substeps_b = sub_config['number_of_subsamples']
+ self.subcondition_signals_b = sub_config['signals']
+ self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
+ method = sub_config['method']
+ kwargs = sub_config['kwargs']
+ if type(kwargs) == type(None):
+ kwargs = dict()
+
+ state_size = self.gru_b_units
+ self.subconditioner_b = get_subconditioner(method,
+ sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
+ state_size, self.signal_levels, len(sub_config['signals']),
+ **sub_config['kwargs'])
+
+ # subconditioning A
+ sub_config = config['subconditioning']['subconditioning_a']
+ self.substeps_a = sub_config['number_of_subsamples']
+ self.subcondition_signals_a = sub_config['signals']
+ self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
+ method = sub_config['method']
+ kwargs = sub_config['kwargs']
+ if type(kwargs) == type(None):
+ kwargs = dict()
+
+ state_size = self.gru_a_units
+ self.subconditioner_a = get_subconditioner(method,
+ sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
+ state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
+ **sub_config['kwargs'])
+
+
+ # wrap up subconditioning, group_size_gru_a holds the number
+ # of timesteps that are grouped as sample input for GRU A
+ # input and group_size_subcondition_a holds the number of samples that are
+ # grouped as input to pre-GRU B subconditioning
+ self.group_size_gru_a = self.substeps_a * self.substeps_b
+ self.group_size_subcondition_a = self.substeps_b
+ self.gru_a_rate_divider = self.group_size_gru_a
+ self.gru_b_rate_divider = self.substeps_b
+
+ # gru sizes
+ self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
+ self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
+ self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
+
+ # sample rate network layers
+ self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
+ self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
+ self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
+
+ # sparsification
+ self.sparsifier = []
+
+ # GRU A
+ if 'gru_a' in config['sparsification']:
+ gru_config = config['sparsification']['gru_a']
+ task_list = [(self.gru_a, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
+ gru_config['params'], drop_input=True)
+ else:
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
+
+ # GRU B
+ if 'gru_b' in config['sparsification']:
+ gru_config = config['sparsification']['gru_b']
+ task_list = [(self.gru_b, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
+ gru_config['params'])
+ else:
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
+
+
+
+ # dual FCs
+ self.dual_fc = []
+ for i in range(self.substeps_b):
+ dim = self.subconditioner_b.get_output_dim(i)
+ self.dual_fc.append(DualFC(dim, self.output_levels))
+ self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
+
+ def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
+ gflops = 0
+
+ # frame rate network
+ conditioning_dim = self.feature_conditioning_dim
+ feature_channels = self.feature_channels
+ frame_rate = fs / self.frame_size
+ frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
+ if verbose:
+ print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
+ gflops += frame_rate_network_complexity
+
+ # gru a
+ gru_a_rate = fs / self.group_size_gru_a
+ gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
+ if verbose:
+ print(f"gru A: {gru_a_complexity} GFLOPS")
+ gflops += gru_a_complexity
+
+ # subconditioning a
+ subcond_a_rate = fs / self.substeps_b
+ subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
+ if verbose:
+ print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
+ gflops += subconditioning_a_complexity
+
+ # gru b
+ gru_b_rate = fs / self.substeps_b
+ gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
+ if verbose:
+ print(f"gru B: {gru_b_complexity} GFLOPS")
+ gflops += gru_b_complexity
+
+ # subconditioning b
+ subcond_b_rate = fs
+ subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
+ if verbose:
+ print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
+ gflops += subconditioning_b_complexity
+
+ # dual fcs
+ for i, fc in enumerate(self.dual_fc):
+ rate = fs / len(self.dual_fc)
+ input_size = fc.dense1.in_features
+ output_size = fc.dense1.out_features
+ dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
+ if hierarchical_sampling:
+ dual_fc_complexity /= 8
+ if verbose:
+ print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
+ gflops += dual_fc_complexity
+
+ if verbose:
+ print(f'total: {gflops} GFLOPS')
+
+ return gflops
+
+
+
+ def sparsify(self):
+ for sparsifier in self.sparsifier:
+ sparsifier.step()
+
+ def frame_rate_network(self, features, periods):
+
+ embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
+ features = torch.concat((features, embedded_periods), dim=-1)
+
+ # convert to channels first and calculate conditioning vector
+ c = torch.permute(features, [0, 2, 1])
+
+ c = torch.tanh(self.feature_conv1(c))
+ c = torch.tanh(self.feature_conv2(c))
+ # back to channels last
+ c = torch.permute(c, [0, 2, 1])
+ c = torch.tanh(self.feature_dense1(c))
+ c = torch.tanh(self.feature_dense2(c))
+
+ return c
+
+ def prepare_signals(self, signals, group_size, signal_idx):
+ """ extracts, delays and groups signals """
+
+ batch_size, sequence_length, num_signals = signals.shape
+
+ # extract signals according to position
+ signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
+ dim=-1)
+
+ # roll back pcm to account for grouping
+ signals = torch.roll(signals, group_size - 1, -2)
+
+ # reshape
+ signals = torch.reshape(signals,
+ (batch_size, sequence_length // group_size, group_size * len(signal_idx)))
+
+ return signals
+
+
+ def sample_rate_network(self, signals, c, gru_states):
+
+ signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
+ embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
+ # features at GRU A rate
+ c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
+ # features at GRU B rate
+ c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
+
+ y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ # first round of upsampling and subconditioning
+ c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
+ y = self.subconditioner_a(y, c_signals_a)
+ y = interleave_tensors(y)
+
+ y = torch.concat((y, c_upsampled_b), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+ c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
+ y = self.subconditioner_b(y, c_signals_b)
+
+ y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
+ y = interleave_tensors(y)
+
+ return y, (gru_a_state, gru_b_state)
+
+ def decoder(self, signals, c, gru_states):
+ embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
+
+ y = torch.concat((embedded_signals, c), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ y = torch.concat((y, c), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+
+ y = self.dual_fc(y)
+
+ return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
+
+ def forward(self, features, periods, signals, gru_states):
+
+ c = self.frame_rate_network(features, periods)
+ y, _ = self.sample_rate_network(signals, c, gru_states)
+ log_probs = torch.log_softmax(y, dim=-1)
+
+ return log_probs
+
+ def generate(self, features, periods, lpcs):
+
+ with torch.no_grad():
+ device = self.parameters().__next__().device
+
+ num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
+ lpc_order = lpcs.shape[-1]
+ num_input_signals = len(self.signals)
+ pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
+
+ # signal buffers
+ last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
+ prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
+ last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
+ output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
+ mem = 0
+
+ # state buffers
+ gru_a_state = torch.zeros((1, 1, self.gru_a_units))
+ gru_b_state = torch.zeros((1, 1, self.gru_b_units))
+
+ input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
+ # conditioning signals for subconditioner a
+ c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
+ # conditioning signals for subconditioner b
+ c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
+
+ # signal dict
+ signal_dict = {
+ 'prediction' : prediction,
+ 'last_error' : last_error,
+ 'last_signal' : last_signal
+ }
+
+ # push data to device
+ features = features.to(device)
+ periods = periods.to(device)
+ lpcs = lpcs.to(device)
+
+ # run feature encoding
+ c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
+
+ for frame_index in range(num_frames):
+ frame_start = frame_index * self.frame_size
+ pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
+ a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
+ current_c = c[:, frame_index : frame_index + 1, :]
+
+ for i in range(0, self.frame_size, self.group_size_gru_a):
+ pcm_position = frame_start + i + lpc_order
+ output_position = frame_start + i
+
+ # calculate newest prediction
+ prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
+
+ # prepare input
+ for slot in range(self.group_size_gru_a):
+ k = slot - self.group_size_gru_a + 1
+ for idx, name in enumerate(self.signals):
+ input_signals[idx + slot * num_input_signals] = lin2ulawq(
+ signal_dict[name][pcm_position + k]
+ )
+
+
+ # run GRU A
+ embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
+ embed_signals = torch.flatten(embed_signals, 2)
+ y = torch.cat((embed_signals, current_c), dim=-1)
+ h_a, gru_a_state = self.gru_a(y, gru_a_state)
+
+ # loop over substeps_a
+ for step_a in range(self.substeps_a):
+ # prepare conditioning input
+ for slot in range(self.group_size_subcondition_a):
+ k = slot - self.group_size_subcondition_a + 1
+ for idx, name in enumerate(self.subcondition_signals_a):
+ c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
+ signal_dict[name][pcm_position + k]
+ )
+
+ # subconditioning
+ h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
+
+ # run GRU B
+ y = torch.cat((h_a, current_c), dim=-1)
+ h_b, gru_b_state = self.gru_b(y, gru_b_state)
+
+ # loop over substeps b
+ for step_b in range(self.substeps_b):
+ # prepare subconditioning input
+ for idx, name in enumerate(self.subcondition_signals_b):
+ c_signals_b[idx] = lin2ulawq(
+ signal_dict[name][pcm_position]
+ )
+
+ # subcondition
+ h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
+
+ # run dual FC
+ probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
+
+ # sample
+ new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
+
+ # update signals
+ sig = new_exc + prediction[pcm_position]
+ last_error[pcm_position + 1] = new_exc
+ last_signal[pcm_position + 1] = sig
+
+ mem = 0.85 * mem + float(sig)
+ output[output_position] = clip_to_int16(round(mem))
+
+ # increase positions
+ pcm_position += 1
+ output_position += 1
+
+ # calculate next prediction
+ prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
+
+ return output
diff --git a/dnn/torch/lpcnet/print_lpcnet_complexity.py b/dnn/torch/lpcnet/print_lpcnet_complexity.py
new file mode 100644
index 00000000..a47352be
--- /dev/null
+++ b/dnn/torch/lpcnet/print_lpcnet_complexity.py
@@ -0,0 +1,35 @@
+import argparse
+
+import yaml
+
+from models import model_dict
+
+
+debug = False
+if debug:
+ args = type('dummy', (object,),
+ {
+ 'setup' : 'setups/lpcnet_m/setup_1_4_concatenative.yml',
+ 'hierarchical_sampling' : False
+ })()
+else:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('setup', type=str, help='setup yaml file')
+ parser.add_argument('--hierarchical-sampling', action="store_true", help='whether to assume hierarchical sampling (default=False)', default=False)
+
+ args = parser.parse_args()
+
+with open(args.setup, 'r') as f:
+ setup = yaml.load(f.read(), yaml.FullLoader)
+
+# check model
+if not 'model' in setup['lpcnet']:
+ print(f'warning: did not find model entry in setup, using default lpcnet')
+ model_name = 'lpcnet'
+else:
+ model_name = setup['lpcnet']['model']
+
+# create model
+model = model_dict[model_name](setup['lpcnet']['config'])
+
+gflops = model.get_gflops(16000, verbose=True, hierarchical_sampling=args.hierarchical_sampling)
diff --git a/dnn/torch/lpcnet/scripts/collect_multi_run_results.py b/dnn/torch/lpcnet/scripts/collect_multi_run_results.py
new file mode 100644
index 00000000..b772038d
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/collect_multi_run_results.py
@@ -0,0 +1,161 @@
+import argparse
+import os
+from uuid import UUID
+from collections import OrderedDict
+import pickle
+
+
+import torch
+import numpy as np
+
+import utils
+
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("input", type=str, help="input folder containing multi-run output")
+parser.add_argument("tag", type=str, help="tag for multi-run experiment")
+parser.add_argument("csv", type=str, help="name for output csv")
+
+
+def is_uuid(val):
+ try:
+ UUID(val)
+ return True
+ except:
+ return False
+
+
+def collect_results(folder):
+
+ training_folder = os.path.join(folder, 'training')
+ testing_folder = os.path.join(folder, 'testing')
+
+ # validation loss
+ checkpoint = torch.load(os.path.join(training_folder, 'checkpoints', 'checkpoint_finalize_epoch_1.pth'), map_location='cpu')
+ validation_loss = checkpoint['validation_loss']
+
+ # eval_warpq
+ eval_warpq = utils.data.parse_warpq_scores(os.path.join(training_folder, 'out_finalize.txt'))[-1]
+
+ # testing results
+ testing_results = utils.data.collect_test_stats(os.path.join(testing_folder, 'final'))
+
+ results = OrderedDict()
+ results['eval_loss'] = validation_loss
+ results['eval_warpq'] = eval_warpq
+ results['pesq_mean'] = testing_results['pesq'][0]
+ results['warpq_mean'] = testing_results['warpq'][0]
+ results['pitch_error_mean'] = testing_results['pitch_error'][0]
+ results['voicing_error_mean'] = testing_results['voicing_error'][0]
+
+ return results
+
+def print_csv(path, results, tag, ranks=None, header=True):
+
+ metrics = next(iter(results.values())).keys()
+ if ranks is not None:
+ rank_keys = next(iter(ranks.values())).keys()
+ else:
+ rank_keys = []
+
+ with open(path, 'w') as f:
+ if header:
+ f.write("uuid, tag")
+
+ for metric in metrics:
+ f.write(f", {metric}")
+
+ for rank in rank_keys:
+ f.write(f", {rank}")
+
+ f.write("\n")
+
+
+ for uuid, values in results.items():
+ f.write(f"{uuid}, {tag}")
+
+ for val in values.values():
+ f.write(f", {val:10.8f}")
+
+ for rank in rank_keys:
+ f.write(f", {ranks[uuid][rank]:4d}")
+
+ f.write("\n")
+
+def get_ranks(results):
+
+ metrics = list(next(iter(results.values())).keys())
+
+ positive = {'pesq_mean', 'mix'}
+
+ ranks = OrderedDict()
+ for key in results.keys():
+ ranks[key] = OrderedDict()
+
+ for metric in metrics:
+ sign = -1 if metric in positive else 1
+
+ x = sorted([(key, value[metric]) for key, value in results.items()], key=lambda x: sign * x[1])
+ x = [y[0] for y in x]
+
+ for key in results.keys():
+ ranks[key]['rank_' + metric] = x.index(key) + 1
+
+ return ranks
+
+def analyse_metrics(results):
+ metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
+
+ x = []
+ for metric in metrics:
+ x.append([val[metric] for val in results.values()])
+
+ x = np.array(x)
+
+ print(x)
+
+def add_mix_metric(results):
+ metrics = ['eval_loss', 'pesq_mean', 'warpq_mean', 'pitch_error_mean', 'voicing_error_mean']
+
+ x = []
+ for metric in metrics:
+ x.append([val[metric] for val in results.values()])
+
+ x = np.array(x).transpose() * np.array([-1, 1, -1, -1, -1])
+
+ z = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
+
+ print(f"covariance matrix for normalized scores of {metrics}:")
+ print(np.cov(z.transpose()))
+
+ score = np.mean(z, axis=1)
+
+ for i, key in enumerate(results.keys()):
+ results[key]['mix'] = score[i].item()
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ uuids = sorted([x for x in os.listdir(args.input) if os.path.isdir(os.path.join(args.input, x)) and is_uuid(x)])
+
+
+ results = OrderedDict()
+
+ for uuid in uuids:
+ results[uuid] = collect_results(os.path.join(args.input, uuid))
+
+
+ add_mix_metric(results)
+
+ ranks = get_ranks(results)
+
+
+
+ csv = args.csv if args.csv.endswith('.csv') else args.csv + '.csv'
+
+ print_csv(args.csv, results, args.tag, ranks=ranks)
+
+
+ with open(csv[:-4] + '.pickle', 'wb') as f:
+ pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file
diff --git a/dnn/torch/lpcnet/scripts/loop_run.sh b/dnn/torch/lpcnet/scripts/loop_run.sh
new file mode 100644
index 00000000..7250f639
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/loop_run.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+
+case $# in
+ 9) SETUP=$1; OUTDIR=$2; NAME=$3; DEVICE=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
+ *) echo "loop_run.sh setup outdir name device rounds lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
+esac
+
+
+PYTHON="/home/ubuntu/opt/miniconda3/envs/torch/bin/python"
+TESTFEATURES=${LPCNEXT}/testitems/features/all_0_orig_features.f32
+WARPQREFERENCE=${LPCNEXT}/testitems/wav/all_0_orig.wav
+METRICS="warpq,pesq,pitch_error,voicing_error"
+LPCNETDEMO=${LPCNET}/lpcnet_demo
+
+for ((round = 1; round <= $ROUNDS; round++))
+do
+ echo
+ echo round $round
+
+ UUID=$(uuidgen)
+ TRAINOUT=${OUTDIR}/${UUID}/training
+ TESTOUT=${OUTDIR}/${UUID}/testing
+ CHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_last.pth
+ FINALCHECKPOINT=${TRAINOUT}/checkpoints/checkpoint_finalize_last.pth
+
+ # run training
+ echo "starting training..."
+ $PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT --device $DEVICE --test-features $TESTFEATURES --warpq-reference $WARPQREFERENCE
+
+ # run finalization
+ echo "starting finalization..."
+ $PYTHON $LPCNEXT/train_lpcnet.py $SETUP $TRAINOUT \
+ --device $DEVICE --test-features $TESTFEATURES \
+ --warpq-reference $WARPQREFERENCE \
+ --finalize --initial-checkpoint $CHECKPOINT
+
+ # create test configs
+ $PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig.yml "$NAME $UUID" $CHECKPOINT --lpcnet-demo $LPCNETDEMO
+ $PYTHON $LPCNEXT/make_test_config.py ${OUTDIR}/${UUID}/testconfig_finalize.yml "$NAME $UUID finalized" $FINALCHECKPOINT --lpcnet-demo $LPCNETDEMO
+
+ # run tests
+ echo "starting test 1 (no finalization)..."
+ $PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig.yml \
+ $TESTITEMS ${TESTOUT}/prefinal --num-workers 8 \
+ --num-testitems 400 --metrics $METRICS
+
+ echo "starting test 2 (after finalization)..."
+ $PYTHON $TESTSUITE/run_test.py ${OUTDIR}/${UUID}/testconfig_finalize.yml \
+ $TESTITEMS ${TESTOUT}/final --num-workers 8 \
+ --num-testitems 400 --metrics $METRICS
+done
diff --git a/dnn/torch/lpcnet/scripts/make_animation.py b/dnn/torch/lpcnet/scripts/make_animation.py
new file mode 100644
index 00000000..a6e55472
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/make_animation.py
@@ -0,0 +1,37 @@
+""" script for creating animations from debug data
+
+"""
+
+
+import argparse
+
+
+import sys
+sys.path.append('./')
+
+from utils.endoscopy import make_animation, read_data
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('folder', type=str, help='endoscopy folder with debug output')
+parser.add_argument('output', type=str, help='output file (will be auto-extended with .mp4)')
+
+parser.add_argument('--start-index', type=int, help='index of first sample to be considered', default=0)
+parser.add_argument('--stop-index', type=int, help='index of last sample to be considered', default=-1)
+parser.add_argument('--interval', type=int, help='interval between frames in ms', default=20)
+parser.add_argument('--half-window-length', type=int, help='half size of window for displaying signals', default=80)
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ filename = args.output if args.output.endswith('.mp4') else args.output + '.mp4'
+ data = read_data(args.folder)
+
+ make_animation(
+ data,
+ filename,
+ start_index=args.start_index,
+ stop_index = args.stop_index,
+ half_signal_window_length=args.half_window_length
+ )
diff --git a/dnn/torch/lpcnet/scripts/modify_dataset_target.py b/dnn/torch/lpcnet/scripts/modify_dataset_target.py
new file mode 100644
index 00000000..a70fe169
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/modify_dataset_target.py
@@ -0,0 +1,17 @@
+import argparse
+
+import numpy as np
+
+
+parser = argparse.ArgumentParser(description="sets s_t to augmented_s_t")
+
+parser.add_argument('datafile', type=str, help='data.s16 file path')
+
+args = parser.parse_args()
+
+data = np.memmap(args.datafile, dtype='int16', mode='readwrite')
+
+# signal is in data[1::2]
+# last augmented signal is in data[0::2]
+
+data[1 : - 1 : 2] = data[2 : : 2]
diff --git a/dnn/torch/lpcnet/scripts/multi_run.sh b/dnn/torch/lpcnet/scripts/multi_run.sh
new file mode 100644
index 00000000..fb0fee14
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/multi_run.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+case $# in
+ 9) SETUP=$1; OUTDIR=$2; NAME=$3; NUMDEVICES=$4; ROUNDS=$5; LPCNEXT=$6; LPCNET=$7; TESTSUITE=$8; TESTITEMS=$9;;
+ *) echo "multi_run.sh setup outdir name num_devices rounds_per_device lpcnext_repo lpcnet_repo testsuite_repo testitems"; exit;;
+esac
+
+
+LOOPRUN=${LPCNEXT}/loop_run.sh
+
+mkdir -p $OUTDIR
+
+for ((i = 0; i < $NUMDEVICES; i++))
+do
+ echo "launching job queue for device $i"
+ nohup bash $LOOPRUN $SETUP $OUTDIR "$NAME" "cuda:$i" $ROUNDS $LPCNEXT $LPCNET $TESTSUITE $TESTITEMS > $OUTDIR/job_${i}_out.txt &
+done
diff --git a/dnn/torch/lpcnet/scripts/run_inference_test.sh b/dnn/torch/lpcnet/scripts/run_inference_test.sh
new file mode 100644
index 00000000..9f22b03d
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/run_inference_test.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+
+case $# in
+ 3) FEATURES=$1; FOLDER=$2; PYTHON=$3;;
+ *) echo "run_inference_test.sh <features file> <output folder> <python path>"; exit;;
+esac
+
+
+SCRIPTFOLDER=$(dirname "$0")
+
+mkdir -p $FOLDER/inference_test
+
+# update checkpoints
+for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
+do
+ tmp=$(basename $fn)
+ tmp=${tmp%.pth}
+ epoch=${tmp#checkpoint_epoch_}
+ echo "running inference with checkpoint $fn..."
+ $PYTHON $SCRIPTFOLDER/../test_lpcnet.py $FEATURES $fn $FOLDER/inference_test/output_epoch_${epoch}.wav
+done
diff --git a/dnn/torch/lpcnet/scripts/update_checkpoints.py b/dnn/torch/lpcnet/scripts/update_checkpoints.py
new file mode 100644
index 00000000..989f6164
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/update_checkpoints.py
@@ -0,0 +1,25 @@
+""" script for updating checkpoints with new setup entries
+
+ Use this script to update older outputs with newly introduced
+ parameters. (Saves us the trouble of backward compatibility)
+"""
+
+
+import argparse
+
+import torch
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint_file', type=str, help='checkpoint to be updated')
+parser.add_argument('--model', type=str, help='model update', default=None)
+
+args = parser.parse_args()
+
+checkpoint = torch.load(args.checkpoint_file, map_location='cpu')
+
+# update model entry
+if type(args.model) != type(None):
+ checkpoint['setup']['lpcnet']['model'] = args.model
+
+torch.save(checkpoint, args.checkpoint_file) \ No newline at end of file
diff --git a/dnn/torch/lpcnet/scripts/update_output_folder.sh b/dnn/torch/lpcnet/scripts/update_output_folder.sh
new file mode 100644
index 00000000..487d4a2d
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/update_output_folder.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+
+case $# in
+ 3) FOLDER=$1; MODEL=$2; PYTHON=$3;;
+ *) echo "update_output_folder.sh folder model python"; exit;;
+esac
+
+
+SCRIPTFOLDER=$(dirname "$0")
+
+
+# update setup
+echo "updating $FOLDER/setup.py..."
+$PYTHON $SCRIPTFOLDER/update_setups.py $FOLDER/setup.yml --model $MODEL
+
+# update checkpoints
+for fn in $(find $FOLDER -type f -name "checkpoint*.pth")
+do
+ echo "updating $fn..."
+ $PYTHON $SCRIPTFOLDER/update_checkpoints.py $fn --model $MODEL
+done \ No newline at end of file
diff --git a/dnn/torch/lpcnet/scripts/update_setups.py b/dnn/torch/lpcnet/scripts/update_setups.py
new file mode 100644
index 00000000..7f8261a0
--- /dev/null
+++ b/dnn/torch/lpcnet/scripts/update_setups.py
@@ -0,0 +1,28 @@
+""" script for updating setup files with new setup entries
+
+ Use this script to update older outputs with newly introduced
+ parameters. (Saves us the trouble of backward compatibility)
+"""
+
+import argparse
+
+import yaml
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('setup_file', type=str, help='setup to be updated')
+parser.add_argument('--model', type=str, help='model update', default=None)
+
+args = parser.parse_args()
+
+# load setup
+with open(args.setup_file, 'r') as f:
+ setup = yaml.load(f.read(), yaml.FullLoader)
+
+# update model entry
+if type(args.model) != type(None):
+ setup['lpcnet']['model'] = args.model
+
+# dump result
+with open(args.setup_file, 'w') as f:
+ yaml.dump(setup, f)
diff --git a/dnn/torch/lpcnet/test_lpcnet.py b/dnn/torch/lpcnet/test_lpcnet.py
new file mode 100644
index 00000000..c57266dd
--- /dev/null
+++ b/dnn/torch/lpcnet/test_lpcnet.py
@@ -0,0 +1,60 @@
+import argparse
+
+import torch
+import numpy as np
+
+
+from models import model_dict
+from utils.data import load_features
+from utils.wav import wavwrite16
+
+debug = False
+if debug:
+ args = type('dummy', (object,),
+ {
+ 'features' : 'features.f32',
+ 'checkpoint' : 'checkpoint.pth',
+ 'output' : 'out.wav',
+ 'version' : 2
+ })()
+else:
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('features', type=str, help='feature file')
+ parser.add_argument('checkpoint', type=str, help='checkpoint file')
+ parser.add_argument('output', type=str, help='output file')
+ parser.add_argument('--version', type=int, help='feature version', default=2)
+
+ args = parser.parse_args()
+
+
+torch.set_num_threads(2)
+
+version = args.version
+feature_file = args.features
+checkpoint_file = args.checkpoint
+
+
+
+output_file = args.output
+if not output_file.endswith('.wav'):
+ output_file += '.wav'
+
+checkpoint = torch.load(checkpoint_file, map_location="cpu")
+
+# check model
+if not 'model' in checkpoint['setup']['lpcnet']:
+ print(f'warning: did not find model entry in setup, using default lpcnet')
+ model_name = 'lpcnet'
+else:
+ model_name = checkpoint['setup']['lpcnet']['model']
+
+model = model_dict[model_name](checkpoint['setup']['lpcnet']['config'])
+
+model.load_state_dict(checkpoint['state_dict'])
+
+data = load_features(feature_file)
+
+output = model.generate(data['features'], data['periods'], data['lpcs'])
+
+wavwrite16(output_file, output.numpy(), 16000)
diff --git a/dnn/torch/lpcnet/train_lpcnet.py b/dnn/torch/lpcnet/train_lpcnet.py
new file mode 100644
index 00000000..1bef7112
--- /dev/null
+++ b/dnn/torch/lpcnet/train_lpcnet.py
@@ -0,0 +1,243 @@
+import os
+import argparse
+import sys
+
+try:
+ import git
+ has_git = True
+except:
+ has_git = False
+
+import yaml
+
+
+import torch
+from torch.optim.lr_scheduler import LambdaLR
+
+from data import LPCNetDataset
+from models import model_dict
+from engine.lpcnet_engine import train_one_epoch, evaluate
+from utils.data import load_features
+from utils.wav import wavwrite16
+
+
+debug = False
+if debug:
+ args = type('dummy', (object,),
+ {
+ 'setup' : 'setup.yml',
+ 'output' : 'testout',
+ 'device' : None,
+ 'test_features' : None,
+ 'finalize': False,
+ 'initial_checkpoint': None,
+ 'no-redirect': False
+ })()
+else:
+ parser = argparse.ArgumentParser("train_lpcnet.py")
+ parser.add_argument('setup', type=str, help='setup yaml file')
+ parser.add_argument('output', type=str, help='output path')
+ parser.add_argument('--device', type=str, help='compute device', default=None)
+ parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
+ parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
+ parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
+ parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')
+
+ args = parser.parse_args()
+
+
+torch.set_num_threads(4)
+
+with open(args.setup, 'r') as f:
+ setup = yaml.load(f.read(), yaml.FullLoader)
+
+if args.finalize:
+ if args.initial_checkpoint is None:
+ raise ValueError('finalization requires initial checkpoint')
+
+ if 'sparsification' in setup['lpcnet']['config']:
+ for sp_job in setup['lpcnet']['config']['sparsification'].values():
+ sp_job['start'], sp_job['stop'] = 0, 0
+
+ setup['training']['lr'] = 1.0e-5
+ setup['training']['lr_decay_factor'] = 0.0
+ setup['training']['epochs'] = 1
+
+ checkpoint_prefix = 'checkpoint_finalize'
+ output_prefix = 'output_finalize'
+ setup_name = 'setup_finalize.yml'
+ output_file='out_finalize.txt'
+else:
+ checkpoint_prefix = 'checkpoint'
+ output_prefix = 'output'
+ setup_name = 'setup.yml'
+ output_file='out.txt'
+
+
+# check model
+if not 'model' in setup['lpcnet']:
+ print(f'warning: did not find model entry in setup, using default lpcnet')
+ model_name = 'lpcnet'
+else:
+ model_name = setup['lpcnet']['model']
+
+# prepare output folder
+if os.path.exists(args.output) and not debug and not args.finalize:
+ print("warning: output folder exists")
+
+ reply = input('continue? (y/n): ')
+ while reply not in {'y', 'n'}:
+ reply = input('continue? (y/n): ')
+
+ if reply == 'n':
+ os._exit()
+else:
+ os.makedirs(args.output, exist_ok=True)
+
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+
+# add repo info to setup
+if has_git:
+ working_dir = os.path.split(__file__)[0]
+ try:
+ repo = git.Repo(working_dir)
+ setup['repo'] = dict()
+ hash = repo.head.object.hexsha
+ urls = list(repo.remote().urls)
+ is_dirty = repo.is_dirty()
+
+ if is_dirty:
+ print("warning: repo is dirty")
+
+ setup['repo']['hash'] = hash
+ setup['repo']['urls'] = urls
+ setup['repo']['dirty'] = is_dirty
+ except:
+ has_git = False
+
+# dump setup
+with open(os.path.join(args.output, setup_name), 'w') as f:
+ yaml.dump(setup, f)
+
+# prepare inference test if wanted
+run_inference_test = False
+if type(args.test_features) != type(None):
+ test_features = load_features(args.test_features)
+ inference_test_dir = os.path.join(args.output, 'inference_test')
+ os.makedirs(inference_test_dir, exist_ok=True)
+ run_inference_test = True
+
+# training parameters
+batch_size = setup['training']['batch_size']
+epochs = setup['training']['epochs']
+lr = setup['training']['lr']
+lr_decay_factor = setup['training']['lr_decay_factor']
+
+# load training dataset
+lpcnet_config = setup['lpcnet']['config']
+data = LPCNetDataset( setup['dataset'],
+ features=lpcnet_config['features'],
+ input_signals=lpcnet_config['signals'],
+ target=lpcnet_config['target'],
+ frames_per_sample=setup['training']['frames_per_sample'],
+ feature_history=lpcnet_config['feature_history'],
+ feature_lookahead=lpcnet_config['feature_lookahead'],
+ lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
+
+# load validation dataset if given
+if 'validation_dataset' in setup:
+ validation_data = LPCNetDataset( setup['validation_dataset'],
+ features=lpcnet_config['features'],
+ input_signals=lpcnet_config['signals'],
+ target=lpcnet_config['target'],
+ frames_per_sample=setup['training']['frames_per_sample'],
+ feature_history=lpcnet_config['feature_history'],
+ feature_lookahead=lpcnet_config['feature_lookahead'],
+ lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
+
+ validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
+
+ run_validation = True
+else:
+ run_validation = False
+
+# create model
+model = model_dict[model_name](setup['lpcnet']['config'])
+
+if args.initial_checkpoint is not None:
+ print(f"loading state dict from {args.initial_checkpoint}...")
+ chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
+ model.load_state_dict(chkpt['state_dict'])
+
+# set compute device
+if type(args.device) == type(None):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+else:
+ device = torch.device(args.device)
+
+# push model to device
+model.to(device)
+
+# dataloader
+dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
+
+# optimizer is introduced to trainable parameters
+parameters = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(parameters, lr=lr)
+
+# learning rate scheduler
+scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+# loss
+criterion = torch.nn.NLLLoss()
+
+# model checkpoint
+checkpoint = {
+ 'setup' : setup,
+ 'state_dict' : model.state_dict(),
+ 'loss' : -1
+}
+
+if not args.no_redirect:
+ print(f"re-directing output to {os.path.join(args.output, output_file)}")
+ sys.stdout = open(os.path.join(args.output, output_file), "w")
+
+best_loss = 1e9
+
+for ep in range(1, epochs + 1):
+ print(f"training epoch {ep}...")
+ new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
+
+
+ # save checkpoint
+ checkpoint['state_dict'] = model.state_dict()
+ checkpoint['loss'] = new_loss
+
+ if run_validation:
+ print("running validation...")
+ validation_loss = evaluate(model, criterion, validation_dataloader, device)
+ checkpoint['validation_loss'] = validation_loss
+
+ if validation_loss < best_loss:
+ torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
+ best_loss = validation_loss
+
+ torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
+ torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
+
+ # run inference test
+ if run_inference_test:
+ model.to("cpu")
+ print("running inference test...")
+
+ output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])
+
+ testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')
+
+ wavwrite16(testfilename, output.numpy(), 16000)
+
+ model.to(device)
+
+ print()
diff --git a/dnn/torch/lpcnet/utils/__init__.py b/dnn/torch/lpcnet/utils/__init__.py
new file mode 100644
index 00000000..edbbe02c
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/__init__.py
@@ -0,0 +1,4 @@
+from . import sparsification
+from . import data
+from . import pcm
+from . import sample \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/data.py b/dnn/torch/lpcnet/utils/data.py
new file mode 100644
index 00000000..b8e7c612
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/data.py
@@ -0,0 +1,112 @@
+import os
+
+import torch
+import numpy as np
+
+def load_features(feature_file, version=2):
+ if version == 2:
+ layout = {
+ 'cepstrum': [0,18],
+ 'periods': [18, 19],
+ 'pitch_corr': [19, 20],
+ 'lpc': [20, 36]
+ }
+ frame_length = 36
+
+ elif version == 1:
+ layout = {
+ 'cepstrum': [0,18],
+ 'periods': [36, 37],
+ 'pitch_corr': [37, 38],
+ 'lpc': [39, 55],
+ }
+ frame_length = 55
+ else:
+ raise ValueError(f'unknown feature version: {version}')
+
+
+ raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
+ raw_features = raw_features.reshape((-1, frame_length))
+
+ features = torch.cat(
+ [
+ raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
+ raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
+ ],
+ dim=1
+ )
+
+ lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
+ periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
+
+ return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
+
+
+
+def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
+ ref_data = np.memmap(reference_data_path, dtype=np.int16)
+ signal = np.memmap(signal_path, dtype=np.int16)
+
+ signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
+ signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
+
+
+ assert len(signal) % 160 == 0
+ num_frames = len(signal) // 160
+ mem = np.zeros(1)
+ for fr in range(len(signal)//160):
+ signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
+ mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
+
+ new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
+
+ new_data[:] = 0
+ N = len(signal) - offset
+ new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
+ new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
+
+
+def parse_warpq_scores(output_file):
+ """ extracts warpq scores from output file """
+
+ with open(output_file, "r") as f:
+ lines = f.readlines()
+
+ scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
+
+ return scores
+
+
+def parse_stats_file(file):
+
+ with open(file, "r") as f:
+ lines = f.readlines()
+
+ mean = float(lines[0].split(":")[-1])
+ bt_mean = float(lines[1].split(":")[-1])
+ top_mean = float(lines[2].split(":")[-1])
+
+ return mean, bt_mean, top_mean
+
+def collect_test_stats(test_folder):
+ """ collects statistics for all discovered metrics from test folder """
+
+ metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
+
+ results = dict()
+
+ content = os.listdir(test_folder)
+
+ stats_files = [file for file in content if file.startswith('stats_')]
+
+ for file in stats_files:
+ metric = file[len("stats_") : -len(".txt")]
+
+ if metric not in metrics:
+ print(f"warning: unknown metric {metric}")
+
+ mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
+
+ results[metric] = [mean, bt_mean, top_mean]
+
+ return results
diff --git a/dnn/torch/lpcnet/utils/endoscopy.py b/dnn/torch/lpcnet/utils/endoscopy.py
new file mode 100644
index 00000000..05dd4750
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/endoscopy.py
@@ -0,0 +1,205 @@
+""" module for inspecting models during inference """
+
+import os
+
+import yaml
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+
+import torch
+import numpy as np
+
+# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
+_state = dict()
+_folder = 'endoscopy'
+
+def get_gru_gates(gru, input, state):
+ hidden_size = gru.hidden_size
+
+ direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
+ recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
+
+ # reset gate
+ start, stop = 0 * hidden_size, 1 * hidden_size
+ reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+ # update gate
+ start, stop = 1 * hidden_size, 2 * hidden_size
+ update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+ # new gate
+ start, stop = 2 * hidden_size, 3 * hidden_size
+ new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
+
+ return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
+
+
+def init(folder='endoscopy'):
+ """ sets up output folder for endoscopy data """
+
+ global _folder
+ _folder = folder
+
+ if not os.path.exists(folder):
+ os.makedirs(folder)
+ else:
+ print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
+
+def write_data(key, data, fs):
+ """ appends data to previous data written under key """
+
+ global _state
+
+ # convert to numpy if torch.Tensor is given
+ if isinstance(data, torch.Tensor):
+ data = data.detach().numpy()
+
+ if not key in _state:
+ _state[key] = {
+ 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
+ 'fs' : fs,
+ 'dim' : tuple(data.shape),
+ 'dtype' : str(data.dtype)
+ }
+
+ with open(os.path.join(_folder, key + '.yml'), 'w') as f:
+ f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
+ else:
+ if _state[key]['fs'] != fs:
+ raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
+ if _state[key]['dtype'] != str(data.dtype):
+ raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
+ if _state[key]['dim'] != tuple(data.shape):
+ raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
+
+ _state[key]['fid'].write(data.tobytes())
+
+def close(folder='endoscopy'):
+ """ clean up """
+ for key in _state.keys():
+ _state[key]['fid'].close()
+
+
+def read_data(folder='endoscopy'):
+ """ retrieves written data as numpy arrays """
+
+
+ keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
+
+ return_dict = dict()
+
+ for key in keys:
+ with open(os.path.join(folder, key + '.yml'), 'r') as f:
+ value = yaml.load(f.read(), yaml.FullLoader)
+
+ with open(os.path.join(folder, key + '.bin'), 'rb') as f:
+ data = np.frombuffer(f.read(), dtype=value['dtype'])
+
+ value['data'] = data.reshape((-1,) + value['dim'])
+
+ return_dict[key] = value
+
+ return return_dict
+
+def get_best_reshape(shape, target_ratio=1):
+ """ calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
+
+ if len(shape) > 1:
+ pixel_count = 1
+ for s in shape:
+ pixel_count *= s
+ else:
+ pixel_count = shape[0]
+
+ if pixel_count == 1:
+ return (1,)
+
+ num_columns = int((pixel_count / target_ratio)**.5)
+
+ while (pixel_count % num_columns):
+ num_columns -= 1
+
+ num_rows = pixel_count // num_columns
+
+ return (num_rows, num_columns)
+
+def get_type_and_shape(shape):
+
+ # can happen if data is one dimensional
+ if len(shape) == 0:
+ shape = (1,)
+
+ # calculate pixel count
+ if len(shape) > 1:
+ pixel_count = 1
+ for s in shape:
+ pixel_count *= s
+ else:
+ pixel_count = shape[0]
+
+ if pixel_count == 1:
+ return 'plot', (1, )
+
+ # stay with shape if already 2-dimensional
+ if len(shape) == 2:
+ if (shape[0] != pixel_count) or (shape[1] != pixel_count):
+ return 'image', shape
+
+ return 'image', get_best_reshape(shape)
+
+def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
+
+ # determine plot setup
+ num_keys = len(data.keys())
+
+ num_rows = int((num_keys * 3/4) ** .5)
+
+ num_cols = (num_keys + num_rows - 1) // num_rows
+
+ fig, axs = plt.subplots(num_rows, num_cols)
+ fig.set_size_inches(num_cols * 5, num_rows * 5)
+
+ display = dict()
+
+ fs_max = max([val['fs'] for val in data.values()])
+
+ num_samples = max([val['data'].shape[0] for val in data.values()])
+
+ keys = sorted(data.keys())
+
+ # inspect data
+ for i, key in enumerate(keys):
+ axs[i // num_cols, i % num_cols].title.set_text(key)
+
+ display[key] = dict()
+
+ display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
+ display[key]['down_factor'] = data[key]['fs'] / fs_max
+
+ start_index = max(start_index, half_signal_window_length)
+ while stop_index < 0:
+ stop_index += num_samples
+
+ stop_index = min(stop_index, num_samples - half_signal_window_length)
+
+ # actual plotting
+ frames = []
+ for index in range(start_index, stop_index):
+ ims = []
+ for i, key in enumerate(keys):
+ feature_index = int(round(index * display[key]['down_factor']))
+
+ if display[key]['type'] == 'plot':
+ ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
+
+ elif display[key]['type'] == 'image':
+ ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
+
+ frames.append(ims)
+
+ ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
+
+ if not filename.endswith('.mp4'):
+ filename += '.mp4'
+
+ ani.save(filename) \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/layers/__init__.py b/dnn/torch/lpcnet/utils/layers/__init__.py
new file mode 100644
index 00000000..4a58f221
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/__init__.py
@@ -0,0 +1,3 @@
+from .dual_fc import DualFC
+from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
+from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/layers/dual_fc.py b/dnn/torch/lpcnet/utils/layers/dual_fc.py
new file mode 100644
index 00000000..ed10a5c6
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/dual_fc.py
@@ -0,0 +1,15 @@
+import torch
+from torch import nn
+
+class DualFC(nn.Module):
+ def __init__(self, input_dim, output_dim):
+ super(DualFC, self).__init__()
+
+ self.dense1 = nn.Linear(input_dim, output_dim)
+ self.dense2 = nn.Linear(input_dim, output_dim)
+
+ self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
+ self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
+
+ def forward(self, x):
+ return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
diff --git a/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py b/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
new file mode 100644
index 00000000..12835f89
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
@@ -0,0 +1,42 @@
+""" module implementing PCM embeddings for LPCNet """
+
+import math as m
+
+import torch
+from torch import nn
+
+
+class PCMEmbedding(nn.Module):
+ def __init__(self, embed_dim=128, num_levels=256):
+ super(PCMEmbedding, self).__init__()
+
+ self.embed_dim = embed_dim
+ self.num_levels = num_levels
+
+ self.embedding = nn.Embedding(self.num_levels, self.num_dim)
+
+ # initialize
+ with torch.no_grad():
+ num_rows, num_cols = self.num_levels, self.embed_dim
+ a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
+ for i in range(num_rows):
+ a[i, :] += m.sqrt(12) * (i - num_rows / 2)
+ self.embedding.weight[:, :] = 0.1 * a
+
+ def forward(self, x):
+ return self.embeddint(x)
+
+
+class DifferentiablePCMEmbedding(PCMEmbedding):
+ def __init__(self, embed_dim, num_levels=256):
+ super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
+
+ def forward(self, x):
+ x_int = (x - torch.floor(x)).detach().long()
+ x_frac = x - x_int
+ x_next = torch.minimum(x_int + 1, self.num_levels)
+
+ embed_0 = self.embedding(x_int)
+ embed_1 = self.embedding(x_next)
+
+ return (1 - x_frac) * embed_0 + x_frac * embed_1
diff --git a/dnn/torch/lpcnet/utils/layers/subconditioner.py b/dnn/torch/lpcnet/utils/layers/subconditioner.py
new file mode 100644
index 00000000..87189cd5
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/subconditioner.py
@@ -0,0 +1,468 @@
+from re import sub
+import torch
+from torch import nn
+
+
+
+
+def get_subconditioner( method,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ **kwargs):
+
+ subconditioner_dict = {
+ 'additive' : AdditiveSubconditioner,
+ 'concatenative' : ConcatenativeSubconditioner,
+ 'modulative' : ModulativeSubconditioner
+ }
+
+ return subconditioner_dict[method](number_of_subsamples,
+ pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
+
+
+class Subconditioner(nn.Module):
+ def __init__(self):
+ """ upsampling by subconditioning
+
+ Upsamples a sequence of states conditioning on pcm signals and
+ optionally a feature vector.
+ """
+ super(Subconditioner, self).__init__()
+
+ def forward(self, states, signals, features=None):
+ raise Exception("Base class should not be called")
+
+ def single_step(self, index, state, signals, features):
+ raise Exception("Base class should not be called")
+
+ def get_output_dim(self, index):
+ raise Exception("Base class should not be called")
+
+
+class AdditiveSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ **kwargs):
+ """ subconditioning by addition """
+
+ super(AdditiveSubconditioner, self).__init__()
+
+ self.number_of_subsamples = number_of_subsamples
+ self.pcm_embedding_size = pcm_embedding_size
+ self.state_size = state_size
+ self.pcm_levels = pcm_levels
+ self.number_of_signals = number_of_signals
+
+ if self.pcm_embedding_size != self.state_size:
+ raise ValueError('For additive subconditioning state and embedding '
+ + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
+
+ self.embeddings = [None]
+ for i in range(1, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ def forward(self, states, signals):
+ """ creates list of subconditioned states
+
+ Parameters:
+ -----------
+ states : torch.tensor
+ states of shape (batch, seq_length // s, state_size)
+ signals : torch.tensor
+ signals of shape (batch, seq_length, number_of_signals)
+
+ Returns:
+ --------
+ c_states : list of torch.tensor
+ list of s subconditioned states
+ """
+
+ s = self.number_of_subsamples
+
+ c_states = [states]
+ new_states = states
+ for i in range(1, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.sum(embed, dim=2)
+
+ new_states = new_states + embed
+ c_states.append(new_states)
+
+ return c_states
+
+ def single_step(self, index, state, signals):
+ """ carry out single step for inference
+
+ Parameters:
+ -----------
+ index : int
+ position in subconditioning batch
+
+ state : torch.tensor
+ state to sub-condition
+
+ signals : torch.tensor
+ signals for subconditioning, all but the last dimensions
+ must match those of state
+
+ Returns:
+ c_state : torch.tensor
+ subconditioned state
+ """
+
+ if index == 0:
+ c_state = state
+ else:
+ embed_signals = self.embeddings[index](signals)
+ c = torch.sum(embed_signals, dim=-2)
+ c_state = state + c
+
+ return c_state
+
+ def get_output_dim(self, index):
+ return self.state_size
+
+ def get_average_flops_per_step(self):
+ s = self.number_of_subsamples
+ flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
+ return flops
+
+
+class ConcatenativeSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ recurrent=True,
+ **kwargs):
+ """ subconditioning by concatenation """
+
+ super(ConcatenativeSubconditioner, self).__init__()
+
+ self.number_of_subsamples = number_of_subsamples
+ self.pcm_embedding_size = pcm_embedding_size
+ self.state_size = state_size
+ self.pcm_levels = pcm_levels
+ self.number_of_signals = number_of_signals
+ self.recurrent = recurrent
+
+ self.embeddings = []
+ start_index = 0
+ if self.recurrent:
+ start_index = 1
+ self.embeddings.append(None)
+
+ for i in range(start_index, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ def forward(self, states, signals):
+ """ creates list of subconditioned states
+
+ Parameters:
+ -----------
+ states : torch.tensor
+ states of shape (batch, seq_length // s, state_size)
+ signals : torch.tensor
+ signals of shape (batch, seq_length, number_of_signals)
+
+ Returns:
+ --------
+ c_states : list of torch.tensor
+ list of s subconditioned states
+ """
+ s = self.number_of_subsamples
+
+ if self.recurrent:
+ c_states = [states]
+ start = 1
+ else:
+ c_states = []
+ start = 0
+
+ new_states = states
+ for i in range(start, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.flatten(embed, -2)
+
+ if self.recurrent:
+ new_states = torch.cat((new_states, embed), dim=-1)
+ else:
+ new_states = torch.cat((states, embed), dim=-1)
+
+ c_states.append(new_states)
+
+ return c_states
+
+ def single_step(self, index, state, signals):
+ """ carry out single step for inference
+
+ Parameters:
+ -----------
+ index : int
+ position in subconditioning batch
+
+ state : torch.tensor
+ state to sub-condition
+
+ signals : torch.tensor
+ signals for subconditioning, all but the last dimensions
+ must match those of state
+
+ Returns:
+ c_state : torch.tensor
+ subconditioned state
+ """
+
+ if index == 0 and self.recurrent:
+ c_state = state
+ else:
+ embed_signals = self.embeddings[index](signals)
+ c = torch.flatten(embed_signals, -2)
+ if not self.recurrent and index > 0:
+ # overwrite previous conditioning vector
+ c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
+ else:
+ c_state = torch.cat((state, c), dim=-1)
+ return c_state
+
+ return c_state
+
+ def get_average_flops_per_step(self):
+ return 0
+
+ def get_output_dim(self, index):
+ if self.recurrent:
+ return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
+ else:
+ return self.state_size + self.pcm_embedding_size * self.number_of_signals
+
+class ModulativeSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ state_recurrent=False,
+ **kwargs):
+ """ subconditioning by modulation """
+
+ super(ModulativeSubconditioner, self).__init__()
+
+ self.number_of_subsamples = number_of_subsamples
+ self.pcm_embedding_size = pcm_embedding_size
+ self.state_size = state_size
+ self.pcm_levels = pcm_levels
+ self.number_of_signals = number_of_signals
+ self.state_recurrent = state_recurrent
+
+ self.hidden_size = self.pcm_embedding_size * self.number_of_signals
+
+ if self.state_recurrent:
+ self.hidden_size += self.pcm_embedding_size
+ self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
+
+ self.embeddings = [None]
+ self.alphas = [None]
+ self.betas = [None]
+
+ for i in range(1, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
+ self.add_module('alpha_dense_' + str(i), self.alphas[-1])
+
+ self.betas.append(nn.Linear(self.hidden_size, self.state_size))
+ self.add_module('beta_dense_' + str(i), self.betas[-1])
+
+
+
+ def forward(self, states, signals):
+ """ creates list of subconditioned states
+
+ Parameters:
+ -----------
+ states : torch.tensor
+ states of shape (batch, seq_length // s, state_size)
+ signals : torch.tensor
+ signals of shape (batch, seq_length, number_of_signals)
+
+ Returns:
+ --------
+ c_states : list of torch.tensor
+ list of s subconditioned states
+ """
+ s = self.number_of_subsamples
+
+ c_states = [states]
+ new_states = states
+ for i in range(1, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.flatten(embed, -2)
+
+ if self.state_recurrent:
+ comp_states = self.state_transform(new_states)
+ embed = torch.cat((embed, comp_states), dim=-1)
+
+ alpha = torch.tanh(self.alphas[i](embed))
+ beta = torch.tanh(self.betas[i](embed))
+
+ # new state obtained by modulating previous state
+ new_states = torch.tanh((1 + alpha) * new_states + beta)
+
+ c_states.append(new_states)
+
+ return c_states
+
+ def single_step(self, index, state, signals):
+ """ carry out single step for inference
+
+ Parameters:
+ -----------
+ index : int
+ position in subconditioning batch
+
+ state : torch.tensor
+ state to sub-condition
+
+ signals : torch.tensor
+ signals for subconditioning, all but the last dimensions
+ must match those of state
+
+ Returns:
+ c_state : torch.tensor
+ subconditioned state
+ """
+
+ if index == 0:
+ c_state = state
+ else:
+ embed_signals = self.embeddings[index](signals)
+ c = torch.flatten(embed_signals, -2)
+ if self.state_recurrent:
+ r_state = self.state_transform(state)
+ c = torch.cat((c, r_state), dim=-1)
+ alpha = torch.tanh(self.alphas[index](c))
+ beta = torch.tanh(self.betas[index](c))
+ c_state = torch.tanh((1 + alpha) * state + beta)
+ return c_state
+
+ return c_state
+
+ def get_output_dim(self, index):
+ return self.state_size
+
+ def get_average_flops_per_step(self):
+ s = self.number_of_subsamples
+
+ # estimate activation by 10 flops
+ # c_state = torch.tanh((1 + alpha) * state + beta)
+ flops = 13 * self.state_size
+
+ # hidden size
+ hidden_size = self.number_of_signals * self.pcm_embedding_size
+ if self.state_recurrent:
+ hidden_size += self.pcm_embedding_size
+
+ # counting 2 * A * B flops for Linear(A, B)
+ # alpha = torch.tanh(self.alphas[index](c))
+ # beta = torch.tanh(self.betas[index](c))
+ flops += 4 * hidden_size * self.state_size + 20 * self.state_size
+
+ # r_state = self.state_transform(state)
+ if self.state_recurrent:
+ flops += 2 * self.state_size * self.pcm_embedding_size
+
+ # average over steps
+ flops *= (s - 1) / s
+
+ return flops
+
+class ComparitiveSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ error_index=-1,
+ apply_gate=True,
+ normalize=False):
+ """ subconditioning by comparison """
+
+ super(ComparitiveSubconditioner, self).__init__()
+
+ self.comparison_size = self.pcm_embedding_size
+ self.error_position = error_index
+ self.apply_gate = apply_gate
+ self.normalize = normalize
+
+ self.state_transform = nn.Linear(self.state_size, self.comparison_size)
+
+ self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
+ self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
+
+ if self.apply_gate:
+ self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
+
+ # embeddings and state transforms
+ self.embeddings = [None]
+ self.alpha_denses = [None]
+ self.beta_denses = [None]
+ self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
+ self.add_module('state_transform_0', self.state_transforms[0])
+
+ for i in range(1, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ state_transform = nn.Linear(self.state_size, self.comparison_size)
+ self.add_module('state_transform_' + str(i), state_transform)
+ self.state_transforms.append(state_transform)
+
+ self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
+ self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
+
+ self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
+ self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
+
+ def forward(self, states, signals):
+ s = self.number_of_subsamples
+
+ c_states = [states]
+ new_states = states
+ for i in range(1, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.flatten(embed, -2)
+
+ comp_states = self.state_transforms[i](new_states)
+
+ alpha = torch.tanh(self.alpha_dense(embed))
+ beta = torch.tanh(self.beta_dense(embed))
+
+ # new state obtained by modulating previous state
+ new_states = torch.tanh((1 + alpha) * comp_states + beta)
+
+ c_states.append(new_states)
+
+ return c_states
diff --git a/dnn/torch/lpcnet/utils/misc.py b/dnn/torch/lpcnet/utils/misc.py
new file mode 100644
index 00000000..dab4837f
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/misc.py
@@ -0,0 +1,36 @@
+import torch
+
+
+def find(a, v):
+ try:
+ idx = a.index(v)
+ except:
+ idx = -1
+ return idx
+
+def interleave_tensors(tensors, dim=-2):
+ """ interleave list of tensors along sequence dimension """
+
+ x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
+ x = torch.flatten(x, dim - 1, dim)
+
+ return x
+
+def _interleave(x, pcm_levels=256):
+
+ repeats = pcm_levels // (2*x.size(-1))
+ x = x.unsqueeze(-1)
+ p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
+
+ return p
+
+def get_pdf_from_tree(x):
+ pcm_levels = x.size(-1)
+
+ p = _interleave(x[..., 1:2])
+ n = 4
+ while n <= pcm_levels:
+ p = p * _interleave(x[..., n//2:n])
+ n *= 2
+
+ return p \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/pcm.py b/dnn/torch/lpcnet/utils/pcm.py
new file mode 100644
index 00000000..608e40d7
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/pcm.py
@@ -0,0 +1,6 @@
+
+def clip_to_int16(x):
+ int_min = -2**15
+ int_max = 2**15 - 1
+ x_clipped = max(int_min, min(x, int_max))
+ return x_clipped
diff --git a/dnn/torch/lpcnet/utils/sample.py b/dnn/torch/lpcnet/utils/sample.py
new file mode 100644
index 00000000..14e1cd19
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sample.py
@@ -0,0 +1,15 @@
+import torch
+
+
+def sample_excitation(probs, pitch_corr):
+
+ norm = lambda x : x / (x.sum() + 1e-18)
+
+ # lowering the temperature
+ probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
+ # cut-off tails
+ probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
+ # sample
+ exc = torch.multinomial(probs.squeeze(), 1)
+
+ return exc
diff --git a/dnn/torch/lpcnet/utils/sparsification/__init__.py b/dnn/torch/lpcnet/utils/sparsification/__init__.py
new file mode 100644
index 00000000..ebfa9d9a
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/__init__.py
@@ -0,0 +1,2 @@
+from .gru_sparsifier import GRUSparsifier
+from .common import sparsify_matrix, calculate_gru_flops_per_step \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py
new file mode 100644
index 00000000..34989d4b
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/common.py
@@ -0,0 +1,92 @@
+import torch
+
+def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
+ """ sparsifies matrix with specified block size
+
+ Parameters:
+ -----------
+ matrix : torch.tensor
+ matrix to sparsify
+ density : int
+ target density
+ block_size : [int, int]
+ block size dimensions
+ keep_diagonal : bool
+ If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
+ """
+
+ m, n = matrix.shape
+ m1, n1 = block_size
+
+ if m % m1 or n % n1:
+ raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
+
+ # extract diagonal if keep_diagonal = True
+ if keep_diagonal:
+ if m != n:
+ raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
+
+ to_spare = torch.diag(torch.diag(matrix))
+ matrix = matrix - to_spare
+ else:
+ to_spare = torch.zeros_like(matrix)
+
+ # calculate energy in sub-blocks
+ x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
+ x = x ** 2
+ block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
+
+ number_of_blocks = (m * n) // (m1 * n1)
+ number_of_survivors = round(number_of_blocks * density)
+
+ # masking threshold
+ if number_of_survivors == 0:
+ threshold = 0
+ else:
+ threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
+
+ # create mask
+ mask = torch.ones_like(block_energies)
+ mask[block_energies < threshold] = 0
+ mask = torch.repeat_interleave(mask, m1, dim=0)
+ mask = torch.repeat_interleave(mask, n1, dim=1)
+
+ # perform masking
+ masked_matrix = mask * matrix + to_spare
+
+ if return_mask:
+ return masked_matrix, mask
+ else:
+ return masked_matrix
+
+def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
+ input_size = gru.input_size
+ hidden_size = gru.hidden_size
+ flops = 0
+
+ input_density = (
+ sparsification_dict.get('W_ir', [1])[0]
+ + sparsification_dict.get('W_in', [1])[0]
+ + sparsification_dict.get('W_iz', [1])[0]
+ ) / 3
+
+ recurrent_density = (
+ sparsification_dict.get('W_hr', [1])[0]
+ + sparsification_dict.get('W_hn', [1])[0]
+ + sparsification_dict.get('W_hz', [1])[0]
+ ) / 3
+
+ # input matrix vector multiplications
+ if not drop_input:
+ flops += 2 * 3 * input_size * hidden_size * input_density
+
+ # recurrent matrix vector multiplications
+ flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
+
+ # biases
+ flops += 6 * hidden_size
+
+ # activations estimated by 10 flops per activation
+ flops += 30 * hidden_size
+
+ return flops \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
new file mode 100644
index 00000000..865f3a7d
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
@@ -0,0 +1,158 @@
+import torch
+
+from .common import sparsify_matrix
+
+
+class GRUSparsifier:
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
+ of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
+ 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
+ update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
+ should be kept.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> gru = torch.nn.GRU(10, 20)
+ >>> sparsify_dict = {
+ ... 'W_ir' : (0.5, [2, 2], False),
+ ... 'W_iz' : (0.6, [2, 2], False),
+ ... 'W_in' : (0.7, [2, 2], False),
+ ... 'W_hr' : (0.1, [4, 4], True),
+ ... 'W_hz' : (0.2, [4, 4], True),
+ ... 'W_hn' : (0.3, [4, 4], True),
+ ... }
+ >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+ # just copying parameters...
+ self.start = start
+ self.stop = stop
+ self.interval = interval
+ self.exponent = exponent
+ self.task_list = task_list
+
+ # ... and setting counter to 0
+ self.step_counter = 0
+
+ self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
+
+ def step(self, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+ # compute current interpolation factor
+ self.step_counter += 1
+
+ if self.step_counter < self.start:
+ return
+ elif self.step_counter < self.stop:
+ # update only every self.interval-th interval
+ if self.step_counter % self.interval:
+ return
+
+ alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
+ else:
+ alpha = 0
+
+
+ with torch.no_grad():
+ for gru, params in self.task_list:
+ hidden_size = gru.hidden_size
+
+ # input weights
+ for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
+ if key in params:
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+
+ gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
+ density, # density
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
+ print(f"sparsification mask {key} changed for gru {gru}")
+
+ self.last_masks[key] = new_mask
+
+ # recurrent weights
+ for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
+ if key in params:
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+ gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
+ density,
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
+ print(f"sparsification mask {key} changed for gru {gru}")
+
+ self.last_masks[key] = new_mask
+
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ gru = torch.nn.GRU(10, 20)
+ sparsify_dict = {
+ 'W_ir' : (0.5, [2, 2], False),
+ 'W_iz' : (0.6, [2, 2], False),
+ 'W_in' : (0.7, [2, 2], False),
+ 'W_hr' : (0.1, [4, 4], True),
+ 'W_hz' : (0.2, [4, 4], True),
+ 'W_hn' : (0.3, [4, 4], True),
+ }
+
+ sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
diff --git a/dnn/torch/lpcnet/utils/templates.py b/dnn/torch/lpcnet/utils/templates.py
new file mode 100644
index 00000000..d399f57c
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/templates.py
@@ -0,0 +1,128 @@
+from models import multi_rate_lpcnet
+import copy
+
+setup_dict = dict()
+
+dataset_template_v2 = {
+ 'version' : 2,
+ 'feature_file' : 'features.f32',
+ 'signal_file' : 'data.s16',
+ 'frame_length' : 160,
+ 'feature_frame_length' : 36,
+ 'signal_frame_length' : 2,
+ 'feature_dtype' : 'float32',
+ 'signal_dtype' : 'int16',
+ 'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
+ 'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
+}
+
+dataset_template_v1 = {
+ 'version' : 1,
+ 'feature_file' : 'features.f32',
+ 'signal_file' : 'data.u8',
+ 'frame_length' : 160,
+ 'feature_frame_length' : 55,
+ 'signal_frame_length' : 4,
+ 'feature_dtype' : 'float32',
+ 'signal_dtype' : 'uint8',
+ 'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
+ 'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
+}
+
+# lpcnet
+
+lpcnet_config = {
+ 'frame_size' : 160,
+ 'gru_a_units' : 384,
+ 'gru_b_units' : 64,
+ 'feature_conditioning_dim' : 128,
+ 'feature_conv_kernel_size' : 3,
+ 'period_levels' : 257,
+ 'period_embedding_dim' : 64,
+ 'signal_embedding_dim' : 128,
+ 'signal_levels' : 256,
+ 'feature_dimension' : 19,
+ 'output_levels' : 256,
+ 'lpc_gamma' : 0.9,
+ 'features' : ['cepstrum', 'periods', 'pitch_corr'],
+ 'signals' : ['last_signal', 'prediction', 'last_error'],
+ 'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
+ 'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
+ 'target' : 'error',
+ 'feature_history' : 2,
+ 'feature_lookahead' : 2,
+ 'sparsification' : {
+ 'gru_a' : {
+ 'start' : 10000,
+ 'stop' : 30000,
+ 'interval' : 100,
+ 'exponent' : 3,
+ 'params' : {
+ 'W_hr' : (0.05, [4, 8], True),
+ 'W_hz' : (0.05, [4, 8], True),
+ 'W_hn' : (0.2, [4, 8], True)
+ },
+ },
+ 'gru_b' : {
+ 'start' : 10000,
+ 'stop' : 30000,
+ 'interval' : 100,
+ 'exponent' : 3,
+ 'params' : {
+ 'W_ir' : (0.5, [4, 8], False),
+ 'W_iz' : (0.5, [4, 8], False),
+ 'W_in' : (0.5, [4, 8], False)
+ },
+ }
+ },
+ 'add_reference_phase' : False,
+ 'reference_phase_dim' : 0
+}
+
+
+
+# multi rate
+subconditioning = {
+ 'subconditioning_a' : {
+ 'number_of_subsamples' : 2,
+ 'method' : 'modulative',
+ 'signals' : ['last_signal', 'prediction', 'last_error'],
+ 'pcm_embedding_size' : 64,
+ 'kwargs' : dict()
+
+ },
+ 'subconditioning_b' : {
+ 'number_of_subsamples' : 2,
+ 'method' : 'modulative',
+ 'signals' : ['last_signal', 'prediction', 'last_error'],
+ 'pcm_embedding_size' : 64,
+ 'kwargs' : dict()
+ }
+}
+
+multi_rate_lpcnet_config = lpcnet_config.copy()
+multi_rate_lpcnet_config['subconditioning'] = subconditioning
+
+training_default = {
+ 'batch_size' : 256,
+ 'epochs' : 20,
+ 'lr' : 1e-3,
+ 'lr_decay_factor' : 2.5e-5,
+ 'adam_betas' : [0.9, 0.99],
+ 'frames_per_sample' : 15
+}
+
+lpcnet_setup = {
+ 'dataset' : '/local/datasets/lpcnet_training',
+ 'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
+ 'training' : training_default
+}
+
+multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
+multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
+multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
+
+setup_dict = {
+ 'lpcnet' : lpcnet_setup,
+ 'multi_rate' : multi_rate_lpcnet_setup
+}
diff --git a/dnn/torch/lpcnet/utils/ulaw.py b/dnn/torch/lpcnet/utils/ulaw.py
new file mode 100644
index 00000000..1a9f9e47
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/ulaw.py
@@ -0,0 +1,29 @@
+import math as m
+
+import torch
+
+
+
+def ulaw2lin(u):
+ scale_1 = 32768.0 / 255.0
+ u = u - 128
+ s = torch.sign(u)
+ u = torch.abs(u)
+ return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
+
+
+def lin2ulawq(x):
+ scale = 255.0 / 32768.0
+ s = torch.sign(x)
+ x = torch.abs(x)
+ u = s * (128 * torch.log(1 + scale * x) / m.log(256))
+ u = torch.clip(128 + torch.round(u), 0, 255)
+ return u
+
+def lin2ulaw(x):
+ scale = 255.0 / 32768.0
+ s = torch.sign(x)
+ x = torch.abs(x)
+ u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
+ u = torch.clip(128 + u, 0, 255)
+ return u \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/wav.py b/dnn/torch/lpcnet/utils/wav.py
new file mode 100644
index 00000000..3ed811f5
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/wav.py
@@ -0,0 +1,14 @@
+import wave
+
+def wavwrite16(filename, x, fs):
+ """ writes x as int16 to file with name filename
+
+ If x.dtype is int16 x is written as is. Otherwise,
+ it is scaled by 2**15 - 1 and converted to int16.
+ """
+ if x.dtype != 'int16':
+ x = ((2**15 - 1) * x).astype('int16')
+
+ with wave.open(filename, 'wb') as f:
+ f.setparams((1, 2, fs, len(x), 'NONE', ""))
+ f.writeframes(x.tobytes()) \ No newline at end of file