From 35ee397e060283d30c098ae5e17836316bbec08b Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Tue, 5 Sep 2023 12:29:38 +0200 Subject: added LPCNet torch implementation Signed-off-by: Jan Buethe --- dnn/torch/lpcnet/utils/__init__.py | 4 + dnn/torch/lpcnet/utils/data.py | 112 +++++ dnn/torch/lpcnet/utils/endoscopy.py | 205 +++++++++ dnn/torch/lpcnet/utils/layers/__init__.py | 3 + dnn/torch/lpcnet/utils/layers/dual_fc.py | 15 + dnn/torch/lpcnet/utils/layers/pcm_embeddings.py | 42 ++ dnn/torch/lpcnet/utils/layers/subconditioner.py | 468 +++++++++++++++++++++ dnn/torch/lpcnet/utils/misc.py | 36 ++ dnn/torch/lpcnet/utils/pcm.py | 6 + dnn/torch/lpcnet/utils/sample.py | 15 + dnn/torch/lpcnet/utils/sparsification/__init__.py | 2 + dnn/torch/lpcnet/utils/sparsification/common.py | 92 ++++ .../lpcnet/utils/sparsification/gru_sparsifier.py | 158 +++++++ dnn/torch/lpcnet/utils/templates.py | 128 ++++++ dnn/torch/lpcnet/utils/ulaw.py | 29 ++ dnn/torch/lpcnet/utils/wav.py | 14 + 16 files changed, 1329 insertions(+) create mode 100644 dnn/torch/lpcnet/utils/__init__.py create mode 100644 dnn/torch/lpcnet/utils/data.py create mode 100644 dnn/torch/lpcnet/utils/endoscopy.py create mode 100644 dnn/torch/lpcnet/utils/layers/__init__.py create mode 100644 dnn/torch/lpcnet/utils/layers/dual_fc.py create mode 100644 dnn/torch/lpcnet/utils/layers/pcm_embeddings.py create mode 100644 dnn/torch/lpcnet/utils/layers/subconditioner.py create mode 100644 dnn/torch/lpcnet/utils/misc.py create mode 100644 dnn/torch/lpcnet/utils/pcm.py create mode 100644 dnn/torch/lpcnet/utils/sample.py create mode 100644 dnn/torch/lpcnet/utils/sparsification/__init__.py create mode 100644 dnn/torch/lpcnet/utils/sparsification/common.py create mode 100644 dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py create mode 100644 dnn/torch/lpcnet/utils/templates.py create mode 100644 dnn/torch/lpcnet/utils/ulaw.py create mode 100644 dnn/torch/lpcnet/utils/wav.py (limited to 'dnn/torch/lpcnet/utils') diff --git a/dnn/torch/lpcnet/utils/__init__.py b/dnn/torch/lpcnet/utils/__init__.py new file mode 100644 index 00000000..edbbe02c --- /dev/null +++ b/dnn/torch/lpcnet/utils/__init__.py @@ -0,0 +1,4 @@ +from . import sparsification +from . import data +from . import pcm +from . import sample \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/data.py b/dnn/torch/lpcnet/utils/data.py new file mode 100644 index 00000000..b8e7c612 --- /dev/null +++ b/dnn/torch/lpcnet/utils/data.py @@ -0,0 +1,112 @@ +import os + +import torch +import numpy as np + +def load_features(feature_file, version=2): + if version == 2: + layout = { + 'cepstrum': [0,18], + 'periods': [18, 19], + 'pitch_corr': [19, 20], + 'lpc': [20, 36] + } + frame_length = 36 + + elif version == 1: + layout = { + 'cepstrum': [0,18], + 'periods': [36, 37], + 'pitch_corr': [37, 38], + 'lpc': [39, 55], + } + frame_length = 55 + else: + raise ValueError(f'unknown feature version: {version}') + + + raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32')) + raw_features = raw_features.reshape((-1, frame_length)) + + features = torch.cat( + [ + raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]], + raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]] + ], + dim=1 + ) + + lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]] + periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long() + + return {'features' : features, 'periods' : periods, 'lpcs' : lpcs} + + + +def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85): + ref_data = np.memmap(reference_data_path, dtype=np.int16) + signal = np.memmap(signal_path, dtype=np.int16) + + signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw' + signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape) + + + assert len(signal) % 160 == 0 + num_frames = len(signal) // 160 + mem = np.zeros(1) + for fr in range(len(signal)//160): + signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid') + mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160] + + new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape) + + new_data[:] = 0 + N = len(signal) - offset + new_data[1 : 2*N + 1: 2] = signal_preemph[offset:] + new_data[2 : 2*N + 2: 2] = signal_preemph[offset:] + + +def parse_warpq_scores(output_file): + """ extracts warpq scores from output file """ + + with open(output_file, "r") as f: + lines = f.readlines() + + scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")] + + return scores + + +def parse_stats_file(file): + + with open(file, "r") as f: + lines = f.readlines() + + mean = float(lines[0].split(":")[-1]) + bt_mean = float(lines[1].split(":")[-1]) + top_mean = float(lines[2].split(":")[-1]) + + return mean, bt_mean, top_mean + +def collect_test_stats(test_folder): + """ collects statistics for all discovered metrics from test folder """ + + metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'} + + results = dict() + + content = os.listdir(test_folder) + + stats_files = [file for file in content if file.startswith('stats_')] + + for file in stats_files: + metric = file[len("stats_") : -len(".txt")] + + if metric not in metrics: + print(f"warning: unknown metric {metric}") + + mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file)) + + results[metric] = [mean, bt_mean, top_mean] + + return results diff --git a/dnn/torch/lpcnet/utils/endoscopy.py b/dnn/torch/lpcnet/utils/endoscopy.py new file mode 100644 index 00000000..05dd4750 --- /dev/null +++ b/dnn/torch/lpcnet/utils/endoscopy.py @@ -0,0 +1,205 @@ +""" module for inspecting models during inference """ + +import os + +import yaml +import matplotlib.pyplot as plt +import matplotlib.animation as animation + +import torch +import numpy as np + +# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}} +_state = dict() +_folder = 'endoscopy' + +def get_gru_gates(gru, input, state): + hidden_size = gru.hidden_size + + direct = torch.matmul(gru.weight_ih_l0, input.squeeze()) + recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze()) + + # reset gate + start, stop = 0 * hidden_size, 1 * hidden_size + reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) + + # update gate + start, stop = 1 * hidden_size, 2 * hidden_size + update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) + + # new gate + start, stop = 2 * hidden_size, 3 * hidden_size + new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop])) + + return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate} + + +def init(folder='endoscopy'): + """ sets up output folder for endoscopy data """ + + global _folder + _folder = folder + + if not os.path.exists(folder): + os.makedirs(folder) + else: + print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.") + +def write_data(key, data, fs): + """ appends data to previous data written under key """ + + global _state + + # convert to numpy if torch.Tensor is given + if isinstance(data, torch.Tensor): + data = data.detach().numpy() + + if not key in _state: + _state[key] = { + 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'), + 'fs' : fs, + 'dim' : tuple(data.shape), + 'dtype' : str(data.dtype) + } + + with open(os.path.join(_folder, key + '.yml'), 'w') as f: + f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]})) + else: + if _state[key]['fs'] != fs: + raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}") + if _state[key]['dtype'] != str(data.dtype): + raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}") + if _state[key]['dim'] != tuple(data.shape): + raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}") + + _state[key]['fid'].write(data.tobytes()) + +def close(folder='endoscopy'): + """ clean up """ + for key in _state.keys(): + _state[key]['fid'].close() + + +def read_data(folder='endoscopy'): + """ retrieves written data as numpy arrays """ + + + keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')] + + return_dict = dict() + + for key in keys: + with open(os.path.join(folder, key + '.yml'), 'r') as f: + value = yaml.load(f.read(), yaml.FullLoader) + + with open(os.path.join(folder, key + '.bin'), 'rb') as f: + data = np.frombuffer(f.read(), dtype=value['dtype']) + + value['data'] = data.reshape((-1,) + value['dim']) + + return_dict[key] = value + + return return_dict + +def get_best_reshape(shape, target_ratio=1): + """ calculated the best 2d reshape of shape given the target ratio (rows/cols)""" + + if len(shape) > 1: + pixel_count = 1 + for s in shape: + pixel_count *= s + else: + pixel_count = shape[0] + + if pixel_count == 1: + return (1,) + + num_columns = int((pixel_count / target_ratio)**.5) + + while (pixel_count % num_columns): + num_columns -= 1 + + num_rows = pixel_count // num_columns + + return (num_rows, num_columns) + +def get_type_and_shape(shape): + + # can happen if data is one dimensional + if len(shape) == 0: + shape = (1,) + + # calculate pixel count + if len(shape) > 1: + pixel_count = 1 + for s in shape: + pixel_count *= s + else: + pixel_count = shape[0] + + if pixel_count == 1: + return 'plot', (1, ) + + # stay with shape if already 2-dimensional + if len(shape) == 2: + if (shape[0] != pixel_count) or (shape[1] != pixel_count): + return 'image', shape + + return 'image', get_best_reshape(shape) + +def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80): + + # determine plot setup + num_keys = len(data.keys()) + + num_rows = int((num_keys * 3/4) ** .5) + + num_cols = (num_keys + num_rows - 1) // num_rows + + fig, axs = plt.subplots(num_rows, num_cols) + fig.set_size_inches(num_cols * 5, num_rows * 5) + + display = dict() + + fs_max = max([val['fs'] for val in data.values()]) + + num_samples = max([val['data'].shape[0] for val in data.values()]) + + keys = sorted(data.keys()) + + # inspect data + for i, key in enumerate(keys): + axs[i // num_cols, i % num_cols].title.set_text(key) + + display[key] = dict() + + display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim']) + display[key]['down_factor'] = data[key]['fs'] / fs_max + + start_index = max(start_index, half_signal_window_length) + while stop_index < 0: + stop_index += num_samples + + stop_index = min(stop_index, num_samples - half_signal_window_length) + + # actual plotting + frames = [] + for index in range(start_index, stop_index): + ims = [] + for i, key in enumerate(keys): + feature_index = int(round(index * display[key]['down_factor'])) + + if display[key]['type'] == 'plot': + ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0]) + + elif display[key]['type'] == 'image': + ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True)) + + frames.append(ims) + + ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000) + + if not filename.endswith('.mp4'): + filename += '.mp4' + + ani.save(filename) \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/layers/__init__.py b/dnn/torch/lpcnet/utils/layers/__init__.py new file mode 100644 index 00000000..4a58f221 --- /dev/null +++ b/dnn/torch/lpcnet/utils/layers/__init__.py @@ -0,0 +1,3 @@ +from .dual_fc import DualFC +from .subconditioner import AdditiveSubconditioner, ModulativeSubconditioner, ConcatenativeSubconditioner +from .pcm_embeddings import PCMEmbedding, DifferentiablePCMEmbedding \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/layers/dual_fc.py b/dnn/torch/lpcnet/utils/layers/dual_fc.py new file mode 100644 index 00000000..ed10a5c6 --- /dev/null +++ b/dnn/torch/lpcnet/utils/layers/dual_fc.py @@ -0,0 +1,15 @@ +import torch +from torch import nn + +class DualFC(nn.Module): + def __init__(self, input_dim, output_dim): + super(DualFC, self).__init__() + + self.dense1 = nn.Linear(input_dim, output_dim) + self.dense2 = nn.Linear(input_dim, output_dim) + + self.alpha = nn.Parameter(torch.tensor([0.5]), requires_grad=True) + self.beta = nn.Parameter(torch.tensor([0.5]), requires_grad=True) + + def forward(self, x): + return self.alpha * torch.tanh(self.dense1(x)) + self.beta * torch.tanh(self.dense2(x)) diff --git a/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py b/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py new file mode 100644 index 00000000..12835f89 --- /dev/null +++ b/dnn/torch/lpcnet/utils/layers/pcm_embeddings.py @@ -0,0 +1,42 @@ +""" module implementing PCM embeddings for LPCNet """ + +import math as m + +import torch +from torch import nn + + +class PCMEmbedding(nn.Module): + def __init__(self, embed_dim=128, num_levels=256): + super(PCMEmbedding, self).__init__() + + self.embed_dim = embed_dim + self.num_levels = num_levels + + self.embedding = nn.Embedding(self.num_levels, self.num_dim) + + # initialize + with torch.no_grad(): + num_rows, num_cols = self.num_levels, self.embed_dim + a = m.sqrt(12) * (torch.rand(num_rows, num_cols) - 0.5) + for i in range(num_rows): + a[i, :] += m.sqrt(12) * (i - num_rows / 2) + self.embedding.weight[:, :] = 0.1 * a + + def forward(self, x): + return self.embeddint(x) + + +class DifferentiablePCMEmbedding(PCMEmbedding): + def __init__(self, embed_dim, num_levels=256): + super(DifferentiablePCMEmbedding, self).__init__(embed_dim, num_levels) + + def forward(self, x): + x_int = (x - torch.floor(x)).detach().long() + x_frac = x - x_int + x_next = torch.minimum(x_int + 1, self.num_levels) + + embed_0 = self.embedding(x_int) + embed_1 = self.embedding(x_next) + + return (1 - x_frac) * embed_0 + x_frac * embed_1 diff --git a/dnn/torch/lpcnet/utils/layers/subconditioner.py b/dnn/torch/lpcnet/utils/layers/subconditioner.py new file mode 100644 index 00000000..87189cd5 --- /dev/null +++ b/dnn/torch/lpcnet/utils/layers/subconditioner.py @@ -0,0 +1,468 @@ +from re import sub +import torch +from torch import nn + + + + +def get_subconditioner( method, + number_of_subsamples, + pcm_embedding_size, + state_size, + pcm_levels, + number_of_signals, + **kwargs): + + subconditioner_dict = { + 'additive' : AdditiveSubconditioner, + 'concatenative' : ConcatenativeSubconditioner, + 'modulative' : ModulativeSubconditioner + } + + return subconditioner_dict[method](number_of_subsamples, + pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs) + + +class Subconditioner(nn.Module): + def __init__(self): + """ upsampling by subconditioning + + Upsamples a sequence of states conditioning on pcm signals and + optionally a feature vector. + """ + super(Subconditioner, self).__init__() + + def forward(self, states, signals, features=None): + raise Exception("Base class should not be called") + + def single_step(self, index, state, signals, features): + raise Exception("Base class should not be called") + + def get_output_dim(self, index): + raise Exception("Base class should not be called") + + +class AdditiveSubconditioner(Subconditioner): + def __init__(self, + number_of_subsamples, + pcm_embedding_size, + state_size, + pcm_levels, + number_of_signals, + **kwargs): + """ subconditioning by addition """ + + super(AdditiveSubconditioner, self).__init__() + + self.number_of_subsamples = number_of_subsamples + self.pcm_embedding_size = pcm_embedding_size + self.state_size = state_size + self.pcm_levels = pcm_levels + self.number_of_signals = number_of_signals + + if self.pcm_embedding_size != self.state_size: + raise ValueError('For additive subconditioning state and embedding ' + + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}') + + self.embeddings = [None] + for i in range(1, self.number_of_subsamples): + embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) + self.add_module('pcm_embedding_' + str(i), embedding) + self.embeddings.append(embedding) + + def forward(self, states, signals): + """ creates list of subconditioned states + + Parameters: + ----------- + states : torch.tensor + states of shape (batch, seq_length // s, state_size) + signals : torch.tensor + signals of shape (batch, seq_length, number_of_signals) + + Returns: + -------- + c_states : list of torch.tensor + list of s subconditioned states + """ + + s = self.number_of_subsamples + + c_states = [states] + new_states = states + for i in range(1, self.number_of_subsamples): + embed = self.embeddings[i](signals[:, i::s]) + # reduce signal dimension + embed = torch.sum(embed, dim=2) + + new_states = new_states + embed + c_states.append(new_states) + + return c_states + + def single_step(self, index, state, signals): + """ carry out single step for inference + + Parameters: + ----------- + index : int + position in subconditioning batch + + state : torch.tensor + state to sub-condition + + signals : torch.tensor + signals for subconditioning, all but the last dimensions + must match those of state + + Returns: + c_state : torch.tensor + subconditioned state + """ + + if index == 0: + c_state = state + else: + embed_signals = self.embeddings[index](signals) + c = torch.sum(embed_signals, dim=-2) + c_state = state + c + + return c_state + + def get_output_dim(self, index): + return self.state_size + + def get_average_flops_per_step(self): + s = self.number_of_subsamples + flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size + return flops + + +class ConcatenativeSubconditioner(Subconditioner): + def __init__(self, + number_of_subsamples, + pcm_embedding_size, + state_size, + pcm_levels, + number_of_signals, + recurrent=True, + **kwargs): + """ subconditioning by concatenation """ + + super(ConcatenativeSubconditioner, self).__init__() + + self.number_of_subsamples = number_of_subsamples + self.pcm_embedding_size = pcm_embedding_size + self.state_size = state_size + self.pcm_levels = pcm_levels + self.number_of_signals = number_of_signals + self.recurrent = recurrent + + self.embeddings = [] + start_index = 0 + if self.recurrent: + start_index = 1 + self.embeddings.append(None) + + for i in range(start_index, self.number_of_subsamples): + embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) + self.add_module('pcm_embedding_' + str(i), embedding) + self.embeddings.append(embedding) + + def forward(self, states, signals): + """ creates list of subconditioned states + + Parameters: + ----------- + states : torch.tensor + states of shape (batch, seq_length // s, state_size) + signals : torch.tensor + signals of shape (batch, seq_length, number_of_signals) + + Returns: + -------- + c_states : list of torch.tensor + list of s subconditioned states + """ + s = self.number_of_subsamples + + if self.recurrent: + c_states = [states] + start = 1 + else: + c_states = [] + start = 0 + + new_states = states + for i in range(start, self.number_of_subsamples): + embed = self.embeddings[i](signals[:, i::s]) + # reduce signal dimension + embed = torch.flatten(embed, -2) + + if self.recurrent: + new_states = torch.cat((new_states, embed), dim=-1) + else: + new_states = torch.cat((states, embed), dim=-1) + + c_states.append(new_states) + + return c_states + + def single_step(self, index, state, signals): + """ carry out single step for inference + + Parameters: + ----------- + index : int + position in subconditioning batch + + state : torch.tensor + state to sub-condition + + signals : torch.tensor + signals for subconditioning, all but the last dimensions + must match those of state + + Returns: + c_state : torch.tensor + subconditioned state + """ + + if index == 0 and self.recurrent: + c_state = state + else: + embed_signals = self.embeddings[index](signals) + c = torch.flatten(embed_signals, -2) + if not self.recurrent and index > 0: + # overwrite previous conditioning vector + c_state = torch.cat((state[...,:self.state_size], c), dim=-1) + else: + c_state = torch.cat((state, c), dim=-1) + return c_state + + return c_state + + def get_average_flops_per_step(self): + return 0 + + def get_output_dim(self, index): + if self.recurrent: + return self.state_size + index * self.pcm_embedding_size * self.number_of_signals + else: + return self.state_size + self.pcm_embedding_size * self.number_of_signals + +class ModulativeSubconditioner(Subconditioner): + def __init__(self, + number_of_subsamples, + pcm_embedding_size, + state_size, + pcm_levels, + number_of_signals, + state_recurrent=False, + **kwargs): + """ subconditioning by modulation """ + + super(ModulativeSubconditioner, self).__init__() + + self.number_of_subsamples = number_of_subsamples + self.pcm_embedding_size = pcm_embedding_size + self.state_size = state_size + self.pcm_levels = pcm_levels + self.number_of_signals = number_of_signals + self.state_recurrent = state_recurrent + + self.hidden_size = self.pcm_embedding_size * self.number_of_signals + + if self.state_recurrent: + self.hidden_size += self.pcm_embedding_size + self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size) + + self.embeddings = [None] + self.alphas = [None] + self.betas = [None] + + for i in range(1, self.number_of_subsamples): + embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) + self.add_module('pcm_embedding_' + str(i), embedding) + self.embeddings.append(embedding) + + self.alphas.append(nn.Linear(self.hidden_size, self.state_size)) + self.add_module('alpha_dense_' + str(i), self.alphas[-1]) + + self.betas.append(nn.Linear(self.hidden_size, self.state_size)) + self.add_module('beta_dense_' + str(i), self.betas[-1]) + + + + def forward(self, states, signals): + """ creates list of subconditioned states + + Parameters: + ----------- + states : torch.tensor + states of shape (batch, seq_length // s, state_size) + signals : torch.tensor + signals of shape (batch, seq_length, number_of_signals) + + Returns: + -------- + c_states : list of torch.tensor + list of s subconditioned states + """ + s = self.number_of_subsamples + + c_states = [states] + new_states = states + for i in range(1, self.number_of_subsamples): + embed = self.embeddings[i](signals[:, i::s]) + # reduce signal dimension + embed = torch.flatten(embed, -2) + + if self.state_recurrent: + comp_states = self.state_transform(new_states) + embed = torch.cat((embed, comp_states), dim=-1) + + alpha = torch.tanh(self.alphas[i](embed)) + beta = torch.tanh(self.betas[i](embed)) + + # new state obtained by modulating previous state + new_states = torch.tanh((1 + alpha) * new_states + beta) + + c_states.append(new_states) + + return c_states + + def single_step(self, index, state, signals): + """ carry out single step for inference + + Parameters: + ----------- + index : int + position in subconditioning batch + + state : torch.tensor + state to sub-condition + + signals : torch.tensor + signals for subconditioning, all but the last dimensions + must match those of state + + Returns: + c_state : torch.tensor + subconditioned state + """ + + if index == 0: + c_state = state + else: + embed_signals = self.embeddings[index](signals) + c = torch.flatten(embed_signals, -2) + if self.state_recurrent: + r_state = self.state_transform(state) + c = torch.cat((c, r_state), dim=-1) + alpha = torch.tanh(self.alphas[index](c)) + beta = torch.tanh(self.betas[index](c)) + c_state = torch.tanh((1 + alpha) * state + beta) + return c_state + + return c_state + + def get_output_dim(self, index): + return self.state_size + + def get_average_flops_per_step(self): + s = self.number_of_subsamples + + # estimate activation by 10 flops + # c_state = torch.tanh((1 + alpha) * state + beta) + flops = 13 * self.state_size + + # hidden size + hidden_size = self.number_of_signals * self.pcm_embedding_size + if self.state_recurrent: + hidden_size += self.pcm_embedding_size + + # counting 2 * A * B flops for Linear(A, B) + # alpha = torch.tanh(self.alphas[index](c)) + # beta = torch.tanh(self.betas[index](c)) + flops += 4 * hidden_size * self.state_size + 20 * self.state_size + + # r_state = self.state_transform(state) + if self.state_recurrent: + flops += 2 * self.state_size * self.pcm_embedding_size + + # average over steps + flops *= (s - 1) / s + + return flops + +class ComparitiveSubconditioner(Subconditioner): + def __init__(self, + number_of_subsamples, + pcm_embedding_size, + state_size, + pcm_levels, + number_of_signals, + error_index=-1, + apply_gate=True, + normalize=False): + """ subconditioning by comparison """ + + super(ComparitiveSubconditioner, self).__init__() + + self.comparison_size = self.pcm_embedding_size + self.error_position = error_index + self.apply_gate = apply_gate + self.normalize = normalize + + self.state_transform = nn.Linear(self.state_size, self.comparison_size) + + self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) + self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) + + if self.apply_gate: + self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size) + + # embeddings and state transforms + self.embeddings = [None] + self.alpha_denses = [None] + self.beta_denses = [None] + self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)] + self.add_module('state_transform_0', self.state_transforms[0]) + + for i in range(1, self.number_of_subsamples): + embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) + self.add_module('pcm_embedding_' + str(i), embedding) + self.embeddings.append(embedding) + + state_transform = nn.Linear(self.state_size, self.comparison_size) + self.add_module('state_transform_' + str(i), state_transform) + self.state_transforms.append(state_transform) + + self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) + self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1]) + + self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) + self.add_module('beta_dense_' + str(i), self.beta_denses[-1]) + + def forward(self, states, signals): + s = self.number_of_subsamples + + c_states = [states] + new_states = states + for i in range(1, self.number_of_subsamples): + embed = self.embeddings[i](signals[:, i::s]) + # reduce signal dimension + embed = torch.flatten(embed, -2) + + comp_states = self.state_transforms[i](new_states) + + alpha = torch.tanh(self.alpha_dense(embed)) + beta = torch.tanh(self.beta_dense(embed)) + + # new state obtained by modulating previous state + new_states = torch.tanh((1 + alpha) * comp_states + beta) + + c_states.append(new_states) + + return c_states diff --git a/dnn/torch/lpcnet/utils/misc.py b/dnn/torch/lpcnet/utils/misc.py new file mode 100644 index 00000000..dab4837f --- /dev/null +++ b/dnn/torch/lpcnet/utils/misc.py @@ -0,0 +1,36 @@ +import torch + + +def find(a, v): + try: + idx = a.index(v) + except: + idx = -1 + return idx + +def interleave_tensors(tensors, dim=-2): + """ interleave list of tensors along sequence dimension """ + + x = torch.cat([x.unsqueeze(dim) for x in tensors], dim=dim) + x = torch.flatten(x, dim - 1, dim) + + return x + +def _interleave(x, pcm_levels=256): + + repeats = pcm_levels // (2*x.size(-1)) + x = x.unsqueeze(-1) + p = torch.flatten(torch.repeat_interleave(torch.cat((x, 1 - x), dim=-1), repeats, dim=-1), -2) + + return p + +def get_pdf_from_tree(x): + pcm_levels = x.size(-1) + + p = _interleave(x[..., 1:2]) + n = 4 + while n <= pcm_levels: + p = p * _interleave(x[..., n//2:n]) + n *= 2 + + return p \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/pcm.py b/dnn/torch/lpcnet/utils/pcm.py new file mode 100644 index 00000000..608e40d7 --- /dev/null +++ b/dnn/torch/lpcnet/utils/pcm.py @@ -0,0 +1,6 @@ + +def clip_to_int16(x): + int_min = -2**15 + int_max = 2**15 - 1 + x_clipped = max(int_min, min(x, int_max)) + return x_clipped diff --git a/dnn/torch/lpcnet/utils/sample.py b/dnn/torch/lpcnet/utils/sample.py new file mode 100644 index 00000000..14e1cd19 --- /dev/null +++ b/dnn/torch/lpcnet/utils/sample.py @@ -0,0 +1,15 @@ +import torch + + +def sample_excitation(probs, pitch_corr): + + norm = lambda x : x / (x.sum() + 1e-18) + + # lowering the temperature + probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5))) + # cut-off tails + probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0]))) + # sample + exc = torch.multinomial(probs.squeeze(), 1) + + return exc diff --git a/dnn/torch/lpcnet/utils/sparsification/__init__.py b/dnn/torch/lpcnet/utils/sparsification/__init__.py new file mode 100644 index 00000000..ebfa9d9a --- /dev/null +++ b/dnn/torch/lpcnet/utils/sparsification/__init__.py @@ -0,0 +1,2 @@ +from .gru_sparsifier import GRUSparsifier +from .common import sparsify_matrix, calculate_gru_flops_per_step \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py new file mode 100644 index 00000000..34989d4b --- /dev/null +++ b/dnn/torch/lpcnet/utils/sparsification/common.py @@ -0,0 +1,92 @@ +import torch + +def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False): + """ sparsifies matrix with specified block size + + Parameters: + ----------- + matrix : torch.tensor + matrix to sparsify + density : int + target density + block_size : [int, int] + block size dimensions + keep_diagonal : bool + If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False + """ + + m, n = matrix.shape + m1, n1 = block_size + + if m % m1 or n % n1: + raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}") + + # extract diagonal if keep_diagonal = True + if keep_diagonal: + if m != n: + raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True") + + to_spare = torch.diag(torch.diag(matrix)) + matrix = matrix - to_spare + else: + to_spare = torch.zeros_like(matrix) + + # calculate energy in sub-blocks + x = torch.reshape(matrix, (m // m1, m1, n // n1, n1)) + x = x ** 2 + block_energies = torch.sum(torch.sum(x, dim=3), dim=1) + + number_of_blocks = (m * n) // (m1 * n1) + number_of_survivors = round(number_of_blocks * density) + + # masking threshold + if number_of_survivors == 0: + threshold = 0 + else: + threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors] + + # create mask + mask = torch.ones_like(block_energies) + mask[block_energies < threshold] = 0 + mask = torch.repeat_interleave(mask, m1, dim=0) + mask = torch.repeat_interleave(mask, n1, dim=1) + + # perform masking + masked_matrix = mask * matrix + to_spare + + if return_mask: + return masked_matrix, mask + else: + return masked_matrix + +def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False): + input_size = gru.input_size + hidden_size = gru.hidden_size + flops = 0 + + input_density = ( + sparsification_dict.get('W_ir', [1])[0] + + sparsification_dict.get('W_in', [1])[0] + + sparsification_dict.get('W_iz', [1])[0] + ) / 3 + + recurrent_density = ( + sparsification_dict.get('W_hr', [1])[0] + + sparsification_dict.get('W_hn', [1])[0] + + sparsification_dict.get('W_hz', [1])[0] + ) / 3 + + # input matrix vector multiplications + if not drop_input: + flops += 2 * 3 * input_size * hidden_size * input_density + + # recurrent matrix vector multiplications + flops += 2 * 3 * hidden_size * hidden_size * recurrent_density + + # biases + flops += 6 * hidden_size + + # activations estimated by 10 flops per activation + flops += 30 * hidden_size + + return flops \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py new file mode 100644 index 00000000..865f3a7d --- /dev/null +++ b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py @@ -0,0 +1,158 @@ +import torch + +from .common import sparsify_matrix + + +class GRUSparsifier: + def __init__(self, task_list, start, stop, interval, exponent=3): + """ Sparsifier for torch.nn.GRUs + + Parameters: + ----------- + task_list : list + task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance + of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in', + 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset, + update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal), + where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which + sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal + should be kept. + + start : int + training step after which sparsification will be started. + + stop : int + training step after which sparsification will be completed. + + interval : int + sparsification interval for steps between start and stop. After stop sparsification will be + carried out after every call to GRUSparsifier.step() + + exponent : float + Interpolation exponent for sparsification interval. In step i sparsification will be carried out + with density (alpha + target_density * (1 * alpha)), where + alpha = ((stop - i) / (start - stop)) ** exponent + + Example: + -------- + >>> import torch + >>> gru = torch.nn.GRU(10, 20) + >>> sparsify_dict = { + ... 'W_ir' : (0.5, [2, 2], False), + ... 'W_iz' : (0.6, [2, 2], False), + ... 'W_in' : (0.7, [2, 2], False), + ... 'W_hr' : (0.1, [4, 4], True), + ... 'W_hz' : (0.2, [4, 4], True), + ... 'W_hn' : (0.3, [4, 4], True), + ... } + >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50) + >>> for i in range(100): + ... sparsifier.step() + """ + # just copying parameters... + self.start = start + self.stop = stop + self.interval = interval + self.exponent = exponent + self.task_list = task_list + + # ... and setting counter to 0 + self.step_counter = 0 + + self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']} + + def step(self, verbose=False): + """ carries out sparsification step + + Call this function after optimizer.step in your + training loop. + + Parameters: + ---------- + verbose : bool + if true, densities are printed out + + Returns: + -------- + None + + """ + # compute current interpolation factor + self.step_counter += 1 + + if self.step_counter < self.start: + return + elif self.step_counter < self.stop: + # update only every self.interval-th interval + if self.step_counter % self.interval: + return + + alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent + else: + alpha = 0 + + + with torch.no_grad(): + for gru, params in self.task_list: + hidden_size = gru.hidden_size + + # input weights + for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): + if key in params: + density = alpha + (1 - alpha) * params[key][0] + if verbose: + print(f"[{self.step_counter}]: {key} density: {density}") + + gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], + density, # density + params[key][1], # block_size + params[key][2], # keep_diagonal (might want to set this to False) + return_mask=True + ) + + if type(self.last_masks[key]) != type(None): + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: + print(f"sparsification mask {key} changed for gru {gru}") + + self.last_masks[key] = new_mask + + # recurrent weights + for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): + if key in params: + density = alpha + (1 - alpha) * params[key][0] + if verbose: + print(f"[{self.step_counter}]: {key} density: {density}") + gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], + density, + params[key][1], # block_size + params[key][2], # keep_diagonal (might want to set this to False) + return_mask=True + ) + + if type(self.last_masks[key]) != type(None): + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: + print(f"sparsification mask {key} changed for gru {gru}") + + self.last_masks[key] = new_mask + + + +if __name__ == "__main__": + print("Testing sparsifier") + + gru = torch.nn.GRU(10, 20) + sparsify_dict = { + 'W_ir' : (0.5, [2, 2], False), + 'W_iz' : (0.6, [2, 2], False), + 'W_in' : (0.7, [2, 2], False), + 'W_hr' : (0.1, [4, 4], True), + 'W_hz' : (0.2, [4, 4], True), + 'W_hn' : (0.3, [4, 4], True), + } + + sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10) + + for i in range(100): + sparsifier.step(verbose=True) diff --git a/dnn/torch/lpcnet/utils/templates.py b/dnn/torch/lpcnet/utils/templates.py new file mode 100644 index 00000000..d399f57c --- /dev/null +++ b/dnn/torch/lpcnet/utils/templates.py @@ -0,0 +1,128 @@ +from models import multi_rate_lpcnet +import copy + +setup_dict = dict() + +dataset_template_v2 = { + 'version' : 2, + 'feature_file' : 'features.f32', + 'signal_file' : 'data.s16', + 'frame_length' : 160, + 'feature_frame_length' : 36, + 'signal_frame_length' : 2, + 'feature_dtype' : 'float32', + 'signal_dtype' : 'int16', + 'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [18, 19], 'pitch_corr': [19, 20], 'lpc': [20, 36]}, + 'signal_frame_layout' : {'last_signal' : 0, 'signal': 1} # signal, last_signal, error, prediction +} + +dataset_template_v1 = { + 'version' : 1, + 'feature_file' : 'features.f32', + 'signal_file' : 'data.u8', + 'frame_length' : 160, + 'feature_frame_length' : 55, + 'signal_frame_length' : 4, + 'feature_dtype' : 'float32', + 'signal_dtype' : 'uint8', + 'feature_frame_layout' : {'cepstrum': [0,18], 'periods': [36, 37], 'pitch_corr': [37, 38], 'lpc': [39, 55]}, + 'signal_frame_layout' : {'last_signal' : 0, 'prediction' : 1, 'last_error': 2, 'error': 3} # signal, last_signal, error, prediction +} + +# lpcnet + +lpcnet_config = { + 'frame_size' : 160, + 'gru_a_units' : 384, + 'gru_b_units' : 64, + 'feature_conditioning_dim' : 128, + 'feature_conv_kernel_size' : 3, + 'period_levels' : 257, + 'period_embedding_dim' : 64, + 'signal_embedding_dim' : 128, + 'signal_levels' : 256, + 'feature_dimension' : 19, + 'output_levels' : 256, + 'lpc_gamma' : 0.9, + 'features' : ['cepstrum', 'periods', 'pitch_corr'], + 'signals' : ['last_signal', 'prediction', 'last_error'], + 'input_layout' : { 'signals' : {'last_signal' : 0, 'prediction' : 1, 'last_error' : 2}, + 'features' : {'cepstrum' : [0, 18], 'pitch_corr' : [18, 19]} }, + 'target' : 'error', + 'feature_history' : 2, + 'feature_lookahead' : 2, + 'sparsification' : { + 'gru_a' : { + 'start' : 10000, + 'stop' : 30000, + 'interval' : 100, + 'exponent' : 3, + 'params' : { + 'W_hr' : (0.05, [4, 8], True), + 'W_hz' : (0.05, [4, 8], True), + 'W_hn' : (0.2, [4, 8], True) + }, + }, + 'gru_b' : { + 'start' : 10000, + 'stop' : 30000, + 'interval' : 100, + 'exponent' : 3, + 'params' : { + 'W_ir' : (0.5, [4, 8], False), + 'W_iz' : (0.5, [4, 8], False), + 'W_in' : (0.5, [4, 8], False) + }, + } + }, + 'add_reference_phase' : False, + 'reference_phase_dim' : 0 +} + + + +# multi rate +subconditioning = { + 'subconditioning_a' : { + 'number_of_subsamples' : 2, + 'method' : 'modulative', + 'signals' : ['last_signal', 'prediction', 'last_error'], + 'pcm_embedding_size' : 64, + 'kwargs' : dict() + + }, + 'subconditioning_b' : { + 'number_of_subsamples' : 2, + 'method' : 'modulative', + 'signals' : ['last_signal', 'prediction', 'last_error'], + 'pcm_embedding_size' : 64, + 'kwargs' : dict() + } +} + +multi_rate_lpcnet_config = lpcnet_config.copy() +multi_rate_lpcnet_config['subconditioning'] = subconditioning + +training_default = { + 'batch_size' : 256, + 'epochs' : 20, + 'lr' : 1e-3, + 'lr_decay_factor' : 2.5e-5, + 'adam_betas' : [0.9, 0.99], + 'frames_per_sample' : 15 +} + +lpcnet_setup = { + 'dataset' : '/local/datasets/lpcnet_training', + 'lpcnet' : {'config' : lpcnet_config, 'model': 'lpcnet'}, + 'training' : training_default +} + +multi_rate_lpcnet_setup = copy.deepcopy(lpcnet_setup) +multi_rate_lpcnet_setup['lpcnet']['config'] = multi_rate_lpcnet_config +multi_rate_lpcnet_setup['lpcnet']['model'] = 'multi_rate' + +setup_dict = { + 'lpcnet' : lpcnet_setup, + 'multi_rate' : multi_rate_lpcnet_setup +} diff --git a/dnn/torch/lpcnet/utils/ulaw.py b/dnn/torch/lpcnet/utils/ulaw.py new file mode 100644 index 00000000..1a9f9e47 --- /dev/null +++ b/dnn/torch/lpcnet/utils/ulaw.py @@ -0,0 +1,29 @@ +import math as m + +import torch + + + +def ulaw2lin(u): + scale_1 = 32768.0 / 255.0 + u = u - 128 + s = torch.sign(u) + u = torch.abs(u) + return s * scale_1 * (torch.exp(u / 128. * m.log(256)) - 1) + + +def lin2ulawq(x): + scale = 255.0 / 32768.0 + s = torch.sign(x) + x = torch.abs(x) + u = s * (128 * torch.log(1 + scale * x) / m.log(256)) + u = torch.clip(128 + torch.round(u), 0, 255) + return u + +def lin2ulaw(x): + scale = 255.0 / 32768.0 + s = torch.sign(x) + x = torch.abs(x) + u = s * (128 * torch.log(1 + scale * x) / torch.log(256)) + u = torch.clip(128 + u, 0, 255) + return u \ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/wav.py b/dnn/torch/lpcnet/utils/wav.py new file mode 100644 index 00000000..3ed811f5 --- /dev/null +++ b/dnn/torch/lpcnet/utils/wav.py @@ -0,0 +1,14 @@ +import wave + +def wavwrite16(filename, x, fs): + """ writes x as int16 to file with name filename + + If x.dtype is int16 x is written as is. Otherwise, + it is scaled by 2**15 - 1 and converted to int16. + """ + if x.dtype != 'int16': + x = ((2**15 - 1) * x).astype('int16') + + with wave.open(filename, 'wb') as f: + f.setparams((1, 2, fs, len(x), 'NONE', "")) + f.writeframes(x.tobytes()) \ No newline at end of file -- cgit v1.2.3