diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-08-01 19:18:28 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-08-01 19:18:28 +0300 |
commit | 902d763622fd6f7665c614bf66daaf8b8ba9fc48 (patch) | |
tree | 4c1725fcccc6897b0002c3bff9d4e25ce030d192 | |
parent | 9691440a5f9ce8cbbd33b119bb9af881e4dee1a2 (diff) |
added FWGAN weight dumping code
-rw-r--r-- | dnn/torch/fwgan/dump_model_weights.py | 89 | ||||
-rw-r--r-- | dnn/torch/fwgan/inference.py | 141 | ||||
-rw-r--r-- | dnn/torch/fwgan/models/__init__.py | 7 | ||||
-rw-r--r-- | dnn/torch/fwgan/models/fwgan400.py | 308 | ||||
-rw-r--r-- | dnn/torch/fwgan/models/fwgan500.py | 260 |
5 files changed, 805 insertions, 0 deletions
diff --git a/dnn/torch/fwgan/dump_model_weights.py b/dnn/torch/fwgan/dump_model_weights.py new file mode 100644 index 00000000..f4e38c15 --- /dev/null +++ b/dnn/torch/fwgan/dump_model_weights.py @@ -0,0 +1,89 @@ +import os +import sys +import argparse + +import torch +from torch import nn + + +sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange')) +import wexchange.torch + +from models import model_dict + +unquantized = [ + 'feat_in_conv1.conv', + 'bfcc_with_corr_upsampler.fc', + 'cont_net.0', + 'fwc6.cont_fc.0', + 'fwc6.fc.0', + 'fwc6.fc.1.gate', + 'fwc7.cont_fc.0', + 'fwc7.fc.0', + 'fwc7.fc.1.gate' +] + +description=f""" +This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer +and will fail to export any other weights. + +Furthermore, the quanitze option relies on the following explicit list of layers to be excluded: +{unquantized}. + +Modify this script manually if adjustments are needed. +""" + +parser = argparse.ArgumentParser(description=description) +parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name') +parser.add_argument('weightfile', type=str, help='weight file path') +parser.add_argument('export_folder', type=str) +parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data') +parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN') +parser.add_argument('--quantize', action='store_true', help='apply quantization') + +if __name__ == "__main__": + args = parser.parse_args() + + model = model_dict[args.model]() + + print(f"loading weights from {args.weightfile}...") + saved_gen= torch.load(args.weightfile, map_location='cpu') + model.load_state_dict(saved_gen) + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + model.apply(_remove_weight_norm) + + + print("dumping model...") + quantize_model=args.quantize + + output_folder = args.export_folder + os.makedirs(output_folder, exist_ok=True) + + writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name) + + for name, module in model.named_modules(): + + if quantize_model: + quantize=name not in unquantized + scale = None if quantize else 1/128 + else: + quantize=False + scale=1/128 + + if isinstance(module, nn.Linear): + print(f"dumping linear layer {name}...") + wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) + + if isinstance(module, nn.Conv1d): + print(f"dumping conv1d layer {name}...") + wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) + + if isinstance(module, nn.GRU): + print(f"dumping GRU layer {name}...") + wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) + + writer.close() diff --git a/dnn/torch/fwgan/inference.py b/dnn/torch/fwgan/inference.py new file mode 100644 index 00000000..c06b68b1 --- /dev/null +++ b/dnn/torch/fwgan/inference.py @@ -0,0 +1,141 @@ +import os +import time +import torch +import numpy as np +from scipy import signal as si +from scipy.io import wavfile +import argparse + +from models import model_dict + +parser = argparse.ArgumentParser() +parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name') +parser.add_argument('weightfile', type=str, help='weight file') +parser.add_argument('input', type=str, help='input: feature file or folder with feature files') +parser.add_argument('output', type=str, help='output: wav file name or folder name, depending on input') + + +########################### Signal Processing Layers ########################### + +def preemphasis(x, coef= -0.85): + + return si.lfilter(np.array([1.0, coef]), np.array([1.0]), x).astype('float32') + +def deemphasis(x, coef= -0.85): + + return si.lfilter(np.array([1.0]), np.array([1.0, coef]), x).astype('float32') + +gamma = 0.92 +weighting_vector = np.array([gamma**i for i in range(16,0,-1)]) + + +def lpc_synthesis_one_frame(frame, filt, buffer, weighting_vector=np.ones(16)): + + out = np.zeros_like(frame) + + filt = np.flip(filt) + + inp = frame[:] + + + for i in range(0, inp.shape[0]): + + s = inp[i] - np.dot(buffer*weighting_vector, filt) + + buffer[0] = s + + buffer = np.roll(buffer, -1) + + out[i] = s + + return out + +def inverse_perceptual_weighting (pw_signal, filters, weighting_vector): + + #inverse perceptual weighting= H_preemph / W(z/gamma) + + pw_signal = preemphasis(pw_signal) + + signal = np.zeros_like(pw_signal) + buffer = np.zeros(16) + num_frames = pw_signal.shape[0] //160 + assert num_frames == filters.shape[0] + + for frame_idx in range(0, num_frames): + + in_frame = pw_signal[frame_idx*160: (frame_idx+1)*160][:] + out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer, weighting_vector) + signal[frame_idx*160: (frame_idx+1)*160] = out_sig_frame[:] + buffer[:] = out_sig_frame[-16:] + + return signal + + +def process_item(generator, feature_filename, output_filename, verbose=False): + + feat = np.memmap(feature_filename, dtype='float32', mode='r') + + num_feat_frames = len(feat) // 36 + feat = np.reshape(feat, (num_feat_frames, 36)) + + bfcc = np.copy(feat[:, :18]) + corr = np.copy(feat[:, 19:20]) + 0.5 + bfcc_with_corr = torch.from_numpy(np.hstack((bfcc, corr))).type(torch.FloatTensor).unsqueeze(0)#.to(device) + + period = torch.from_numpy((0.1 + 50 * np.copy(feat[:, 18:19]) + 100)\ + .astype('int32')).type(torch.long).view(1,-1)#.to(device) + + lpc_filters = np.copy(feat[:, -16:]) + + start_time = time.time() + x1 = generator(period, bfcc_with_corr, torch.zeros(1,320)) #this means the vocoder runs in complete synthesis mode with zero history audio frames + end_time = time.time() + total_time = end_time - start_time + x1 = x1.squeeze(1).squeeze(0).detach().cpu().numpy() + gen_seconds = len(x1)/16000 + out = deemphasis(inverse_perceptual_weighting(x1, lpc_filters, weighting_vector)) + if verbose: + print(f"Took {total_time:.3f}s to generate {len(x1)} samples ({gen_seconds}s) -> {gen_seconds/total_time:.2f}x real time") + + out = np.clip(np.round(2**15 * out), -2**15, 2**15 -1).astype(np.int16) + wavfile.write(output_filename, 16000, out) + + +########################### The inference loop over folder containing lpcnet feature files ################################# +if __name__ == "__main__": + + args = parser.parse_args() + + generator = model_dict[args.model]() + + + #Load the FWGAN500Hz Checkpoint + saved_gen= torch.load(args.weightfile, map_location='cpu') + generator.load_state_dict(saved_gen) + + #this is just to remove the weight_norm from the model layers as it's no longer needed + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + generator.apply(_remove_weight_norm) + + #enable inference mode + generator = generator.eval() + + print('Successfully loaded the generator model ... start generation:') + + if os.path.isdir(args.input): + + os.makedirs(args.output, exist_ok=True) + + for fn in os.listdir(args.input): + print(f"processing input {fn}...") + feature_filename = os.path.join(args.input, fn) + output_filename = os.path.join(args.output, os.path.splitext(fn)[0] + f"_{args.model}.wav") + process_item(generator, feature_filename, output_filename) + else: + process_item(generator, args.input, args.output) + + print("Finished!")
\ No newline at end of file diff --git a/dnn/torch/fwgan/models/__init__.py b/dnn/torch/fwgan/models/__init__.py new file mode 100644 index 00000000..d52a6eb0 --- /dev/null +++ b/dnn/torch/fwgan/models/__init__.py @@ -0,0 +1,7 @@ +from .fwgan400 import FWGAN400ContLarge +from .fwgan500 import FWGAN500Cont + +model_dict = { + 'fwgan400': FWGAN400ContLarge, + 'fwgan500': FWGAN500Cont +}
\ No newline at end of file diff --git a/dnn/torch/fwgan/models/fwgan400.py b/dnn/torch/fwgan/models/fwgan400.py new file mode 100644 index 00000000..84d9849e --- /dev/null +++ b/dnn/torch/fwgan/models/fwgan400.py @@ -0,0 +1,308 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import numpy as np + +which_norm = weight_norm + +#################### Definition of basic model components #################### + +#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet) +class ConvLookahead(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False): + super(ConvLookahead, self).__init__() + torch.manual_seed(5) + + self.padding_left = (kernel_size - 2) * dilation + self.padding_right = 1 * dilation + + self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + x = F.pad(x,(self.padding_left, self.padding_right)) + conv_out = self.conv(x) + return conv_out + +#(modified) GLU Activation layer definition +class GLU(nn.Module): + def __init__(self, feat_size): + super(GLU, self).__init__() + + torch.manual_seed(5) + + self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + out = torch.tanh(x) * torch.sigmoid(self.gate(x)) + + return out + +#GRU layer definition +class ContForwardGRU(nn.Module): + def __init__(self, input_size, hidden_size, num_layers=1): + super(ContForwardGRU, self).__init__() + + torch.manual_seed(5) + + self.hidden_size = hidden_size + + self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.hidden_size, bias=False)), + nn.Tanh()) + + self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\ + bias=False) + + self.nl = GLU(self.hidden_size) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x, x0): + + self.gru.flatten_parameters() + + h0 = self.cont_fc(x0).unsqueeze(0) + + output, h0 = self.gru(x, h0) + + return self.nl(output) + +# Framewise convolution layer definition +class ContFramewiseConv(torch.nn.Module): + + def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True): + + super(ContFramewiseConv, self).__init__() + torch.manual_seed(5) + + self.frame_kernel_size = frame_kernel_size + self.frame_len = frame_len + + if (causal == True) or (self.frame_kernel_size == 2): + + self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len + self.required_pad_right = 0 + + self.cont_fc = nn.Sequential(which_norm(nn.Linear(64, self.required_pad_left, bias=False)), + nn.Tanh() + ) + + else: + + self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len + self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len + + self.fc_input_dim = self.frame_kernel_size * self.frame_len + self.fc_out_dim = out_dim + + if act=='glu': + self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), + GLU(self.fc_out_dim) + ) + if act=='tanh': + self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), + nn.Tanh() + ) + + self.init_weights() + + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ + isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x, x0): + + if self.frame_kernel_size == 1: + return self.fc(x) + + x_flat = x.reshape(x.size(0),1,-1) + pad = self.cont_fc(x0).view(x0.size(0),1,-1) + x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2) + + x_flat_padded_unfolded = F.unfold(x_flat_padded,\ + kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous() + + out = self.fc(x_flat_padded_unfolded) + return out + +# A fully-connected based upsampling layer definition +class UpsampleFC(nn.Module): + def __init__(self, in_ch, out_ch, upsample_factor): + super(UpsampleFC, self).__init__() + torch.manual_seed(5) + + self.in_ch = in_ch + self.out_ch = out_ch + self.upsample_factor = upsample_factor + self.fc = nn.Linear(in_ch, out_ch * upsample_factor, bias=False) + self.nl = nn.Tanh() + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or\ + isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + batch_size = x.size(0) + x = x.permute(0, 2, 1) + x = self.nl(self.fc(x)) + x = x.reshape((batch_size, -1, self.out_ch)) + x = x.permute(0, 2, 1) + return x + +########################### The complete model definition ################################# + +class FWGAN400ContLarge(nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(5) + + self.bfcc_with_corr_upsampler = UpsampleFC(19,80,4) + + self.feat_in_conv1 = ConvLookahead(160,256,kernel_size=5) + self.feat_in_nl1 = GLU(256) + + self.cont_net = nn.Sequential(which_norm(nn.Linear(321, 160, bias=False)), + nn.Tanh(), + which_norm(nn.Linear(160, 160, bias=False)), + nn.Tanh(), + which_norm(nn.Linear(160, 80, bias=False)), + nn.Tanh(), + which_norm(nn.Linear(80, 80, bias=False)), + nn.Tanh(), + which_norm(nn.Linear(80, 64, bias=False)), + nn.Tanh(), + which_norm(nn.Linear(64, 64, bias=False)), + nn.Tanh()) + + self.rnn = ContForwardGRU(256,256) + + self.fwc1 = ContFramewiseConv(256, 256) + self.fwc2 = ContFramewiseConv(256, 128) + self.fwc3 = ContFramewiseConv(128, 128) + self.fwc4 = ContFramewiseConv(128, 64) + self.fwc5 = ContFramewiseConv(64, 64) + self.fwc6 = ContFramewiseConv(64, 40) + self.fwc7 = ContFramewiseConv(40, 40) + + self.init_weights() + self.count_parameters() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ + isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def count_parameters(self): + num_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n") + + def create_phase_signals(self, periods): + + batch_size = periods.size(0) + progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) + progression = torch.repeat_interleave(progression, batch_size, 0) + + phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) + chunks = [] + for sframe in range(periods.size(1)): + f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) + + chunk_sin = torch.sin(f * progression + phase0) + chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,40) + + chunk_cos = torch.cos(f * progression + phase0) + chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,40) + + chunk = torch.cat((chunk_sin, chunk_cos), dim = -1) + + phase0 = phase0 + 160 * f + + chunks.append(chunk) + + phase_signals = torch.cat(chunks, dim=1) + + return phase_signals + + + def gain_multiply(self, x, c0): + + gain = 10**(0.5*c0/np.sqrt(18.0)) + gain = torch.repeat_interleave(gain, 160, dim=-1) + gain = gain.reshape(gain.size(0),1,-1).squeeze(1) + + return x * gain + + def forward(self, pitch_period, bfcc_with_corr, x0): + + norm_x0 = torch.norm(x0,2, dim=-1, keepdim=True) + x0 = x0 / torch.sqrt((1e-8) + norm_x0**2) + x0 = torch.cat((torch.log(norm_x0 + 1e-7), x0), dim=-1) + + p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous() + + envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous()) + + feat_in = torch.cat((p_embed , envelope), dim=1) + + wav_latent1 = self.feat_in_nl1(self.feat_in_conv1(feat_in).permute(0,2,1).contiguous()) + + cont_latent = self.cont_net(x0) + + rnn_out = self.rnn(wav_latent1, cont_latent) + + fwc1_out = self.fwc1(rnn_out, cont_latent) + + fwc2_out = self.fwc2(fwc1_out, cont_latent) + + fwc3_out = self.fwc3(fwc2_out, cont_latent) + + fwc4_out = self.fwc4(fwc3_out, cont_latent) + + fwc5_out = self.fwc5(fwc4_out, cont_latent) + + fwc6_out = self.fwc6(fwc5_out, cont_latent) + + fwc7_out = self.fwc7(fwc6_out, cont_latent) + + waveform = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1) + + waveform = self.gain_multiply(waveform,bfcc_with_corr[:,:,:1]) + + return waveform
\ No newline at end of file diff --git a/dnn/torch/fwgan/models/fwgan500.py b/dnn/torch/fwgan/models/fwgan500.py new file mode 100644 index 00000000..2c6dea5f --- /dev/null +++ b/dnn/torch/fwgan/models/fwgan500.py @@ -0,0 +1,260 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import numpy as np + + +which_norm = weight_norm + +#################### Definition of basic model components #################### + +#Convolutional layer with 1 frame look-ahead (used for feature PreCondNet) +class ConvLookahead(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size, dilation=1, groups=1, bias= False): + super(ConvLookahead, self).__init__() + torch.manual_seed(5) + + self.padding_left = (kernel_size - 2) * dilation + self.padding_right = 1 * dilation + + self.conv = which_norm(nn.Conv1d(in_ch,out_ch,kernel_size,dilation=dilation, groups=groups, bias= bias)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + x = F.pad(x,(self.padding_left, self.padding_right)) + conv_out = self.conv(x) + return conv_out + +#(modified) GLU Activation layer definition +class GLU(nn.Module): + def __init__(self, feat_size): + super(GLU, self).__init__() + + torch.manual_seed(5) + + self.gate = which_norm(nn.Linear(feat_size, feat_size, bias=False)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + out = torch.tanh(x) * torch.sigmoid(self.gate(x)) + + return out + +#GRU layer definition +class ContForwardGRU(nn.Module): + def __init__(self, input_size, hidden_size, num_layers=1): + super(ContForwardGRU, self).__init__() + + torch.manual_seed(5) + + self.hidden_size = hidden_size + + #This is to initialize the layer with history audio samples for continuation. + self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.hidden_size, bias=False)), + nn.Tanh()) + + self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\ + bias=False) + + self.nl = GLU(self.hidden_size) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x, x0): + + self.gru.flatten_parameters() + + h0 = self.cont_fc(x0).unsqueeze(0) + + output, h0 = self.gru(x, h0) + + return self.nl(output) + +# Framewise convolution layer definition +class ContFramewiseConv(torch.nn.Module): + + def __init__(self, frame_len, out_dim, frame_kernel_size=3, act='glu', causal=True): + + super(ContFramewiseConv, self).__init__() + torch.manual_seed(5) + + self.frame_kernel_size = frame_kernel_size + self.frame_len = frame_len + + if (causal == True) or (self.frame_kernel_size == 2): + + self.required_pad_left = (self.frame_kernel_size - 1) * self.frame_len + self.required_pad_right = 0 + + #This is to initialize the layer with history audio samples for continuation. + self.cont_fc = nn.Sequential(which_norm(nn.Linear(320, self.required_pad_left, bias=False)), + nn.Tanh() + ) + + else: + #This means non-causal frame-wise convolution. We don't use it at the moment + self.required_pad_left = (self.frame_kernel_size - 1)//2 * self.frame_len + self.required_pad_right = (self.frame_kernel_size - 1)//2 * self.frame_len + + self.fc_input_dim = self.frame_kernel_size * self.frame_len + self.fc_out_dim = out_dim + + if act=='glu': + self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), + GLU(self.fc_out_dim) + ) + if act=='tanh': + self.fc = nn.Sequential(which_norm(nn.Linear(self.fc_input_dim, self.fc_out_dim, bias=False)), + nn.Tanh() + ) + + self.init_weights() + + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ + isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x, x0): + + if self.frame_kernel_size == 1: + return self.fc(x) + + x_flat = x.reshape(x.size(0),1,-1) + pad = self.cont_fc(x0).view(x0.size(0),1,-1) + x_flat_padded = torch.cat((pad, x_flat), dim=-1).unsqueeze(2) + + x_flat_padded_unfolded = F.unfold(x_flat_padded,\ + kernel_size= (1,self.fc_input_dim), stride=self.frame_len).permute(0,2,1).contiguous() + + out = self.fc(x_flat_padded_unfolded) + return out + +########################### The complete model definition ################################# + +class FWGAN500Cont(nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(5) + + #PrecondNet: + self.bfcc_with_corr_upsampler = nn.Sequential(nn.ConvTranspose1d(19,64,kernel_size=5,stride=5,padding=0,\ + bias=False), + nn.Tanh()) + + self.feat_in_conv = ConvLookahead(128,256,kernel_size=5) + self.feat_in_nl = GLU(256) + + #GRU: + self.rnn = ContForwardGRU(256,256) + + #Frame-wise convolution stack: + self.fwc1 = ContFramewiseConv(256, 256) + self.fwc2 = ContFramewiseConv(256, 128) + self.fwc3 = ContFramewiseConv(128, 128) + self.fwc4 = ContFramewiseConv(128, 64) + self.fwc5 = ContFramewiseConv(64, 64) + self.fwc6 = ContFramewiseConv(64, 32) + self.fwc7 = ContFramewiseConv(32, 32, act='tanh') + + self.init_weights() + self.count_parameters() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or\ + isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def count_parameters(self): + num_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f"Total number of {self.__class__.__name__} network parameters = {num_params}\n") + + def create_phase_signals(self, periods): + + batch_size = periods.size(0) + progression = torch.arange(1, 160 + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) + progression = torch.repeat_interleave(progression, batch_size, 0) + + phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) + chunks = [] + for sframe in range(periods.size(1)): + f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) + + chunk_sin = torch.sin(f * progression + phase0) + chunk_sin = chunk_sin.reshape(chunk_sin.size(0),-1,32) + + chunk_cos = torch.cos(f * progression + phase0) + chunk_cos = chunk_cos.reshape(chunk_cos.size(0),-1,32) + + chunk = torch.cat((chunk_sin, chunk_cos), dim = -1) + + phase0 = phase0 + 160 * f + + chunks.append(chunk) + + phase_signals = torch.cat(chunks, dim=1) + + return phase_signals + + + def gain_multiply(self, x, c0): + + gain = 10**(0.5*c0/np.sqrt(18.0)) + gain = torch.repeat_interleave(gain, 160, dim=-1) + gain = gain.reshape(gain.size(0),1,-1).squeeze(1) + + return x * gain + + def forward(self, pitch_period, bfcc_with_corr, x0): + + #This should create a latent representation of shape [Batch_dim, 500 frames, 256 elemets per frame] + p_embed = self.create_phase_signals(pitch_period).permute(0, 2, 1).contiguous() + envelope = self.bfcc_with_corr_upsampler(bfcc_with_corr.permute(0,2,1).contiguous()) + feat_in = torch.cat((p_embed , envelope), dim=1) + wav_latent = self.feat_in_nl(self.feat_in_conv(feat_in).permute(0,2,1).contiguous()) + + #Generation with continuation using history samples x0 starts from here: + + rnn_out = self.rnn(wav_latent, x0) + + fwc1_out = self.fwc1(rnn_out, x0) + fwc2_out = self.fwc2(fwc1_out, x0) + fwc3_out = self.fwc3(fwc2_out, x0) + fwc4_out = self.fwc4(fwc3_out, x0) + fwc5_out = self.fwc5(fwc4_out, x0) + fwc6_out = self.fwc6(fwc5_out, x0) + fwc7_out = self.fwc7(fwc6_out, x0) + + waveform_unscaled = fwc7_out.reshape(fwc7_out.size(0),1,-1).squeeze(1) + waveform = self.gain_multiply(waveform_unscaled,bfcc_with_corr[:,:,:1]) + + return waveform |