Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/lpcnet/utils')
-rw-r--r--dnn/torch/lpcnet/utils/__init__.py4
-rw-r--r--dnn/torch/lpcnet/utils/data.py112
-rw-r--r--dnn/torch/lpcnet/utils/endoscopy.py205
-rw-r--r--dnn/torch/lpcnet/utils/layers/__init__.py3
-rw-r--r--dnn/torch/lpcnet/utils/layers/dual_fc.py15
-rw-r--r--dnn/torch/lpcnet/utils/layers/pcm_embeddings.py42
-rw-r--r--dnn/torch/lpcnet/utils/layers/subconditioner.py468
-rw-r--r--dnn/torch/lpcnet/utils/misc.py36
-rw-r--r--dnn/torch/lpcnet/utils/pcm.py6
-rw-r--r--dnn/torch/lpcnet/utils/sample.py15
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/__init__.py2
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/common.py92
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py158
-rw-r--r--dnn/torch/lpcnet/utils/templates.py128
-rw-r--r--dnn/torch/lpcnet/utils/ulaw.py29
-rw-r--r--dnn/torch/lpcnet/utils/wav.py14
16 files changed, 1329 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/utils/__init__.py b/dnn/torch/lpcnet/utils/__init__.py
new file mode 100644
index 00000000..edbbe02c
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/__init__.py
@@ -0,0 +1,4 @@
+from . import sparsification
+from . import data
+from . import pcm
+from . import sample \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/data.py b/dnn/torch/lpcnet/utils/data.py
new file mode 100644
index 00000000..b8e7c612
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/data.py
@@ -0,0 +1,112 @@
+import os
+
+import torch
+import numpy as np
+
+def load_features(feature_file, version=2):
+ if version == 2:
+ layout = {
+ 'cepstrum': [0,18],
+ 'periods': [18, 19],
+ 'pitch_corr': [19, 20],
+ 'lpc': [20, 36]
+ }
+ frame_length = 36
+
+ elif version == 1:
+ layout = {
+ 'cepstrum': [0,18],
+ 'periods': [36, 37],
+ 'pitch_corr': [37, 38],
+ 'lpc': [39, 55],
+ }
+ frame_length = 55
+ else:
+ raise ValueError(f'unknown feature version: {version}')
+
+
+ raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
+ raw_features = raw_features.reshape((-1, frame_length))
+
+ features = torch.cat(
+ [
+ raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],
+ raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
+ ],
+ dim=1
+ )
+
+ lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]
+ periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
+
+ return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
+
+
+
+def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
+ ref_data = np.memmap(reference_data_path, dtype=np.int16)
+ signal = np.memmap(signal_path, dtype=np.int16)
+
+ signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
+ signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
+
+
+ assert len(signal) % 160 == 0
+ num_frames = len(signal) // 160
+ mem = np.zeros(1)
+ for fr in range(len(signal)//160):
+ signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
+ mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
+
+ new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
+
+ new_data[:] = 0
+ N = len(signal) - offset
+ new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
+ new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
+
+
+def parse_warpq_scores(output_file):
+ """ extracts warpq scores from output file """
+
+ with open(output_file, "r") as f:
+ lines = f.readlines()
+
+ scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
+
+ return scores
+
+
+def parse_stats_file(file):
+
+ with open(file, "r") as f:
+ lines = f.readlines()
+
+ mean = float(lines[0].split(":")[-1])
+ bt_mean = float(lines[1].split(":")[-1])
+ top_mean = float(lines[2].split(":")[-1])
+
+ return mean, bt_mean, top_mean
+
+def collect_test_stats(test_folder):
+ """ collects statistics for all discovered metrics from test folder """
+
+ metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
+
+ results = dict()
+
+ content = os.listdir(test_folder)
+
+ stats_files = [file for file in content if file.startswith('stats_')]
+
+ for file in stats_files:
+ metric = file[len("stats_") : -len(".txt")]
+
+ if metric not in metrics:
+ print(f"warning: unknown metric {metric}")
+
+ mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
+
+ results[metric] = [mean, bt_mean, top_mean]
+
+ return results
diff --git a/dnn/torch/lpcnet/utils/endoscopy.py b/dnn/torch/lpcnet/utils/endoscopy.py
new file mode 100644
index 00000000..05dd4750
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/endoscopy.py
@@ -0,0 +1,205 @@
+""" module for inspecting models during inference """
+
+import os
+
+import yaml
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+
+import torch
+import numpy as np
+
+# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
+_state = dict()
+_folder = 'endoscopy'
+
+def get_gru_gates(gru, input, state):
+ hidden_size = gru.hidden_size
+
+ direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
+ recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
+
+ # reset gate
+ start, stop = 0 * hidden_size, 1 * hidden_size
+ reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+ # update gate
+ start, stop = 1 * hidden_size, 2 * hidden_size
+ update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
+
+ # new gate
+ start, stop = 2 * hidden_size, 3 * hidden_size
+ new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop]))
+
+ return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
+
+
+def init(folder='endoscopy'):
+ """ sets up output folder for endoscopy data """
+
+ global _folder
+ _folder = folder
+
+ if not os.path.exists(folder):
+ os.makedirs(folder)
+ else:
+ print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
+
+def write_data(key, data, fs):
+ """ appends data to previous data written under key """
+
+ global _state
+
+ # convert to numpy if torch.Tensor is given
+ if isinstance(data, torch.Tensor):
+ data = data.detach().numpy()
+
+ if not key in _state:
+ _state[key] = {
+ 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'),
+ 'fs' : fs,
+ 'dim' : tuple(data.shape),
+ 'dtype' : str(data.dtype)
+ }
+
+ with open(os.path.join(_folder, key + '.yml'), 'w') as f:
+ f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
+ else:
+ if _state[key]['fs'] != fs:
+ raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
+ if _state[key]['dtype'] != str(data.dtype):
+ raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
+ if _state[key]['dim'] != tuple(data.shape):
+ raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
+
+ _state[key]['fid'].write(data.tobytes())
+
+def close(folder='endoscopy'):
+ """ clean up """
+ for key in _state.keys():
+ _state[key]['fid'].close()
+
+
+def read_data(folder='endoscopy'):
+ """ retrieves written data as numpy arrays """
+
+
+ keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
+
+ return_dict = dict()
+
+ for key in keys:
+ with open(os.path.join(folder, key + '.yml'), 'r') as f:
+ value = yaml.load(f.read(), yaml.FullLoader)
+
+ with open(os.path.join(folder, key + '.bin'), 'rb') as f:
+ data = np.frombuffer(f.read(), dtype=value['dtype'])
+
+ value['data'] = data.reshape((-1,) + value['dim'])
+
+ return_dict[key] = value
+
+ return return_dict
+
+def get_best_reshape(shape, target_ratio=1):
+ """ calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
+
+ if len(shape) > 1:
+ pixel_count = 1
+ for s in shape:
+ pixel_count *= s
+ else:
+ pixel_count = shape[0]
+
+ if pixel_count == 1:
+ return (1,)
+
+ num_columns = int((pixel_count / target_ratio)**.5)
+
+ while (pixel_count % num_columns):
+ num_columns -= 1
+
+ num_rows = pixel_count // num_columns
+
+ return (num_rows, num_columns)
+
+def get_type_and_shape(shape):
+
+ # can happen if data is one dimensional
+ if len(shape) == 0:
+ shape = (1,)
+
+ # calculate pixel count
+ if len(shape) > 1:
+ pixel_count = 1
+ for s in shape:
+ pixel_count *= s
+ else:
+ pixel_count = shape[0]
+
+ if pixel_count == 1:
+ return 'plot', (1, )
+
+ # stay with shape if already 2-dimensional
+ if len(shape) == 2:
+ if (shape[0] != pixel_count) or (shape[1] != pixel_count):
+ return 'image', shape
+
+ return 'image', get_best_reshape(shape)
+
+def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
+
+ # determine plot setup
+ num_keys = len(data.keys())
+
+ num_rows = int((num_keys * 3/4) ** .5)
+
+ num_cols = (num_keys + num_rows - 1) // num_rows
+
+ fig, axs = plt.subplots(num_rows, num_cols)
+ fig.set_size_inches(num_cols * 5, num_rows * 5)
+
+ display = dict()
+
+ fs_max = max([val['fs'] for val in data.values()])
+
+ num_samples = max([val['data'].shape[0] for val in data.values()])
+
+ keys = sorted(data.keys())
+
+ # inspect data
+ for i, key in enumerate(keys):
+ axs[i // num_cols, i % num_cols].title.set_text(key)
+
+ display[key] = dict()
+
+ display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
+ display[key]['down_factor'] = data[key]['fs'] / fs_max
+
+ start_index = max(start_index, half_signal_window_length)
+ while stop_index < 0:
+ stop_index += num_samples
+
+ stop_index = min(stop_index, num_samples - half_signal_window_length)
+
+ # actual plotting
+ frames = []
+ for index in range(start_index, stop_index):
+ ims = []
+ for i, key in enumerate(keys):
+ feature_index = int(round(index * display[key]['down_factor']))
+
+ if display[key]['type'] == 'plot':
+ ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
+
+ elif display[key]['type'] == 'image':
+ ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
+
+ frames.append(ims)
+
+ ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
+
+ if not filename.endswith('.mp4'):
+ filename += '.mp4'
+
+ ani.save(filename) \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/layers/__init__.py b/dnn/torch/lpcnet/utils/layers/__init__.py
new file mode 100644
index 00000000..4a58f221
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/__init__.py
@@ -0,0 +1,3 @@
+from .dual_fc import DualFC
+from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner
+from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/layers/dual_fc.py b/dnn/torch/lpcnet/utils/layers/dual_fc.py
new file mode 100644
index 00000000..ed10a5c6
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/dual_fc.py
@@ -0,0 +1,15 @@
+import torch
+from torch import nn
+
+class DualFC(nn.Module):
+ def __init__(self, input_dim, output_dim):
+ super(DualFC, self).__init__()
+
+ self.dense1 = nn.Linear(input_dim, output_dim)
+ self.dense2 = nn.Linear(input_dim, output_dim)
+
+ self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
+ self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
+
+ def forward(self, x):
+ return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x))
diff --git a/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py b/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
new file mode 100644
index 00000000..12835f89
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py
@@ -0,0 +1,42 @@
+""" module implementing PCM embeddings for LPCNet """
+
+import math as m
+
+import torch
+from torch import nn
+
+
+class PCMEmbedding(nn.Module):
+ def __init__(self, embed_dim=128, num_levels=256):
+ super(PCMEmbedding, self).__init__()
+
+ self.embed_dim = embed_dim
+ self.num_levels = num_levels
+
+ self.embedding = nn.Embedding(self.num_levels, self.num_dim)
+
+ # initialize
+ with torch.no_grad():
+ num_rows, num_cols = self.num_levels, self.embed_dim
+ a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5)
+ for i in range(num_rows):
+ a[i, :] += m.sqrt(12) * (i - num_rows / 2)
+ self.embedding.weight[:, :] = 0.1 * a
+
+ def forward(self, x):
+ return self.embeddint(x)
+
+
+class DifferentiablePCMEmbedding(PCMEmbedding):
+ def __init__(self, embed_dim, num_levels=256):
+ super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels)
+
+ def forward(self, x):
+ x_int = (x - torch.floor(x)).detach().long()
+ x_frac = x - x_int
+ x_next = torch.minimum(x_int + 1, self.num_levels)
+
+ embed_0 = self.embedding(x_int)
+ embed_1 = self.embedding(x_next)
+
+ return (1 - x_frac) * embed_0 + x_frac * embed_1
diff --git a/dnn/torch/lpcnet/utils/layers/subconditioner.py b/dnn/torch/lpcnet/utils/layers/subconditioner.py
new file mode 100644
index 00000000..87189cd5
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/layers/subconditioner.py
@@ -0,0 +1,468 @@
+from re import sub
+import torch
+from torch import nn
+
+
+
+
+def get_subconditioner( method,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ **kwargs):
+
+ subconditioner_dict = {
+ 'additive' : AdditiveSubconditioner,
+ 'concatenative' : ConcatenativeSubconditioner,
+ 'modulative' : ModulativeSubconditioner
+ }
+
+ return subconditioner_dict[method](number_of_subsamples,
+ pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
+
+
+class Subconditioner(nn.Module):
+ def __init__(self):
+ """ upsampling by subconditioning
+
+ Upsamples a sequence of states conditioning on pcm signals and
+ optionally a feature vector.
+ """
+ super(Subconditioner, self).__init__()
+
+ def forward(self, states, signals, features=None):
+ raise Exception("Base class should not be called")
+
+ def single_step(self, index, state, signals, features):
+ raise Exception("Base class should not be called")
+
+ def get_output_dim(self, index):
+ raise Exception("Base class should not be called")
+
+
+class AdditiveSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ **kwargs):
+ """ subconditioning by addition """
+
+ super(AdditiveSubconditioner, self).__init__()
+
+ self.number_of_subsamples = number_of_subsamples
+ self.pcm_embedding_size = pcm_embedding_size
+ self.state_size = state_size
+ self.pcm_levels = pcm_levels
+ self.number_of_signals = number_of_signals
+
+ if self.pcm_embedding_size != self.state_size:
+ raise ValueError('For additive subconditioning state and embedding '
+ + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
+
+ self.embeddings = [None]
+ for i in range(1, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ def forward(self, states, signals):
+ """ creates list of subconditioned states
+
+ Parameters:
+ -----------
+ states : torch.tensor
+ states of shape (batch, seq_length // s, state_size)
+ signals : torch.tensor
+ signals of shape (batch, seq_length, number_of_signals)
+
+ Returns:
+ --------
+ c_states : list of torch.tensor
+ list of s subconditioned states
+ """
+
+ s = self.number_of_subsamples
+
+ c_states = [states]
+ new_states = states
+ for i in range(1, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.sum(embed, dim=2)
+
+ new_states = new_states + embed
+ c_states.append(new_states)
+
+ return c_states
+
+ def single_step(self, index, state, signals):
+ """ carry out single step for inference
+
+ Parameters:
+ -----------
+ index : int
+ position in subconditioning batch
+
+ state : torch.tensor
+ state to sub-condition
+
+ signals : torch.tensor
+ signals for subconditioning, all but the last dimensions
+ must match those of state
+
+ Returns:
+ c_state : torch.tensor
+ subconditioned state
+ """
+
+ if index == 0:
+ c_state = state
+ else:
+ embed_signals = self.embeddings[index](signals)
+ c = torch.sum(embed_signals, dim=-2)
+ c_state = state + c
+
+ return c_state
+
+ def get_output_dim(self, index):
+ return self.state_size
+
+ def get_average_flops_per_step(self):
+ s = self.number_of_subsamples
+ flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
+ return flops
+
+
+class ConcatenativeSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ recurrent=True,
+ **kwargs):
+ """ subconditioning by concatenation """
+
+ super(ConcatenativeSubconditioner, self).__init__()
+
+ self.number_of_subsamples = number_of_subsamples
+ self.pcm_embedding_size = pcm_embedding_size
+ self.state_size = state_size
+ self.pcm_levels = pcm_levels
+ self.number_of_signals = number_of_signals
+ self.recurrent = recurrent
+
+ self.embeddings = []
+ start_index = 0
+ if self.recurrent:
+ start_index = 1
+ self.embeddings.append(None)
+
+ for i in range(start_index, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ def forward(self, states, signals):
+ """ creates list of subconditioned states
+
+ Parameters:
+ -----------
+ states : torch.tensor
+ states of shape (batch, seq_length // s, state_size)
+ signals : torch.tensor
+ signals of shape (batch, seq_length, number_of_signals)
+
+ Returns:
+ --------
+ c_states : list of torch.tensor
+ list of s subconditioned states
+ """
+ s = self.number_of_subsamples
+
+ if self.recurrent:
+ c_states = [states]
+ start = 1
+ else:
+ c_states = []
+ start = 0
+
+ new_states = states
+ for i in range(start, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.flatten(embed, -2)
+
+ if self.recurrent:
+ new_states = torch.cat((new_states, embed), dim=-1)
+ else:
+ new_states = torch.cat((states, embed), dim=-1)
+
+ c_states.append(new_states)
+
+ return c_states
+
+ def single_step(self, index, state, signals):
+ """ carry out single step for inference
+
+ Parameters:
+ -----------
+ index : int
+ position in subconditioning batch
+
+ state : torch.tensor
+ state to sub-condition
+
+ signals : torch.tensor
+ signals for subconditioning, all but the last dimensions
+ must match those of state
+
+ Returns:
+ c_state : torch.tensor
+ subconditioned state
+ """
+
+ if index == 0 and self.recurrent:
+ c_state = state
+ else:
+ embed_signals = self.embeddings[index](signals)
+ c = torch.flatten(embed_signals, -2)
+ if not self.recurrent and index > 0:
+ # overwrite previous conditioning vector
+ c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
+ else:
+ c_state = torch.cat((state, c), dim=-1)
+ return c_state
+
+ return c_state
+
+ def get_average_flops_per_step(self):
+ return 0
+
+ def get_output_dim(self, index):
+ if self.recurrent:
+ return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
+ else:
+ return self.state_size + self.pcm_embedding_size * self.number_of_signals
+
+class ModulativeSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ state_recurrent=False,
+ **kwargs):
+ """ subconditioning by modulation """
+
+ super(ModulativeSubconditioner, self).__init__()
+
+ self.number_of_subsamples = number_of_subsamples
+ self.pcm_embedding_size = pcm_embedding_size
+ self.state_size = state_size
+ self.pcm_levels = pcm_levels
+ self.number_of_signals = number_of_signals
+ self.state_recurrent = state_recurrent
+
+ self.hidden_size = self.pcm_embedding_size * self.number_of_signals
+
+ if self.state_recurrent:
+ self.hidden_size += self.pcm_embedding_size
+ self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
+
+ self.embeddings = [None]
+ self.alphas = [None]
+ self.betas = [None]
+
+ for i in range(1, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
+ self.add_module('alpha_dense_' + str(i), self.alphas[-1])
+
+ self.betas.append(nn.Linear(self.hidden_size, self.state_size))
+ self.add_module('beta_dense_' + str(i), self.betas[-1])
+
+
+
+ def forward(self, states, signals):
+ """ creates list of subconditioned states
+
+ Parameters:
+ -----------
+ states : torch.tensor
+ states of shape (batch, seq_length // s, state_size)
+ signals : torch.tensor
+ signals of shape (batch, seq_length, number_of_signals)
+
+ Returns:
+ --------
+ c_states : list of torch.tensor
+ list of s subconditioned states
+ """
+ s = self.number_of_subsamples
+
+ c_states = [states]
+ new_states = states
+ for i in range(1, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.flatten(embed, -2)
+
+ if self.state_recurrent:
+ comp_states = self.state_transform(new_states)
+ embed = torch.cat((embed, comp_states), dim=-1)
+
+ alpha = torch.tanh(self.alphas[i](embed))
+ beta = torch.tanh(self.betas[i](embed))
+
+ # new state obtained by modulating previous state
+ new_states = torch.tanh((1 + alpha) * new_states + beta)
+
+ c_states.append(new_states)
+
+ return c_states
+
+ def single_step(self, index, state, signals):
+ """ carry out single step for inference
+
+ Parameters:
+ -----------
+ index : int
+ position in subconditioning batch
+
+ state : torch.tensor
+ state to sub-condition
+
+ signals : torch.tensor
+ signals for subconditioning, all but the last dimensions
+ must match those of state
+
+ Returns:
+ c_state : torch.tensor
+ subconditioned state
+ """
+
+ if index == 0:
+ c_state = state
+ else:
+ embed_signals = self.embeddings[index](signals)
+ c = torch.flatten(embed_signals, -2)
+ if self.state_recurrent:
+ r_state = self.state_transform(state)
+ c = torch.cat((c, r_state), dim=-1)
+ alpha = torch.tanh(self.alphas[index](c))
+ beta = torch.tanh(self.betas[index](c))
+ c_state = torch.tanh((1 + alpha) * state + beta)
+ return c_state
+
+ return c_state
+
+ def get_output_dim(self, index):
+ return self.state_size
+
+ def get_average_flops_per_step(self):
+ s = self.number_of_subsamples
+
+ # estimate activation by 10 flops
+ # c_state = torch.tanh((1 + alpha) * state + beta)
+ flops = 13 * self.state_size
+
+ # hidden size
+ hidden_size = self.number_of_signals * self.pcm_embedding_size
+ if self.state_recurrent:
+ hidden_size += self.pcm_embedding_size
+
+ # counting 2 * A * B flops for Linear(A, B)
+ # alpha = torch.tanh(self.alphas[index](c))
+ # beta = torch.tanh(self.betas[index](c))
+ flops += 4 * hidden_size * self.state_size + 20 * self.state_size
+
+ # r_state = self.state_transform(state)
+ if self.state_recurrent:
+ flops += 2 * self.state_size * self.pcm_embedding_size
+
+ # average over steps
+ flops *= (s - 1) / s
+
+ return flops
+
+class ComparitiveSubconditioner(Subconditioner):
+ def __init__(self,
+ number_of_subsamples,
+ pcm_embedding_size,
+ state_size,
+ pcm_levels,
+ number_of_signals,
+ error_index=-1,
+ apply_gate=True,
+ normalize=False):
+ """ subconditioning by comparison """
+
+ super(ComparitiveSubconditioner, self).__init__()
+
+ self.comparison_size = self.pcm_embedding_size
+ self.error_position = error_index
+ self.apply_gate = apply_gate
+ self.normalize = normalize
+
+ self.state_transform = nn.Linear(self.state_size, self.comparison_size)
+
+ self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
+ self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
+
+ if self.apply_gate:
+ self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
+
+ # embeddings and state transforms
+ self.embeddings = [None]
+ self.alpha_denses = [None]
+ self.beta_denses = [None]
+ self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
+ self.add_module('state_transform_0', self.state_transforms[0])
+
+ for i in range(1, self.number_of_subsamples):
+ embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
+ self.add_module('pcm_embedding_' + str(i), embedding)
+ self.embeddings.append(embedding)
+
+ state_transform = nn.Linear(self.state_size, self.comparison_size)
+ self.add_module('state_transform_' + str(i), state_transform)
+ self.state_transforms.append(state_transform)
+
+ self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
+ self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
+
+ self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
+ self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
+
+ def forward(self, states, signals):
+ s = self.number_of_subsamples
+
+ c_states = [states]
+ new_states = states
+ for i in range(1, self.number_of_subsamples):
+ embed = self.embeddings[i](signals[:, i::s])
+ # reduce signal dimension
+ embed = torch.flatten(embed, -2)
+
+ comp_states = self.state_transforms[i](new_states)
+
+ alpha = torch.tanh(self.alpha_dense(embed))
+ beta = torch.tanh(self.beta_dense(embed))
+
+ # new state obtained by modulating previous state
+ new_states = torch.tanh((1 + alpha) * comp_states + beta)
+
+ c_states.append(new_states)
+
+ return c_states
diff --git a/dnn/torch/lpcnet/utils/misc.py b/dnn/torch/lpcnet/utils/misc.py
new file mode 100644
index 00000000..dab4837f
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/misc.py
@@ -0,0 +1,36 @@
+import torch
+
+
+def find(a, v):
+ try:
+ idx = a.index(v)
+ except:
+ idx = -1
+ return idx
+
+def interleave_tensors(tensors, dim=-2):
+ """ interleave list of tensors along sequence dimension """
+
+ x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim)
+ x = torch.flatten(x, dim - 1, dim)
+
+ return x
+
+def _interleave(x, pcm_levels=256):
+
+ repeats = pcm_levels // (2*x.size(-1))
+ x = x.unsqueeze(-1)
+ p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2)
+
+ return p
+
+def get_pdf_from_tree(x):
+ pcm_levels = x.size(-1)
+
+ p = _interleave(x[..., 1:2])
+ n = 4
+ while n <= pcm_levels:
+ p = p * _interleave(x[..., n//2:n])
+ n *= 2
+
+ return p \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/pcm.py b/dnn/torch/lpcnet/utils/pcm.py
new file mode 100644
index 00000000..608e40d7
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/pcm.py
@@ -0,0 +1,6 @@
+
+def clip_to_int16(x):
+ int_min = -2**15
+ int_max = 2**15 - 1
+ x_clipped = max(int_min, min(x, int_max))
+ return x_clipped
diff --git a/dnn/torch/lpcnet/utils/sample.py b/dnn/torch/lpcnet/utils/sample.py
new file mode 100644
index 00000000..14e1cd19
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sample.py
@@ -0,0 +1,15 @@
+import torch
+
+
+def sample_excitation(probs, pitch_corr):
+
+ norm = lambda x : x / (x.sum() + 1e-18)
+
+ # lowering the temperature
+ probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
+ # cut-off tails
+ probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
+ # sample
+ exc = torch.multinomial(probs.squeeze(), 1)
+
+ return exc
diff --git a/dnn/torch/lpcnet/utils/sparsification/__init__.py b/dnn/torch/lpcnet/utils/sparsification/__init__.py
new file mode 100644
index 00000000..ebfa9d9a
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/__init__.py
@@ -0,0 +1,2 @@
+from .gru_sparsifier import GRUSparsifier
+from .common import sparsify_matrix, calculate_gru_flops_per_step \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py
new file mode 100644
index 00000000..34989d4b
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/common.py
@@ -0,0 +1,92 @@
+import torch
+
+def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
+ """ sparsifies matrix with specified block size
+
+ Parameters:
+ -----------
+ matrix : torch.tensor
+ matrix to sparsify
+ density : int
+ target density
+ block_size : [int, int]
+ block size dimensions
+ keep_diagonal : bool
+ If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
+ """
+
+ m, n = matrix.shape
+ m1, n1 = block_size
+
+ if m % m1 or n % n1:
+ raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
+
+ # extract diagonal if keep_diagonal = True
+ if keep_diagonal:
+ if m != n:
+ raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
+
+ to_spare = torch.diag(torch.diag(matrix))
+ matrix = matrix - to_spare
+ else:
+ to_spare = torch.zeros_like(matrix)
+
+ # calculate energy in sub-blocks
+ x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
+ x = x ** 2
+ block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
+
+ number_of_blocks = (m * n) // (m1 * n1)
+ number_of_survivors = round(number_of_blocks * density)
+
+ # masking threshold
+ if number_of_survivors == 0:
+ threshold = 0
+ else:
+ threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
+
+ # create mask
+ mask = torch.ones_like(block_energies)
+ mask[block_energies < threshold] = 0
+ mask = torch.repeat_interleave(mask, m1, dim=0)
+ mask = torch.repeat_interleave(mask, n1, dim=1)
+
+ # perform masking
+ masked_matrix = mask * matrix + to_spare
+
+ if return_mask:
+ return masked_matrix, mask
+ else:
+ return masked_matrix
+
+def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
+ input_size = gru.input_size
+ hidden_size = gru.hidden_size
+ flops = 0
+
+ input_density = (
+ sparsification_dict.get('W_ir', [1])[0]
+ + sparsification_dict.get('W_in', [1])[0]
+ + sparsification_dict.get('W_iz', [1])[0]
+ ) / 3
+
+ recurrent_density = (
+ sparsification_dict.get('W_hr', [1])[0]
+ + sparsification_dict.get('W_hn', [1])[0]
+ + sparsification_dict.get('W_hz', [1])[0]
+ ) / 3
+
+ # input matrix vector multiplications
+ if not drop_input:
+ flops += 2 * 3 * input_size * hidden_size * input_density
+
+ # recurrent matrix vector multiplications
+ flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
+
+ # biases
+ flops += 6 * hidden_size
+
+ # activations estimated by 10 flops per activation
+ flops += 30 * hidden_size
+
+ return flops \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
new file mode 100644
index 00000000..865f3a7d
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
@@ -0,0 +1,158 @@
+import torch
+
+from .common import sparsify_matrix
+
+
+class GRUSparsifier:
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
+ of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
+ 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
+ update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
+ should be kept.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> gru = torch.nn.GRU(10, 20)
+ >>> sparsify_dict = {
+ ... 'W_ir' : (0.5, [2, 2], False),
+ ... 'W_iz' : (0.6, [2, 2], False),
+ ... 'W_in' : (0.7, [2, 2], False),
+ ... 'W_hr' : (0.1, [4, 4], True),
+ ... 'W_hz' : (0.2, [4, 4], True),
+ ... 'W_hn' : (0.3, [4, 4], True),
+ ... }
+ >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+ # just copying parameters...
+ self.start = start
+ self.stop = stop
+ self.interval = interval
+ self.exponent = exponent
+ self.task_list = task_list
+
+ # ... and setting counter to 0
+ self.step_counter = 0
+
+ self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
+
+ def step(self, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+ # compute current interpolation factor
+ self.step_counter += 1
+
+ if self.step_counter < self.start:
+ return
+ elif self.step_counter < self.stop:
+ # update only every self.interval-th interval
+ if self.step_counter % self.interval:
+ return
+
+ alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
+ else:
+ alpha = 0
+
+
+ with torch.no_grad():
+ for gru, params in self.task_list:
+ hidden_size = gru.hidden_size
+
+ # input weights
+ for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
+ if key in params:
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+
+ gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
+ density, # density
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
+ print(f"sparsification mask {key} changed for gru {gru}")
+
+ self.last_masks[key] = new_mask
+
+ # recurrent weights
+ for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
+ if key in params:
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+ gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
+ density,
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
+ print(f"sparsification mask {key} changed for gru {gru}")
+
+ self.last_masks[key] = new_mask
+
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ gru = torch.nn.GRU(10, 20)
+ sparsify_dict = {
+ 'W_ir' : (0.5, [2, 2], False),
+ 'W_iz' : (0.6, [2, 2], False),
+ 'W_in' : (0.7, [2, 2], False),
+ 'W_hr' : (0.1, [4, 4], True),
+ 'W_hz' : (0.2, [4, 4], True),
+ 'W_hn' : (0.3, [4, 4], True),
+ }
+
+ sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
diff --git a/dnn/torch/lpcnet/utils/templates.py b/dnn/torch/lpcnet/utils/templates.py
new file mode 100644
index 00000000..d399f57c
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/templates.py
@@ -0,0 +1,128 @@
+from models import multi_rate_lpcnet
+import copy
+
+setup_dict = dict()
+
+dataset_template_v2 = {
+ 'version' : 2,
+ 'feature_file' : 'features.f32',
+ 'signal_file' : 'data.s16',
+ 'frame_length' : 160,
+ 'feature_frame_length' : 36,
+ 'signal_frame_length' : 2,
+ 'feature_dtype' : 'float32',
+ 'signal_dtype' : 'int16',
+ 'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]},
+ 'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction
+}
+
+dataset_template_v1 = {
+ 'version' : 1,
+ 'feature_file' : 'features.f32',
+ 'signal_file' : 'data.u8',
+ 'frame_length' : 160,
+ 'feature_frame_length' : 55,
+ 'signal_frame_length' : 4,
+ 'feature_dtype' : 'float32',
+ 'signal_dtype' : 'uint8',
+ 'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]},
+ 'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction
+}
+
+# lpcnet
+
+lpcnet_config = {
+ 'frame_size' : 160,
+ 'gru_a_units' : 384,
+ 'gru_b_units' : 64,
+ 'feature_conditioning_dim' : 128,
+ 'feature_conv_kernel_size' : 3,
+ 'period_levels' : 257,
+ 'period_embedding_dim' : 64,
+ 'signal_embedding_dim' : 128,
+ 'signal_levels' : 256,
+ 'feature_dimension' : 19,
+ 'output_levels' : 256,
+ 'lpc_gamma' : 0.9,
+ 'features' : ['cepstrum', 'periods', 'pitch_corr'],
+ 'signals' : ['last_signal', 'prediction', 'last_error'],
+ 'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2},
+ 'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} },
+ 'target' : 'error',
+ 'feature_history' : 2,
+ 'feature_lookahead' : 2,
+ 'sparsification' : {
+ 'gru_a' : {
+ 'start' : 10000,
+ 'stop' : 30000,
+ 'interval' : 100,
+ 'exponent' : 3,
+ 'params' : {
+ 'W_hr' : (0.05, [4, 8], True),
+ 'W_hz' : (0.05, [4, 8], True),
+ 'W_hn' : (0.2, [4, 8], True)
+ },
+ },
+ 'gru_b' : {
+ 'start' : 10000,
+ 'stop' : 30000,
+ 'interval' : 100,
+ 'exponent' : 3,
+ 'params' : {
+ 'W_ir' : (0.5, [4, 8], False),
+ 'W_iz' : (0.5, [4, 8], False),
+ 'W_in' : (0.5, [4, 8], False)
+ },
+ }
+ },
+ 'add_reference_phase' : False,
+ 'reference_phase_dim' : 0
+}
+
+
+
+# multi rate
+subconditioning = {
+ 'subconditioning_a' : {
+ 'number_of_subsamples' : 2,
+ 'method' : 'modulative',
+ 'signals' : ['last_signal', 'prediction', 'last_error'],
+ 'pcm_embedding_size' : 64,
+ 'kwargs' : dict()
+
+ },
+ 'subconditioning_b' : {
+ 'number_of_subsamples' : 2,
+ 'method' : 'modulative',
+ 'signals' : ['last_signal', 'prediction', 'last_error'],
+ 'pcm_embedding_size' : 64,
+ 'kwargs' : dict()
+ }
+}
+
+multi_rate_lpcnet_config = lpcnet_config.copy()
+multi_rate_lpcnet_config['subconditioning'] = subconditioning
+
+training_default = {
+ 'batch_size' : 256,
+ 'epochs' : 20,
+ 'lr' : 1e-3,
+ 'lr_decay_factor' : 2.5e-5,
+ 'adam_betas' : [0.9, 0.99],
+ 'frames_per_sample' : 15
+}
+
+lpcnet_setup = {
+ 'dataset' : '/local/datasets/lpcnet_training',
+ 'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'},
+ 'training' : training_default
+}
+
+multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup)
+multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config
+multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate'
+
+setup_dict = {
+ 'lpcnet' : lpcnet_setup,
+ 'multi_rate' : multi_rate_lpcnet_setup
+}
diff --git a/dnn/torch/lpcnet/utils/ulaw.py b/dnn/torch/lpcnet/utils/ulaw.py
new file mode 100644
index 00000000..1a9f9e47
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/ulaw.py
@@ -0,0 +1,29 @@
+import math as m
+
+import torch
+
+
+
+def ulaw2lin(u):
+ scale_1 = 32768.0 / 255.0
+ u = u - 128
+ s = torch.sign(u)
+ u = torch.abs(u)
+ return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1)
+
+
+def lin2ulawq(x):
+ scale = 255.0 / 32768.0
+ s = torch.sign(x)
+ x = torch.abs(x)
+ u = s * (128 * torch.log(1 + scale * x) / m.log(256))
+ u = torch.clip(128 + torch.round(u), 0, 255)
+ return u
+
+def lin2ulaw(x):
+ scale = 255.0 / 32768.0
+ s = torch.sign(x)
+ x = torch.abs(x)
+ u = s * (128 * torch.log(1 + scale * x) / torch.log(256))
+ u = torch.clip(128 + u, 0, 255)
+ return u \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/wav.py b/dnn/torch/lpcnet/utils/wav.py
new file mode 100644
index 00000000..3ed811f5
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/wav.py
@@ -0,0 +1,14 @@
+import wave
+
+def wavwrite16(filename, x, fs):
+ """ writes x as int16 to file with name filename
+
+ If x.dtype is int16 x is written as is. Otherwise,
+ it is scaled by 2**15 - 1 and converted to int16.
+ """
+ if x.dtype != 'int16':
+ x = ((2**15 - 1) * x).astype('int16')
+
+ with wave.open(filename, 'wb') as f:
+ f.setparams((1, 2, fs, len(x), 'NONE', ""))
+ f.writeframes(x.tobytes()) \ No newline at end of file