Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
commit35ee397e060283d30c098ae5e17836316bbec08b (patch)
tree4a81b86f8c0738bbdc7147214c53fda54cd0f3f3 /dnn/torch/lpcnet/models
parent90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff)
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
Diffstat (limited to 'dnn/torch/lpcnet/models')
-rw-r--r--dnn/torch/lpcnet/models/__init__.py8
-rw-r--r--dnn/torch/lpcnet/models/lpcnet.py274
-rw-r--r--dnn/torch/lpcnet/models/multi_rate_lpcnet.py408
3 files changed, 690 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/models/__init__.py b/dnn/torch/lpcnet/models/__init__.py
new file mode 100644
index 00000000..a26bc1cd
--- /dev/null
+++ b/dnn/torch/lpcnet/models/__init__.py
@@ -0,0 +1,8 @@
+from .lpcnet import LPCNet
+from .multi_rate_lpcnet import MultiRateLPCNet
+
+
+model_dict = {
+ 'lpcnet' : LPCNet,
+ 'multi_rate' : MultiRateLPCNet
+} \ No newline at end of file
diff --git a/dnn/torch/lpcnet/models/lpcnet.py b/dnn/torch/lpcnet/models/lpcnet.py
new file mode 100644
index 00000000..e20ae68d
--- /dev/null
+++ b/dnn/torch/lpcnet/models/lpcnet.py
@@ -0,0 +1,274 @@
+import torch
+from torch import nn
+import numpy as np
+
+from utils.ulaw import lin2ulawq, ulaw2lin
+from utils.sample import sample_excitation
+from utils.pcm import clip_to_int16
+from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
+from utils.layers import DualFC
+from utils.misc import get_pdf_from_tree
+
+
+class LPCNet(nn.Module):
+ def __init__(self, config):
+ super(LPCNet, self).__init__()
+
+ #
+ self.input_layout = config['input_layout']
+ self.feature_history = config['feature_history']
+ self.feature_lookahead = config['feature_lookahead']
+
+ # frame rate network parameters
+ self.feature_dimension = config['feature_dimension']
+ self.period_embedding_dim = config['period_embedding_dim']
+ self.period_levels = config['period_levels']
+ self.feature_channels = self.feature_dimension + self.period_embedding_dim
+ self.feature_conditioning_dim = config['feature_conditioning_dim']
+ self.feature_conv_kernel_size = config['feature_conv_kernel_size']
+
+
+ # frame rate network layers
+ self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
+ self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
+ self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
+
+ # sample rate network parameters
+ self.frame_size = config['frame_size']
+ self.signal_levels = config['signal_levels']
+ self.signal_embedding_dim = config['signal_embedding_dim']
+ self.gru_a_units = config['gru_a_units']
+ self.gru_b_units = config['gru_b_units']
+ self.output_levels = config['output_levels']
+ self.hsampling = config.get('hsampling', False)
+
+ self.gru_a_input_dim = len(self.input_layout['signals']) * self.signal_embedding_dim + self.feature_conditioning_dim
+ self.gru_b_input_dim = self.gru_a_units + self.feature_conditioning_dim
+
+ # sample rate network layers
+ self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
+ self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
+ self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
+ self.dual_fc = DualFC(self.gru_b_units, self.output_levels)
+
+ # sparsification
+ self.sparsifier = []
+
+ # GRU A
+ if 'gru_a' in config['sparsification']:
+ gru_config = config['sparsification']['gru_a']
+ task_list = [(self.gru_a, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
+ gru_config['params'], drop_input=True)
+ else:
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
+
+ # GRU B
+ if 'gru_b' in config['sparsification']:
+ gru_config = config['sparsification']['gru_b']
+ task_list = [(self.gru_b, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
+ gru_config['params'])
+ else:
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
+
+ # inference parameters
+ self.lpc_gamma = config.get('lpc_gamma', 1)
+
+ def sparsify(self):
+ for sparsifier in self.sparsifier:
+ sparsifier.step()
+
+ def get_gflops(self, fs, verbose=False):
+ gflops = 0
+
+ # frame rate network
+ conditioning_dim = self.feature_conditioning_dim
+ feature_channels = self.feature_channels
+ frame_rate = fs / self.frame_size
+ frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
+ if verbose:
+ print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
+ gflops += frame_rate_network_complexity
+
+ # gru a
+ gru_a_rate = fs
+ gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
+ if verbose:
+ print(f"gru A: {gru_a_complexity} GFLOPS")
+ gflops += gru_a_complexity
+
+ # gru b
+ gru_b_rate = fs
+ gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
+ if verbose:
+ print(f"gru B: {gru_b_complexity} GFLOPS")
+ gflops += gru_b_complexity
+
+
+ # dual fcs
+ fc = self.dual_fc
+ rate = fs
+ input_size = fc.dense1.in_features
+ output_size = fc.dense1.out_features
+ dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
+ if self.hsampling:
+ dual_fc_complexity /= 8
+ if verbose:
+ print(f"dual_fc: {dual_fc_complexity} GFLOPS")
+ gflops += dual_fc_complexity
+
+ if verbose:
+ print(f'total: {gflops} GFLOPS')
+
+ return gflops
+
+ def frame_rate_network(self, features, periods):
+
+ embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
+ features = torch.concat((features, embedded_periods), dim=-1)
+
+ # convert to channels first and calculate conditioning vector
+ c = torch.permute(features, [0, 2, 1])
+
+ c = torch.tanh(self.feature_conv1(c))
+ c = torch.tanh(self.feature_conv2(c))
+ # back to channels last
+ c = torch.permute(c, [0, 2, 1])
+ c = torch.tanh(self.feature_dense1(c))
+ c = torch.tanh(self.feature_dense2(c))
+
+ return c
+
+ def sample_rate_network(self, signals, c, gru_states):
+ embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
+ c_upsampled = torch.repeat_interleave(c, self.frame_size, dim=1)
+
+ y = torch.concat((embedded_signals, c_upsampled), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ y = torch.concat((y, c_upsampled), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+
+ y = self.dual_fc(y)
+
+ if self.hsampling:
+ y = torch.sigmoid(y)
+ log_probs = torch.log(get_pdf_from_tree(y) + 1e-6)
+ else:
+ log_probs = torch.log_softmax(y, dim=-1)
+
+ return log_probs, (gru_a_state, gru_b_state)
+
+ def decoder(self, signals, c, gru_states):
+ embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
+
+ y = torch.concat((embedded_signals, c), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ y = torch.concat((y, c), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+
+ y = self.dual_fc(y)
+
+ if self.hsampling:
+ y = torch.sigmoid(y)
+ probs = get_pdf_from_tree(y)
+ else:
+ probs = torch.softmax(y, dim=-1)
+
+ return probs, (gru_a_state, gru_b_state)
+
+ def forward(self, features, periods, signals, gru_states):
+
+ c = self.frame_rate_network(features, periods)
+ log_probs, _ = self.sample_rate_network(signals, c, gru_states)
+
+ return log_probs
+
+ def generate(self, features, periods, lpcs):
+
+ with torch.no_grad():
+ device = self.parameters().__next__().device
+
+ num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
+ lpc_order = lpcs.shape[-1]
+ num_input_signals = len(self.input_layout['signals'])
+ pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
+
+ # signal buffers
+ pcm = torch.zeros((num_frames * self.frame_size + lpc_order))
+ output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
+ mem = 0
+
+ # state buffers
+ gru_a_state = torch.zeros((1, 1, self.gru_a_units))
+ gru_b_state = torch.zeros((1, 1, self.gru_b_units))
+ gru_states = [gru_a_state, gru_b_state]
+
+ input_signals = torch.zeros((1, 1, num_input_signals), dtype=torch.long) + 128
+
+ # push data to device
+ features = features.to(device)
+ periods = periods.to(device)
+ lpcs = lpcs.to(device)
+
+ # lpc weighting
+ weights = torch.FloatTensor([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]).to(device)
+ lpcs = lpcs * weights
+
+ # run feature encoding
+ c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
+
+ for frame_index in range(num_frames):
+ frame_start = frame_index * self.frame_size
+ pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
+ a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
+ current_c = c[:, frame_index : frame_index + 1, :]
+
+ for i in range(self.frame_size):
+ pcm_position = frame_start + i + lpc_order
+ output_position = frame_start + i
+
+ # prepare input
+ pred = torch.sum(pcm[pcm_position - lpc_order : pcm_position] * a)
+ if 'prediction' in self.input_layout['signals']:
+ input_signals[0, 0, self.input_layout['signals']['prediction']] = lin2ulawq(pred)
+
+ # run single step of sample rate network
+ probs, gru_states = self.decoder(
+ input_signals,
+ current_c,
+ gru_states
+ )
+
+ # sample from output
+ exc_ulaw = sample_excitation(probs, pitch_corr)
+
+ # signal generation
+ exc = ulaw2lin(exc_ulaw)
+ sig = exc + pred
+ pcm[pcm_position] = sig
+ mem = 0.85 * mem + float(sig)
+ output[output_position] = clip_to_int16(round(mem))
+
+ # buffer update
+ if 'last_signal' in self.input_layout['signals']:
+ input_signals[0, 0, self.input_layout['signals']['last_signal']] = lin2ulawq(sig)
+
+ if 'last_error' in self.input_layout['signals']:
+ input_signals[0, 0, self.input_layout['signals']['last_error']] = lin2ulawq(exc)
+
+ return output
diff --git a/dnn/torch/lpcnet/models/multi_rate_lpcnet.py b/dnn/torch/lpcnet/models/multi_rate_lpcnet.py
new file mode 100644
index 00000000..c6850101
--- /dev/null
+++ b/dnn/torch/lpcnet/models/multi_rate_lpcnet.py
@@ -0,0 +1,408 @@
+import torch
+from torch import nn
+from utils.layers.subconditioner import get_subconditioner
+from utils.layers import DualFC
+
+from utils.ulaw import lin2ulawq, ulaw2lin
+from utils.sample import sample_excitation
+from utils.pcm import clip_to_int16
+from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step
+
+from utils.misc import interleave_tensors
+
+
+
+
+# MultiRateLPCNet
+class MultiRateLPCNet(nn.Module):
+ def __init__(self, config):
+ super(MultiRateLPCNet, self).__init__()
+
+ # general parameters
+ self.input_layout = config['input_layout']
+ self.feature_history = config['feature_history']
+ self.feature_lookahead = config['feature_lookahead']
+ self.signals = config['signals']
+
+ # frame rate network parameters
+ self.feature_dimension = config['feature_dimension']
+ self.period_embedding_dim = config['period_embedding_dim']
+ self.period_levels = config['period_levels']
+ self.feature_channels = self.feature_dimension + self.period_embedding_dim
+ self.feature_conditioning_dim = config['feature_conditioning_dim']
+ self.feature_conv_kernel_size = config['feature_conv_kernel_size']
+
+ # frame rate network layers
+ self.period_embedding = nn.Embedding(self.period_levels, self.period_embedding_dim)
+ self.feature_conv1 = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_conv2 = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
+ self.feature_dense1 = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
+ self.feature_dense2 = nn.Linear(*(2*[self.feature_conditioning_dim]))
+
+ # sample rate network parameters
+ self.frame_size = config['frame_size']
+ self.signal_levels = config['signal_levels']
+ self.signal_embedding_dim = config['signal_embedding_dim']
+ self.gru_a_units = config['gru_a_units']
+ self.gru_b_units = config['gru_b_units']
+ self.output_levels = config['output_levels']
+
+ # subconditioning B
+ sub_config = config['subconditioning']['subconditioning_b']
+ self.substeps_b = sub_config['number_of_subsamples']
+ self.subcondition_signals_b = sub_config['signals']
+ self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
+ method = sub_config['method']
+ kwargs = sub_config['kwargs']
+ if type(kwargs) == type(None):
+ kwargs = dict()
+
+ state_size = self.gru_b_units
+ self.subconditioner_b = get_subconditioner(method,
+ sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
+ state_size, self.signal_levels, len(sub_config['signals']),
+ **sub_config['kwargs'])
+
+ # subconditioning A
+ sub_config = config['subconditioning']['subconditioning_a']
+ self.substeps_a = sub_config['number_of_subsamples']
+ self.subcondition_signals_a = sub_config['signals']
+ self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
+ method = sub_config['method']
+ kwargs = sub_config['kwargs']
+ if type(kwargs) == type(None):
+ kwargs = dict()
+
+ state_size = self.gru_a_units
+ self.subconditioner_a = get_subconditioner(method,
+ sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
+ state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
+ **sub_config['kwargs'])
+
+
+ # wrap up subconditioning, group_size_gru_a holds the number
+ # of timesteps that are grouped as sample input for GRU A
+ # input and group_size_subcondition_a holds the number of samples that are
+ # grouped as input to pre-GRU B subconditioning
+ self.group_size_gru_a = self.substeps_a * self.substeps_b
+ self.group_size_subcondition_a = self.substeps_b
+ self.gru_a_rate_divider = self.group_size_gru_a
+ self.gru_b_rate_divider = self.substeps_b
+
+ # gru sizes
+ self.gru_a_input_dim = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
+ self.gru_b_input_dim = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
+ self.signals_idx = [self.input_layout['signals'][key] for key in self.signals]
+
+ # sample rate network layers
+ self.signal_embedding = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
+ self.gru_a = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
+ self.gru_b = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)
+
+ # sparsification
+ self.sparsifier = []
+
+ # GRU A
+ if 'gru_a' in config['sparsification']:
+ gru_config = config['sparsification']['gru_a']
+ task_list = [(self.gru_a, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
+ gru_config['params'], drop_input=True)
+ else:
+ self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)
+
+ # GRU B
+ if 'gru_b' in config['sparsification']:
+ gru_config = config['sparsification']['gru_b']
+ task_list = [(self.gru_b, gru_config['params'])]
+ self.sparsifier.append(GRUSparsifier(task_list,
+ gru_config['start'],
+ gru_config['stop'],
+ gru_config['interval'],
+ gru_config['exponent'])
+ )
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
+ gru_config['params'])
+ else:
+ self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)
+
+
+
+ # dual FCs
+ self.dual_fc = []
+ for i in range(self.substeps_b):
+ dim = self.subconditioner_b.get_output_dim(i)
+ self.dual_fc.append(DualFC(dim, self.output_levels))
+ self.add_module(f"dual_fc_{i}", self.dual_fc[-1])
+
+ def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
+ gflops = 0
+
+ # frame rate network
+ conditioning_dim = self.feature_conditioning_dim
+ feature_channels = self.feature_channels
+ frame_rate = fs / self.frame_size
+ frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
+ if verbose:
+ print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
+ gflops += frame_rate_network_complexity
+
+ # gru a
+ gru_a_rate = fs / self.group_size_gru_a
+ gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
+ if verbose:
+ print(f"gru A: {gru_a_complexity} GFLOPS")
+ gflops += gru_a_complexity
+
+ # subconditioning a
+ subcond_a_rate = fs / self.substeps_b
+ subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
+ if verbose:
+ print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
+ gflops += subconditioning_a_complexity
+
+ # gru b
+ gru_b_rate = fs / self.substeps_b
+ gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
+ if verbose:
+ print(f"gru B: {gru_b_complexity} GFLOPS")
+ gflops += gru_b_complexity
+
+ # subconditioning b
+ subcond_b_rate = fs
+ subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
+ if verbose:
+ print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
+ gflops += subconditioning_b_complexity
+
+ # dual fcs
+ for i, fc in enumerate(self.dual_fc):
+ rate = fs / len(self.dual_fc)
+ input_size = fc.dense1.in_features
+ output_size = fc.dense1.out_features
+ dual_fc_complexity = 1e-9 * (4 * input_size * output_size + 22 * output_size) * rate
+ if hierarchical_sampling:
+ dual_fc_complexity /= 8
+ if verbose:
+ print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
+ gflops += dual_fc_complexity
+
+ if verbose:
+ print(f'total: {gflops} GFLOPS')
+
+ return gflops
+
+
+
+ def sparsify(self):
+ for sparsifier in self.sparsifier:
+ sparsifier.step()
+
+ def frame_rate_network(self, features, periods):
+
+ embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
+ features = torch.concat((features, embedded_periods), dim=-1)
+
+ # convert to channels first and calculate conditioning vector
+ c = torch.permute(features, [0, 2, 1])
+
+ c = torch.tanh(self.feature_conv1(c))
+ c = torch.tanh(self.feature_conv2(c))
+ # back to channels last
+ c = torch.permute(c, [0, 2, 1])
+ c = torch.tanh(self.feature_dense1(c))
+ c = torch.tanh(self.feature_dense2(c))
+
+ return c
+
+ def prepare_signals(self, signals, group_size, signal_idx):
+ """ extracts, delays and groups signals """
+
+ batch_size, sequence_length, num_signals = signals.shape
+
+ # extract signals according to position
+ signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
+ dim=-1)
+
+ # roll back pcm to account for grouping
+ signals = torch.roll(signals, group_size - 1, -2)
+
+ # reshape
+ signals = torch.reshape(signals,
+ (batch_size, sequence_length // group_size, group_size * len(signal_idx)))
+
+ return signals
+
+
+ def sample_rate_network(self, signals, c, gru_states):
+
+ signals_a = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
+ embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
+ # features at GRU A rate
+ c_upsampled_a = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
+ # features at GRU B rate
+ c_upsampled_b = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)
+
+ y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ # first round of upsampling and subconditioning
+ c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
+ y = self.subconditioner_a(y, c_signals_a)
+ y = interleave_tensors(y)
+
+ y = torch.concat((y, c_upsampled_b), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+ c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
+ y = self.subconditioner_b(y, c_signals_b)
+
+ y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
+ y = interleave_tensors(y)
+
+ return y, (gru_a_state, gru_b_state)
+
+ def decoder(self, signals, c, gru_states):
+ embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)
+
+ y = torch.concat((embedded_signals, c), dim=-1)
+ y, gru_a_state = self.gru_a(y, gru_states[0])
+ y = torch.concat((y, c), dim=-1)
+ y, gru_b_state = self.gru_b(y, gru_states[1])
+
+ y = self.dual_fc(y)
+
+ return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)
+
+ def forward(self, features, periods, signals, gru_states):
+
+ c = self.frame_rate_network(features, periods)
+ y, _ = self.sample_rate_network(signals, c, gru_states)
+ log_probs = torch.log_softmax(y, dim=-1)
+
+ return log_probs
+
+ def generate(self, features, periods, lpcs):
+
+ with torch.no_grad():
+ device = self.parameters().__next__().device
+
+ num_frames = features.shape[0] - self.feature_history - self.feature_lookahead
+ lpc_order = lpcs.shape[-1]
+ num_input_signals = len(self.signals)
+ pitch_corr_position = self.input_layout['features']['pitch_corr'][0]
+
+ # signal buffers
+ last_signal = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
+ prediction = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
+ last_error = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
+ output = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
+ mem = 0
+
+ # state buffers
+ gru_a_state = torch.zeros((1, 1, self.gru_a_units))
+ gru_b_state = torch.zeros((1, 1, self.gru_b_units))
+
+ input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
+ # conditioning signals for subconditioner a
+ c_signals_a = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
+ # conditioning signals for subconditioner b
+ c_signals_b = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)
+
+ # signal dict
+ signal_dict = {
+ 'prediction' : prediction,
+ 'last_error' : last_error,
+ 'last_signal' : last_signal
+ }
+
+ # push data to device
+ features = features.to(device)
+ periods = periods.to(device)
+ lpcs = lpcs.to(device)
+
+ # run feature encoding
+ c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))
+
+ for frame_index in range(num_frames):
+ frame_start = frame_index * self.frame_size
+ pitch_corr = features[frame_index + self.feature_history, pitch_corr_position]
+ a = - torch.flip(lpcs[frame_index + self.feature_history], [0])
+ current_c = c[:, frame_index : frame_index + 1, :]
+
+ for i in range(0, self.frame_size, self.group_size_gru_a):
+ pcm_position = frame_start + i + lpc_order
+ output_position = frame_start + i
+
+ # calculate newest prediction
+ prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
+
+ # prepare input
+ for slot in range(self.group_size_gru_a):
+ k = slot - self.group_size_gru_a + 1
+ for idx, name in enumerate(self.signals):
+ input_signals[idx + slot * num_input_signals] = lin2ulawq(
+ signal_dict[name][pcm_position + k]
+ )
+
+
+ # run GRU A
+ embed_signals = self.signal_embedding(input_signals.reshape((1, 1, -1)))
+ embed_signals = torch.flatten(embed_signals, 2)
+ y = torch.cat((embed_signals, current_c), dim=-1)
+ h_a, gru_a_state = self.gru_a(y, gru_a_state)
+
+ # loop over substeps_a
+ for step_a in range(self.substeps_a):
+ # prepare conditioning input
+ for slot in range(self.group_size_subcondition_a):
+ k = slot - self.group_size_subcondition_a + 1
+ for idx, name in enumerate(self.subcondition_signals_a):
+ c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
+ signal_dict[name][pcm_position + k]
+ )
+
+ # subconditioning
+ h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))
+
+ # run GRU B
+ y = torch.cat((h_a, current_c), dim=-1)
+ h_b, gru_b_state = self.gru_b(y, gru_b_state)
+
+ # loop over substeps b
+ for step_b in range(self.substeps_b):
+ # prepare subconditioning input
+ for idx, name in enumerate(self.subcondition_signals_b):
+ c_signals_b[idx] = lin2ulawq(
+ signal_dict[name][pcm_position]
+ )
+
+ # subcondition
+ h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))
+
+ # run dual FC
+ probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)
+
+ # sample
+ new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))
+
+ # update signals
+ sig = new_exc + prediction[pcm_position]
+ last_error[pcm_position + 1] = new_exc
+ last_signal[pcm_position + 1] = sig
+
+ mem = 0.85 * mem + float(sig)
+ output[output_position] = clip_to_int16(round(mem))
+
+ # increase positions
+ pcm_position += 1
+ output_position += 1
+
+ # calculate next prediction
+ prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)
+
+ return output