diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-09-05 13:29:38 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-09-05 13:29:38 +0300 |
commit | 35ee397e060283d30c098ae5e17836316bbec08b (patch) | |
tree | 4a81b86f8c0738bbdc7147214c53fda54cd0f3f3 /dnn/torch/lpcnet/models | |
parent | 90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff) |
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
Diffstat (limited to 'dnn/torch/lpcnet/models')
-rw-r--r-- | dnn/torch/lpcnet/models/__init__.py | 8 | ||||
-rw-r--r-- | dnn/torch/lpcnet/models/lpcnet.py | 274 | ||||
-rw-r--r-- | dnn/torch/lpcnet/models/multi_rate_lpcnet.py | 408 |
3 files changed, 690 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/models/__init__.py b/dnn/torch/lpcnet/models/__init__.py new file mode 100644 index 00000000..a26bc1cd --- /dev/null +++ b/dnn/torch/lpcnet/models/__init__.py @@ -0,0 +1,8 @@ +from .lpcnet import LPCNet +from .multi_rate_lpcnet import MultiRateLPCNet + + +model_dict = { + 'lpcnet' : LPCNet, + 'multi_rate' : MultiRateLPCNet +}
\ No newline at end of file diff --git a/dnn/torch/lpcnet/models/lpcnet.py b/dnn/torch/lpcnet/models/lpcnet.py new file mode 100644 index 00000000..e20ae68d --- /dev/null +++ b/dnn/torch/lpcnet/models/lpcnet.py @@ -0,0 +1,274 @@ +import torch +from torch import nn +import numpy as np + +from utils.ulaw import lin2ulawq, ulaw2lin +from utils.sample import sample_excitation +from utils.pcm import clip_to_int16 +from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step +from utils.layers import DualFC +from utils.misc import get_pdf_from_tree + + +class LPCNet(nn.Module): + def __init__(self, config): + super(LPCNet, self).__init__() + + # + self.input_layout = config['input_layout'] + self.feature_history = config['feature_history'] + self.feature_lookahead = config['feature_lookahead'] + + # frame rate network parameters + self.feature_dimension = config['feature_dimension'] + self.period_embedding_dim = config['period_embedding_dim'] + self.period_levels = config['period_levels'] + self.feature_channels = self.feature_dimension + self.period_embedding_dim + self.feature_conditioning_dim = config['feature_conditioning_dim'] + self.feature_conv_kernel_size = config['feature_conv_kernel_size'] + + + # frame rate network layers + self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim) + self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') + self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') + self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim) + self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim])) + + # sample rate network parameters + self.frame_size = config['frame_size'] + self.signal_levels = config['signal_levels'] + self.signal_embedding_dim = config['signal_embedding_dim'] + self.gru_a_units = config['gru_a_units'] + self.gru_b_units = config['gru_b_units'] + self.output_levels = config['output_levels'] + self.hsampling = config.get('hsampling', False) + + self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim + self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim + + # sample rate network layers + self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim) + self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True) + self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True) + self.dual_fc = DualFC(self.gru_b_units, self.output_levels) + + # sparsification + self.sparsifier = [] + + # GRU A + if 'gru_a' in config['sparsification']: + gru_config = config['sparsification']['gru_a'] + task_list = [(self.gru_a, gru_config['params'])] + self.sparsifier.append(GRUSparsifier(task_list, + gru_config['start'], + gru_config['stop'], + gru_config['interval'], + gru_config['exponent']) + ) + self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, + gru_config['params'], drop_input=True) + else: + self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True) + + # GRU B + if 'gru_b' in config['sparsification']: + gru_config = config['sparsification']['gru_b'] + task_list = [(self.gru_b, gru_config['params'])] + self.sparsifier.append(GRUSparsifier(task_list, + gru_config['start'], + gru_config['stop'], + gru_config['interval'], + gru_config['exponent']) + ) + self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b, + gru_config['params']) + else: + self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b) + + # inference parameters + self.lpc_gamma = config.get('lpc_gamma', 1) + + def sparsify(self): + for sparsifier in self.sparsifier: + sparsifier.step() + + def get_gflops(self, fs, verbose=False): + gflops = 0 + + # frame rate network + conditioning_dim = self.feature_conditioning_dim + feature_channels = self.feature_channels + frame_rate = fs / self.frame_size + frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate + if verbose: + print(f"frame rate network: {frame_rate_network_complexity} GFLOPS") + gflops += frame_rate_network_complexity + + # gru a + gru_a_rate = fs + gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step + if verbose: + print(f"gru A: {gru_a_complexity} GFLOPS") + gflops += gru_a_complexity + + # gru b + gru_b_rate = fs + gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step + if verbose: + print(f"gru B: {gru_b_complexity} GFLOPS") + gflops += gru_b_complexity + + + # dual fcs + fc = self.dual_fc + rate = fs + input_size = fc.dense1.in_features + output_size = fc.dense1.out_features + dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate + if self.hsampling: + dual_fc_complexity /= 8 + if verbose: + print(f"dual_fc: {dual_fc_complexity} GFLOPS") + gflops += dual_fc_complexity + + if verbose: + print(f'total: {gflops} GFLOPS') + + return gflops + + def frame_rate_network(self, features, periods): + + embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3) + features = torch.concat((features, embedded_periods), dim=-1) + + # convert to channels first and calculate conditioning vector + c = torch.permute(features, [0, 2, 1]) + + c = torch.tanh(self.feature_conv1(c)) + c = torch.tanh(self.feature_conv2(c)) + # back to channels last + c = torch.permute(c, [0, 2, 1]) + c = torch.tanh(self.feature_dense1(c)) + c = torch.tanh(self.feature_dense2(c)) + + return c + + def sample_rate_network(self, signals, c, gru_states): + embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3) + c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1) + + y = torch.concat((embedded_signals, c_upsampled), dim=-1) + y, gru_a_state = self.gru_a(y, gru_states[0]) + y = torch.concat((y, c_upsampled), dim=-1) + y, gru_b_state = self.gru_b(y, gru_states[1]) + + y = self.dual_fc(y) + + if self.hsampling: + y = torch.sigmoid(y) + log_probs = torch.log(get_pdf_from_tree(y) + 1e-6) + else: + log_probs = torch.log_softmax(y, dim=-1) + + return log_probs, (gru_a_state, gru_b_state) + + def decoder(self, signals, c, gru_states): + embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3) + + y = torch.concat((embedded_signals, c), dim=-1) + y, gru_a_state = self.gru_a(y, gru_states[0]) + y = torch.concat((y, c), dim=-1) + y, gru_b_state = self.gru_b(y, gru_states[1]) + + y = self.dual_fc(y) + + if self.hsampling: + y = torch.sigmoid(y) + probs = get_pdf_from_tree(y) + else: + probs = torch.softmax(y, dim=-1) + + return probs, (gru_a_state, gru_b_state) + + def forward(self, features, periods, signals, gru_states): + + c = self.frame_rate_network(features, periods) + log_probs, _ = self.sample_rate_network(signals, c, gru_states) + + return log_probs + + def generate(self, features, periods, lpcs): + + with torch.no_grad(): + device = self.parameters().__next__().device + + num_frames = features.shape[0] - self.feature_history - self.feature_lookahead + lpc_order = lpcs.shape[-1] + num_input_signals = len(self.input_layout['signals']) + pitch_corr_position = self.input_layout['features']['pitch_corr'][0] + + # signal buffers + pcm = torch.zeros((num_frames * self.frame_size + lpc_order)) + output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16) + mem = 0 + + # state buffers + gru_a_state = torch.zeros((1, 1, self.gru_a_units)) + gru_b_state = torch.zeros((1, 1, self.gru_b_units)) + gru_states = [gru_a_state, gru_b_state] + + input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128 + + # push data to device + features = features.to(device) + periods = periods.to(device) + lpcs = lpcs.to(device) + + # lpc weighting + weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device) + lpcs = lpcs * weights + + # run feature encoding + c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0)) + + for frame_index in range(num_frames): + frame_start = frame_index * self.frame_size + pitch_corr = features[frame_index + self.feature_history, pitch_corr_position] + a = - torch.flip(lpcs[frame_index + self.feature_history], [0]) + current_c = c[:, frame_index : frame_index + 1, :] + + for i in range(self.frame_size): + pcm_position = frame_start + i + lpc_order + output_position = frame_start + i + + # prepare input + pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a) + if 'prediction' in self.input_layout['signals']: + input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred) + + # run single step of sample rate network + probs, gru_states = self.decoder( + input_signals, + current_c, + gru_states + ) + + # sample from output + exc_ulaw = sample_excitation(probs, pitch_corr) + + # signal generation + exc = ulaw2lin(exc_ulaw) + sig = exc + pred + pcm[pcm_position] = sig + mem = 0.85 * mem + float(sig) + output[output_position] = clip_to_int16(round(mem)) + + # buffer update + if 'last_signal' in self.input_layout['signals']: + input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig) + + if 'last_error' in self.input_layout['signals']: + input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc) + + return output diff --git a/dnn/torch/lpcnet/models/multi_rate_lpcnet.py b/dnn/torch/lpcnet/models/multi_rate_lpcnet.py new file mode 100644 index 00000000..c6850101 --- /dev/null +++ b/dnn/torch/lpcnet/models/multi_rate_lpcnet.py @@ -0,0 +1,408 @@ +import torch +from torch import nn +from utils.layers.subconditioner import get_subconditioner +from utils.layers import DualFC + +from utils.ulaw import lin2ulawq, ulaw2lin +from utils.sample import sample_excitation +from utils.pcm import clip_to_int16 +from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step + +from utils.misc import interleave_tensors + + + + +# MultiRateLPCNet +class MultiRateLPCNet(nn.Module): + def __init__(self, config): + super(MultiRateLPCNet, self).__init__() + + # general parameters + self.input_layout = config['input_layout'] + self.feature_history = config['feature_history'] + self.feature_lookahead = config['feature_lookahead'] + self.signals = config['signals'] + + # frame rate network parameters + self.feature_dimension = config['feature_dimension'] + self.period_embedding_dim = config['period_embedding_dim'] + self.period_levels = config['period_levels'] + self.feature_channels = self.feature_dimension + self.period_embedding_dim + self.feature_conditioning_dim = config['feature_conditioning_dim'] + self.feature_conv_kernel_size = config['feature_conv_kernel_size'] + + # frame rate network layers + self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim) + self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') + self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid') + self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim) + self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim])) + + # sample rate network parameters + self.frame_size = config['frame_size'] + self.signal_levels = config['signal_levels'] + self.signal_embedding_dim = config['signal_embedding_dim'] + self.gru_a_units = config['gru_a_units'] + self.gru_b_units = config['gru_b_units'] + self.output_levels = config['output_levels'] + + # subconditioning B + sub_config = config['subconditioning']['subconditioning_b'] + self.substeps_b = sub_config['number_of_subsamples'] + self.subcondition_signals_b = sub_config['signals'] + self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']] + method = sub_config['method'] + kwargs = sub_config['kwargs'] + if type(kwargs) == type(None): + kwargs = dict() + + state_size = self.gru_b_units + self.subconditioner_b = get_subconditioner(method, + sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'], + state_size, self.signal_levels, len(sub_config['signals']), + **sub_config['kwargs']) + + # subconditioning A + sub_config = config['subconditioning']['subconditioning_a'] + self.substeps_a = sub_config['number_of_subsamples'] + self.subcondition_signals_a = sub_config['signals'] + self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']] + method = sub_config['method'] + kwargs = sub_config['kwargs'] + if type(kwargs) == type(None): + kwargs = dict() + + state_size = self.gru_a_units + self.subconditioner_a = get_subconditioner(method, + sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'], + state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']), + **sub_config['kwargs']) + + + # wrap up subconditioning, group_size_gru_a holds the number + # of timesteps that are grouped as sample input for GRU A + # input and group_size_subcondition_a holds the number of samples that are + # grouped as input to pre-GRU B subconditioning + self.group_size_gru_a = self.substeps_a * self.substeps_b + self.group_size_subcondition_a = self.substeps_b + self.gru_a_rate_divider = self.group_size_gru_a + self.gru_b_rate_divider = self.substeps_b + + # gru sizes + self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim + self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim + self.signals_idx = [self.input_layout['signals'][key] for key in self.signals] + + # sample rate network layers + self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim) + self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True) + self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True) + + # sparsification + self.sparsifier = [] + + # GRU A + if 'gru_a' in config['sparsification']: + gru_config = config['sparsification']['gru_a'] + task_list = [(self.gru_a, gru_config['params'])] + self.sparsifier.append(GRUSparsifier(task_list, + gru_config['start'], + gru_config['stop'], + gru_config['interval'], + gru_config['exponent']) + ) + self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, + gru_config['params'], drop_input=True) + else: + self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True) + + # GRU B + if 'gru_b' in config['sparsification']: + gru_config = config['sparsification']['gru_b'] + task_list = [(self.gru_b, gru_config['params'])] + self.sparsifier.append(GRUSparsifier(task_list, + gru_config['start'], + gru_config['stop'], + gru_config['interval'], + gru_config['exponent']) + ) + self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b, + gru_config['params']) + else: + self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b) + + + + # dual FCs + self.dual_fc = [] + for i in range(self.substeps_b): + dim = self.subconditioner_b.get_output_dim(i) + self.dual_fc.append(DualFC(dim, self.output_levels)) + self.add_module(f"dual_fc_{i}", self.dual_fc[-1]) + + def get_gflops(self, fs, verbose=False, hierarchical_sampling=False): + gflops = 0 + + # frame rate network + conditioning_dim = self.feature_conditioning_dim + feature_channels = self.feature_channels + frame_rate = fs / self.frame_size + frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate + if verbose: + print(f"frame rate network: {frame_rate_network_complexity} GFLOPS") + gflops += frame_rate_network_complexity + + # gru a + gru_a_rate = fs / self.group_size_gru_a + gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step + if verbose: + print(f"gru A: {gru_a_complexity} GFLOPS") + gflops += gru_a_complexity + + # subconditioning a + subcond_a_rate = fs / self.substeps_b + subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate + if verbose: + print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS") + gflops += subconditioning_a_complexity + + # gru b + gru_b_rate = fs / self.substeps_b + gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step + if verbose: + print(f"gru B: {gru_b_complexity} GFLOPS") + gflops += gru_b_complexity + + # subconditioning b + subcond_b_rate = fs + subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate + if verbose: + print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS") + gflops += subconditioning_b_complexity + + # dual fcs + for i, fc in enumerate(self.dual_fc): + rate = fs / len(self.dual_fc) + input_size = fc.dense1.in_features + output_size = fc.dense1.out_features + dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate + if hierarchical_sampling: + dual_fc_complexity /= 8 + if verbose: + print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS") + gflops += dual_fc_complexity + + if verbose: + print(f'total: {gflops} GFLOPS') + + return gflops + + + + def sparsify(self): + for sparsifier in self.sparsifier: + sparsifier.step() + + def frame_rate_network(self, features, periods): + + embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3) + features = torch.concat((features, embedded_periods), dim=-1) + + # convert to channels first and calculate conditioning vector + c = torch.permute(features, [0, 2, 1]) + + c = torch.tanh(self.feature_conv1(c)) + c = torch.tanh(self.feature_conv2(c)) + # back to channels last + c = torch.permute(c, [0, 2, 1]) + c = torch.tanh(self.feature_dense1(c)) + c = torch.tanh(self.feature_dense2(c)) + + return c + + def prepare_signals(self, signals, group_size, signal_idx): + """ extracts, delays and groups signals """ + + batch_size, sequence_length, num_signals = signals.shape + + # extract signals according to position + signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx], + dim=-1) + + # roll back pcm to account for grouping + signals = torch.roll(signals, group_size - 1, -2) + + # reshape + signals = torch.reshape(signals, + (batch_size, sequence_length // group_size, group_size * len(signal_idx))) + + return signals + + + def sample_rate_network(self, signals, c, gru_states): + + signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx) + embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3) + # features at GRU A rate + c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1) + # features at GRU B rate + c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1) + + y = torch.concat((embedded_signals, c_upsampled_a), dim=-1) + y, gru_a_state = self.gru_a(y, gru_states[0]) + # first round of upsampling and subconditioning + c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a) + y = self.subconditioner_a(y, c_signals_a) + y = interleave_tensors(y) + + y = torch.concat((y, c_upsampled_b), dim=-1) + y, gru_b_state = self.gru_b(y, gru_states[1]) + c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b) + y = self.subconditioner_b(y, c_signals_b) + + y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)] + y = interleave_tensors(y) + + return y, (gru_a_state, gru_b_state) + + def decoder(self, signals, c, gru_states): + embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3) + + y = torch.concat((embedded_signals, c), dim=-1) + y, gru_a_state = self.gru_a(y, gru_states[0]) + y = torch.concat((y, c), dim=-1) + y, gru_b_state = self.gru_b(y, gru_states[1]) + + y = self.dual_fc(y) + + return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state) + + def forward(self, features, periods, signals, gru_states): + + c = self.frame_rate_network(features, periods) + y, _ = self.sample_rate_network(signals, c, gru_states) + log_probs = torch.log_softmax(y, dim=-1) + + return log_probs + + def generate(self, features, periods, lpcs): + + with torch.no_grad(): + device = self.parameters().__next__().device + + num_frames = features.shape[0] - self.feature_history - self.feature_lookahead + lpc_order = lpcs.shape[-1] + num_input_signals = len(self.signals) + pitch_corr_position = self.input_layout['features']['pitch_corr'][0] + + # signal buffers + last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1)) + prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1)) + last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1)) + output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16) + mem = 0 + + # state buffers + gru_a_state = torch.zeros((1, 1, self.gru_a_units)) + gru_b_state = torch.zeros((1, 1, self.gru_b_units)) + + input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long) + # conditioning signals for subconditioner a + c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long) + # conditioning signals for subconditioner b + c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long) + + # signal dict + signal_dict = { + 'prediction' : prediction, + 'last_error' : last_error, + 'last_signal' : last_signal + } + + # push data to device + features = features.to(device) + periods = periods.to(device) + lpcs = lpcs.to(device) + + # run feature encoding + c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0)) + + for frame_index in range(num_frames): + frame_start = frame_index * self.frame_size + pitch_corr = features[frame_index + self.feature_history, pitch_corr_position] + a = - torch.flip(lpcs[frame_index + self.feature_history], [0]) + current_c = c[:, frame_index : frame_index + 1, :] + + for i in range(0, self.frame_size, self.group_size_gru_a): + pcm_position = frame_start + i + lpc_order + output_position = frame_start + i + + # calculate newest prediction + prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a) + + # prepare input + for slot in range(self.group_size_gru_a): + k = slot - self.group_size_gru_a + 1 + for idx, name in enumerate(self.signals): + input_signals[idx + slot * num_input_signals] = lin2ulawq( + signal_dict[name][pcm_position + k] + ) + + + # run GRU A + embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1))) + embed_signals = torch.flatten(embed_signals, 2) + y = torch.cat((embed_signals, current_c), dim=-1) + h_a, gru_a_state = self.gru_a(y, gru_a_state) + + # loop over substeps_a + for step_a in range(self.substeps_a): + # prepare conditioning input + for slot in range(self.group_size_subcondition_a): + k = slot - self.group_size_subcondition_a + 1 + for idx, name in enumerate(self.subcondition_signals_a): + c_signals_a[idx + slot * num_input_signals] = lin2ulawq( + signal_dict[name][pcm_position + k] + ) + + # subconditioning + h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1))) + + # run GRU B + y = torch.cat((h_a, current_c), dim=-1) + h_b, gru_b_state = self.gru_b(y, gru_b_state) + + # loop over substeps b + for step_b in range(self.substeps_b): + # prepare subconditioning input + for idx, name in enumerate(self.subcondition_signals_b): + c_signals_b[idx] = lin2ulawq( + signal_dict[name][pcm_position] + ) + + # subcondition + h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1))) + + # run dual FC + probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1) + + # sample + new_exc = ulaw2lin(sample_excitation(probs, pitch_corr)) + + # update signals + sig = new_exc + prediction[pcm_position] + last_error[pcm_position + 1] = new_exc + last_signal[pcm_position + 1] = sig + + mem = 0.85 * mem + float(sig) + output[output_position] = clip_to_int16(round(mem)) + + # increase positions + pcm_position += 1 + output_position += 1 + + # calculate next prediction + prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a) + + return output |