diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-09-12 15:50:24 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-09-12 15:50:24 +0300 |
commit | 2f290d32ed79ad172b5981498711a6291b1f88a2 (patch) | |
tree | 59d94c2ad6bd50ddb8be872c06cf4a82f4f0e173 | |
parent | 7b8ba143f1a1688d4a2527ae3124c9cf65ead55a (diff) |
added more enhancement stuff
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
24 files changed, 3511 insertions, 108 deletions
diff --git a/dnn/torch/osce/adv_train_model.py b/dnn/torch/osce/adv_train_model.py new file mode 100644 index 00000000..9cd32000 --- /dev/null +++ b/dnn/torch/osce/adv_train_model.py @@ -0,0 +1,458 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import sys +import math as m +import random + +import yaml + +from tqdm import tqdm + +try: + import git + has_git = True +except: + has_git = False + +import torch +from torch.optim.lr_scheduler import LambdaLR +import torch.nn.functional as F + +from scipy.io import wavfile +import numpy as np +import pesq + +from data import SilkEnhancementSet +from models import model_dict + + +from utils.silk_features import load_inference_data +from utils.misc import count_parameters, retain_grads, get_grad_norm, create_weights + +from losses.stft_loss import MRSTFTLoss, MRLogMelLoss + + +parser = argparse.ArgumentParser() + +parser.add_argument('setup', type=str, help='setup yaml file') +parser.add_argument('output', type=str, help='output path') +parser.add_argument('--device', type=str, help='compute device', default=None) +parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) +parser.add_argument('--testdata', type=str, help='path to features and signal for testing', default=None) +parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') + +args = parser.parse_args() + + +torch.set_num_threads(4) + +with open(args.setup, 'r') as f: + setup = yaml.load(f.read(), yaml.FullLoader) + +checkpoint_prefix = 'checkpoint' +output_prefix = 'output' +setup_name = 'setup.yml' +output_file='out.txt' + + +# check model +if not 'name' in setup['model']: + print(f'warning: did not find model entry in setup, using default PitchPostFilter') + model_name = 'pitchpostfilter' +else: + model_name = setup['model']['name'] + +# prepare output folder +if os.path.exists(args.output): + print("warning: output folder exists") + + reply = input('continue? (y/n): ') + while reply not in {'y', 'n'}: + reply = input('continue? (y/n): ') + + if reply == 'n': + os._exit() +else: + os.makedirs(args.output, exist_ok=True) + +checkpoint_dir = os.path.join(args.output, 'checkpoints') +os.makedirs(checkpoint_dir, exist_ok=True) + +# add repo info to setup +if has_git: + working_dir = os.path.split(__file__)[0] + try: + repo = git.Repo(working_dir) + setup['repo'] = dict() + hash = repo.head.object.hexsha + urls = list(repo.remote().urls) + is_dirty = repo.is_dirty() + + if is_dirty: + print("warning: repo is dirty") + + setup['repo']['hash'] = hash + setup['repo']['urls'] = urls + setup['repo']['dirty'] = is_dirty + except: + has_git = False + +# dump setup +with open(os.path.join(args.output, setup_name), 'w') as f: + yaml.dump(setup, f) + + +ref = None +if args.testdata is not None: + + testsignal, features, periods, numbits = load_inference_data(args.testdata, **setup['data']) + + inference_test = True + inference_folder = os.path.join(args.output, 'inference_test') + os.makedirs(os.path.join(args.output, 'inference_test'), exist_ok=True) + + try: + ref = np.fromfile(os.path.join(args.testdata, 'clean.s16'), dtype=np.int16) + except: + pass +else: + inference_test = False + +# training parameters +batch_size = setup['training']['batch_size'] +epochs = setup['training']['epochs'] +lr = setup['training']['lr'] +lr_decay_factor = setup['training']['lr_decay_factor'] +lr_gen = lr * setup['training']['gen_lr_reduction'] +lambda_feat = setup['training']['lambda_feat'] +lambda_reg = setup['training']['lambda_reg'] +adv_target = setup['training'].get('adv_target', 'target') + +# load training dataset +data_config = setup['data'] +data = SilkEnhancementSet(setup['dataset'], **data_config) + +# load validation dataset if given +if 'validation_dataset' in setup: + validation_data = SilkEnhancementSet(setup['validation_dataset'], **data_config) + + validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4) + + run_validation = True +else: + run_validation = False + +# create model +model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) + +# create discriminator +disc_name = setup['discriminator']['name'] +disc = model_dict[disc_name]( + *setup['discriminator']['args'], **setup['discriminator']['kwargs'] +) + +# set compute device +if type(args.device) == type(None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +else: + device = torch.device(args.device) + +# dataloader +dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4) + +# optimizer is introduced to trainable parameters +parameters = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.Adam(parameters, lr=lr_gen) + +# disc optimizer +parameters = [p for p in disc.parameters() if p.requires_grad] +optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9]) + +# learning rate scheduler +scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) + +if args.initial_checkpoint is not None: + print(f"loading state dict from {args.initial_checkpoint}...") + chkpt = torch.load(args.initial_checkpoint, map_location=device) + model.load_state_dict(chkpt['state_dict']) + + if 'disc_state_dict' in chkpt: + print(f"loading discriminator state dict from {args.initial_checkpoint}...") + disc.load_state_dict(chkpt['disc_state_dict']) + + if 'optimizer_state_dict' in chkpt: + print(f"loading optimizer state dict from {args.initial_checkpoint}...") + optimizer.load_state_dict(chkpt['optimizer_state_dict']) + + if 'disc_optimizer_state_dict' in chkpt: + print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...") + optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict']) + + if 'scheduler_state_disc' in chkpt: + print(f"loading scheduler state dict from {args.initial_checkpoint}...") + scheduler.load_state_dict(chkpt['scheduler_state_dict']) + + # if 'torch_rng_state' in chkpt: + # print(f"setting torch RNG state from {args.initial_checkpoint}...") + # torch.set_rng_state(chkpt['torch_rng_state']) + + if 'numpy_rng_state' in chkpt: + print(f"setting numpy RNG state from {args.initial_checkpoint}...") + np.random.set_state(chkpt['numpy_rng_state']) + + if 'python_rng_state' in chkpt: + print(f"setting Python RNG state from {args.initial_checkpoint}...") + random.setstate(chkpt['python_rng_state']) + +# loss +w_l1 = setup['training']['loss']['w_l1'] +w_lm = setup['training']['loss']['w_lm'] +w_slm = setup['training']['loss']['w_slm'] +w_sc = setup['training']['loss']['w_sc'] +w_logmel = setup['training']['loss']['w_logmel'] +w_wsc = setup['training']['loss']['w_wsc'] +w_xcorr = setup['training']['loss']['w_xcorr'] +w_sxcorr = setup['training']['loss']['w_sxcorr'] +w_l2 = setup['training']['loss']['w_l2'] + +w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + +stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) +logmelloss = MRLogMelLoss().to(device) + +def xcorr_loss(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) + + return torch.mean(loss) + +def td_l2_norm(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) + + return loss.mean() + +def td_l1(y_true, y_pred, pow=0): + dims = list(range(1, len(y_true.shape))) + tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) + + return torch.mean(tmp) + +def criterion(x, y): + + return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum + + +# model checkpoint +checkpoint = { + 'setup' : setup, + 'state_dict' : model.state_dict(), + 'loss' : -1 +} + + +if not args.no_redirect: + print(f"re-directing output to {os.path.join(args.output, output_file)}") + sys.stdout = open(os.path.join(args.output, output_file), "w") + + +print("summary:") + +print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters") +if hasattr(model, 'flop_count'): + print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS") +print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters") + +if ref is not None: + noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16) + initial_mos = pesq.pesq(16000, ref, noisy, mode='wb') + print(f"initial MOS (PESQ): {initial_mos}") + +best_loss = 1e9 +log_interval = 10 + + +m_r = 0 +m_f = 0 +s_r = 1 +s_f = 1 + +def optimizer_to(optim, device): + for param in optim.state.values(): + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + +optimizer_to(optimizer, device) +optimizer_to(optimizer_disc, device) + +retain_grads(model) +retain_grads(disc) + +for ep in range(1, epochs + 1): + print(f"training epoch {ep}...") + + model.to(device) + disc.to(device) + model.train() + disc.train() + + running_disc_loss = 0 + running_adv_loss = 0 + running_feature_loss = 0 + running_reg_loss = 0 + running_disc_grad_norm = 0 + running_model_grad_norm = 0 + + with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: + for i, batch in enumerate(tepoch): + + # set gradients to zero + optimizer.zero_grad() + + # push batch to device + for key in batch: + batch[key] = batch[key].to(device) + + target = batch['target'].to(device) + disc_target = batch[adv_target].to(device) + + # calculate model output + output = model(batch['signals'].permute(0, 2, 1), batch['features'], batch['periods'], batch['numbits']) + + # discriminator update + scores_gen = disc(output.detach()) + scores_real = disc(disc_target.unsqueeze(1)) + + disc_loss = 0 + for score in scores_gen: + disc_loss += (((score[-1]) ** 2)).mean() + m_f = 0.9 * m_f + 0.1 * score[-1].detach().mean().cpu().item() + s_f = 0.9 * s_f + 0.1 * score[-1].detach().std().cpu().item() + + for score in scores_real: + disc_loss += (((1 - score[-1]) ** 2)).mean() + m_r = 0.9 * m_r + 0.1 * score[-1].detach().mean().cpu().item() + s_r = 0.9 * s_r + 0.1 * score[-1].detach().std().cpu().item() + + disc_loss = 0.5 * disc_loss / len(scores_gen) + winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) ) + + disc.zero_grad() + disc_loss.backward() + + running_disc_grad_norm += get_grad_norm(disc).detach().cpu().item() + + optimizer_disc.step() + + # generator update + scores_gen = disc(output) + + # calculate loss + loss_reg = criterion(output.squeeze(1), target) + + num_discs = len(scores_gen) + gen_loss = 0 + for score in scores_gen: + gen_loss += (((1 - score[-1]) ** 2)).mean() / num_discs + + loss_feat = 0 + for k in range(num_discs): + num_layers = len(scores_gen[k]) - 1 + f = 4 / num_discs / num_layers + for l in range(num_layers): + loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach()) + + model.zero_grad() + + (gen_loss + lambda_feat * loss_feat + lambda_reg * loss_reg).backward() + + optimizer.step() + + running_model_grad_norm += get_grad_norm(model).detach().cpu().item() + running_adv_loss += gen_loss.detach().cpu().item() + running_disc_loss += disc_loss.detach().cpu().item() + running_feature_loss += lambda_feat * loss_feat.detach().cpu().item() + running_reg_loss += lambda_reg * loss_reg.detach().cpu().item() + + # update status bar + if i % log_interval == 0: + tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}", + disc_loss=f"{running_disc_loss/(i + 1):8.7f}", + feat_loss=f"{running_feature_loss/(i + 1):8.7f}", + reg_loss=f"{running_reg_loss/(i + 1):8.7f}", + model_gradnorm=f"{running_model_grad_norm/(i+1):8.7f}", + disc_gradnorm=f"{running_disc_grad_norm/(i+1):8.7f}", + wc=f"{100*winning_chance:5.2f}%") + + + # save checkpoint + checkpoint['state_dict'] = model.state_dict() + checkpoint['disc_state_dict'] = disc.state_dict() + checkpoint['optimizer_state_dict'] = optimizer.state_dict() + checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict() + checkpoint['scheduler_state_dict'] = scheduler.state_dict() + checkpoint['torch_rng_state'] = torch.get_rng_state() + checkpoint['numpy_rng_state'] = np.random.get_state() + checkpoint['python_rng_state'] = random.getstate() + checkpoint['adv_loss'] = running_adv_loss/(i + 1) + checkpoint['disc_loss'] = running_disc_loss/(i + 1) + checkpoint['feature_loss'] = running_feature_loss/(i + 1) + checkpoint['reg_loss'] = running_reg_loss/(i + 1) + + + if inference_test: + print("running inference test...") + out = model.process(testsignal, features, periods, numbits).cpu().numpy() + wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) + if ref is not None: + mos = pesq.pesq(16000, ref, out, mode='wb') + print(f"MOS (PESQ): {mos}") + + + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) + + + print() + +print('Done') diff --git a/dnn/torch/osce/adv_train_vocoder.py b/dnn/torch/osce/adv_train_vocoder.py new file mode 100644 index 00000000..754a1529 --- /dev/null +++ b/dnn/torch/osce/adv_train_vocoder.py @@ -0,0 +1,451 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import sys +import math as m +import random + +import yaml + +from tqdm import tqdm + +try: + import git + has_git = True +except: + has_git = False + +import torch +from torch.optim.lr_scheduler import LambdaLR +import torch.nn.functional as F + +from scipy.io import wavfile +import numpy as np +import pesq + +from data import LPCNetVocodingDataset +from models import model_dict + + +from utils.lpcnet_features import load_lpcnet_features +from utils.misc import count_parameters + +from losses.stft_loss import MRSTFTLoss, MRLogMelLoss + + +parser = argparse.ArgumentParser() + +parser.add_argument('setup', type=str, help='setup yaml file') +parser.add_argument('output', type=str, help='output path') +parser.add_argument('--device', type=str, help='compute device', default=None) +parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) +parser.add_argument('--test-features', type=str, help='path to features for testing', default=None) +parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') + +args = parser.parse_args() + + +torch.set_num_threads(4) + +with open(args.setup, 'r') as f: + setup = yaml.load(f.read(), yaml.FullLoader) + +checkpoint_prefix = 'checkpoint' +output_prefix = 'output' +setup_name = 'setup.yml' +output_file='out.txt' + + +# check model +if not 'name' in setup['model']: + print(f'warning: did not find model entry in setup, using default PitchPostFilter') + model_name = 'pitchpostfilter' +else: + model_name = setup['model']['name'] + +# prepare output folder +if os.path.exists(args.output): + print("warning: output folder exists") + + reply = input('continue? (y/n): ') + while reply not in {'y', 'n'}: + reply = input('continue? (y/n): ') + + if reply == 'n': + os._exit() +else: + os.makedirs(args.output, exist_ok=True) + +checkpoint_dir = os.path.join(args.output, 'checkpoints') +os.makedirs(checkpoint_dir, exist_ok=True) + +# add repo info to setup +if has_git: + working_dir = os.path.split(__file__)[0] + try: + repo = git.Repo(working_dir) + setup['repo'] = dict() + hash = repo.head.object.hexsha + urls = list(repo.remote().urls) + is_dirty = repo.is_dirty() + + if is_dirty: + print("warning: repo is dirty") + + setup['repo']['hash'] = hash + setup['repo']['urls'] = urls + setup['repo']['dirty'] = is_dirty + except: + has_git = False + +# dump setup +with open(os.path.join(args.output, setup_name), 'w') as f: + yaml.dump(setup, f) + + +ref = None +# prepare inference test if wanted +inference_test = False +if type(args.test_features) != type(None): + test_features = load_lpcnet_features(args.test_features) + features = test_features['features'] + periods = test_features['periods'] + inference_folder = os.path.join(args.output, 'inference_test') + os.makedirs(inference_folder, exist_ok=True) + inference_test = True + + +# training parameters +batch_size = setup['training']['batch_size'] +epochs = setup['training']['epochs'] +lr = setup['training']['lr'] +lr_decay_factor = setup['training']['lr_decay_factor'] +lr_gen = lr * setup['training']['gen_lr_reduction'] +lambda_feat = setup['training']['lambda_feat'] +lambda_reg = setup['training']['lambda_reg'] +adv_target = setup['training'].get('adv_target', 'target') + + +# load training dataset +data_config = setup['data'] +data = LPCNetVocodingDataset(setup['dataset'], **data_config) + +# load validation dataset if given +if 'validation_dataset' in setup: + validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config) + + validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4) + + run_validation = True +else: + run_validation = False + +# create model +model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) + + +# create discriminator +disc_name = setup['discriminator']['name'] +disc = model_dict[disc_name]( + *setup['discriminator']['args'], **setup['discriminator']['kwargs'] +) + + + +# set compute device +if type(args.device) == type(None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +else: + device = torch.device(args.device) + + + +# dataloader +dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4) + +# optimizer is introduced to trainable parameters +parameters = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.Adam(parameters, lr=lr_gen) + +# disc optimizer +parameters = [p for p in disc.parameters() if p.requires_grad] +optimizer_disc = torch.optim.Adam(parameters, lr=lr, betas=[0.5, 0.9]) + +# learning rate scheduler +scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) + +if args.initial_checkpoint is not None: + print(f"loading state dict from {args.initial_checkpoint}...") + chkpt = torch.load(args.initial_checkpoint, map_location=device) + model.load_state_dict(chkpt['state_dict']) + + if 'disc_state_dict' in chkpt: + print(f"loading discriminator state dict from {args.initial_checkpoint}...") + disc.load_state_dict(chkpt['disc_state_dict']) + + if 'optimizer_state_dict' in chkpt: + print(f"loading optimizer state dict from {args.initial_checkpoint}...") + optimizer.load_state_dict(chkpt['optimizer_state_dict']) + + if 'disc_optimizer_state_dict' in chkpt: + print(f"loading discriminator optimizer state dict from {args.initial_checkpoint}...") + optimizer_disc.load_state_dict(chkpt['disc_optimizer_state_dict']) + + if 'scheduler_state_disc' in chkpt: + print(f"loading scheduler state dict from {args.initial_checkpoint}...") + scheduler.load_state_dict(chkpt['scheduler_state_dict']) + + # if 'torch_rng_state' in chkpt: + # print(f"setting torch RNG state from {args.initial_checkpoint}...") + # torch.set_rng_state(chkpt['torch_rng_state']) + + if 'numpy_rng_state' in chkpt: + print(f"setting numpy RNG state from {args.initial_checkpoint}...") + np.random.set_state(chkpt['numpy_rng_state']) + + if 'python_rng_state' in chkpt: + print(f"setting Python RNG state from {args.initial_checkpoint}...") + random.setstate(chkpt['python_rng_state']) + +# loss +w_l1 = setup['training']['loss']['w_l1'] +w_lm = setup['training']['loss']['w_lm'] +w_slm = setup['training']['loss']['w_slm'] +w_sc = setup['training']['loss']['w_sc'] +w_logmel = setup['training']['loss']['w_logmel'] +w_wsc = setup['training']['loss']['w_wsc'] +w_xcorr = setup['training']['loss']['w_xcorr'] +w_sxcorr = setup['training']['loss']['w_sxcorr'] +w_l2 = setup['training']['loss']['w_l2'] + +w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + +stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) +logmelloss = MRLogMelLoss().to(device) + +def xcorr_loss(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) + + return torch.mean(loss) + +def td_l2_norm(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) + + return loss.mean() + +def td_l1(y_true, y_pred, pow=0): + dims = list(range(1, len(y_true.shape))) + tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) + + return torch.mean(tmp) + +def criterion(x, y): + + return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum + + +# model checkpoint +checkpoint = { + 'setup' : setup, + 'state_dict' : model.state_dict(), + 'loss' : -1 +} + + +if not args.no_redirect: + print(f"re-directing output to {os.path.join(args.output, output_file)}") + sys.stdout = open(os.path.join(args.output, output_file), "w") + + +print("summary:") + +print(f"generator: {count_parameters(model.cpu()) / 1e6:5.3f} M parameters") +if hasattr(model, 'flop_count'): + print(f"generator: {model.flop_count(16000) / 1e6:5.3f} MFLOPS") +print(f"discriminator: {count_parameters(disc.cpu()) / 1e6:5.3f} M parameters") + +if ref is not None: + noisy = np.fromfile(os.path.join(args.testdata, 'noisy.s16'), dtype=np.int16) + initial_mos = pesq.pesq(16000, ref, noisy, mode='wb') + print(f"initial MOS (PESQ): {initial_mos}") + +best_loss = 1e9 +log_interval = 10 + + +m_r = 0 +m_f = 0 +s_r = 1 +s_f = 1 + +def optimizer_to(optim, device): + for param in optim.state.values(): + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + +optimizer_to(optimizer, device) +optimizer_to(optimizer_disc, device) + + +for ep in range(1, epochs + 1): + print(f"training epoch {ep}...") + + model.to(device) + disc.to(device) + model.train() + disc.train() + + running_disc_loss = 0 + running_adv_loss = 0 + running_feature_loss = 0 + running_reg_loss = 0 + + with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: + for i, batch in enumerate(tepoch): + + # set gradients to zero + optimizer.zero_grad() + + # push batch to device + for key in batch: + batch[key] = batch[key].to(device) + + target = batch['target'].to(device) + disc_target = batch[adv_target].to(device) + + # calculate model output + output = model(batch['features'], batch['periods']) + + # discriminator update + scores_gen = disc(output.detach()) + scores_real = disc(disc_target.unsqueeze(1)) + + disc_loss = 0 + for scale in scores_gen: + disc_loss += ((scale[-1]) ** 2).mean() + m_f = 0.9 * m_f + 0.1 * scale[-1].detach().mean().cpu().item() + s_f = 0.9 * s_f + 0.1 * scale[-1].detach().std().cpu().item() + + for scale in scores_real: + disc_loss += ((1 - scale[-1]) ** 2).mean() + m_r = 0.9 * m_r + 0.1 * scale[-1].detach().mean().cpu().item() + s_r = 0.9 * s_r + 0.1 * scale[-1].detach().std().cpu().item() + + disc_loss = 0.5 * disc_loss / len(scores_gen) + winning_chance = 0.5 * m.erfc( (m_r - m_f) / m.sqrt(2 * (s_f**2 + s_r**2)) ) + + disc.zero_grad() + disc_loss.backward() + optimizer_disc.step() + + # generator update + scores_gen = disc(output) + + + # calculate loss + loss_reg = criterion(output.squeeze(1), target) + + num_discs = len(scores_gen) + loss_gen = 0 + for scale in scores_gen: + loss_gen += ((1 - scale[-1]) ** 2).mean() / num_discs + + loss_feat = 0 + for k in range(num_discs): + num_layers = len(scores_gen[k]) - 1 + f = 4 / num_discs / num_layers + for l in range(num_layers): + loss_feat += f * F.l1_loss(scores_gen[k][l], scores_real[k][l].detach()) + + model.zero_grad() + + (loss_gen + lambda_feat * loss_feat + lambda_reg * loss_reg).backward() + + optimizer.step() + + running_adv_loss += loss_gen.detach().cpu().item() + running_disc_loss += disc_loss.detach().cpu().item() + running_feature_loss += lambda_feat * loss_feat.detach().cpu().item() + running_reg_loss += lambda_reg * loss_reg.detach().cpu().item() + + # update status bar + if i % log_interval == 0: + tepoch.set_postfix(adv_loss=f"{running_adv_loss/(i + 1):8.7f}", + disc_loss=f"{running_disc_loss/(i + 1):8.7f}", + feat_loss=f"{running_feature_loss/(i + 1):8.7f}", + reg_loss=f"{running_reg_loss/(i + 1):8.7f}", + wc=f"{100*winning_chance:5.2f}%") + + + # save checkpoint + checkpoint['state_dict'] = model.state_dict() + checkpoint['disc_state_dict'] = disc.state_dict() + checkpoint['optimizer_state_dict'] = optimizer.state_dict() + checkpoint['disc_optimizer_state_dict'] = optimizer_disc.state_dict() + checkpoint['scheduler_state_dict'] = scheduler.state_dict() + checkpoint['torch_rng_state'] = torch.get_rng_state() + checkpoint['numpy_rng_state'] = np.random.get_state() + checkpoint['python_rng_state'] = random.getstate() + checkpoint['adv_loss'] = running_adv_loss/(i + 1) + checkpoint['disc_loss'] = running_disc_loss/(i + 1) + checkpoint['feature_loss'] = running_feature_loss/(i + 1) + checkpoint['reg_loss'] = running_reg_loss/(i + 1) + + + if inference_test: + print("running inference test...") + out = model.process(features, periods).cpu().numpy() + wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) + if ref is not None: + mos = pesq.pesq(16000, ref, out, mode='wb') + print(f"MOS (PESQ): {mos}") + + + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) + + + print() + +print('Done') diff --git a/dnn/torch/osce/data/__init__.py b/dnn/torch/osce/data/__init__.py index 9f7ea183..8df4d56a 100644 --- a/dnn/torch/osce/data/__init__.py +++ b/dnn/torch/osce/data/__init__.py @@ -1,30 +1,2 @@ -""" -/* Copyright (c) 2023 Amazon - Written by Jan Buethe */ -/* - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER - OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ -""" - -from .silk_enhancement_set import SilkEnhancementSet
\ No newline at end of file +from .silk_enhancement_set import SilkEnhancementSet +from .lpcnet_vocoding_dataset import LPCNetVocodingDataset
\ No newline at end of file diff --git a/dnn/torch/osce/data/lpcnet_vocoding_dataset.py b/dnn/torch/osce/data/lpcnet_vocoding_dataset.py new file mode 100644 index 00000000..36c8c724 --- /dev/null +++ b/dnn/torch/osce/data/lpcnet_vocoding_dataset.py @@ -0,0 +1,225 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +""" Dataset for LPCNet training """ +import os + +import yaml +import torch +import numpy as np +from torch.utils.data import Dataset + + +scale = 255.0/32768.0 +scale_1 = 32768.0/255.0 +def ulaw2lin(u): + u = u - 128 + s = np.sign(u) + u = np.abs(u) + return s*scale_1*(np.exp(u/128.*np.log(256))-1) + + +def lin2ulaw(x): + s = np.sign(x) + x = np.abs(x) + u = (s*(128*np.log(1+scale*x)/np.log(256))) + u = np.clip(128 + np.round(u), 0, 255) + return u + + +def run_lpc(signal, lpcs, frame_length=160): + num_frames, lpc_order = lpcs.shape + + prediction = np.concatenate( + [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)] + ) + error = signal[lpc_order :] - prediction + + return prediction, error + +class LPCNetVocodingDataset(Dataset): + def __init__(self, + path_to_dataset, + features=['cepstrum', 'periods', 'pitch_corr'], + target='signal', + frames_per_sample=100, + feature_history=0, + feature_lookahead=0, + lpc_gamma=1): + + super().__init__() + + # load dataset info + self.path_to_dataset = path_to_dataset + with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f: + dataset = yaml.load(f, yaml.FullLoader) + + # dataset version + self.version = dataset['version'] + if self.version == 1: + self.getitem = self.getitem_v1 + elif self.version == 2: + self.getitem = self.getitem_v2 + else: + raise ValueError(f"dataset version {self.version} unknown") + + # features + self.feature_history = feature_history + self.feature_lookahead = feature_lookahead + self.frame_offset = 2 + self.feature_history + self.frames_per_sample = frames_per_sample + self.input_features = features + self.feature_frame_layout = dataset['feature_frame_layout'] + self.lpc_gamma = lpc_gamma + + # load feature file + self.feature_file = os.path.join(path_to_dataset, dataset['feature_file']) + self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype']) + self.feature_frame_length = dataset['feature_frame_length'] + + assert len(self.features) % self.feature_frame_length == 0 + self.features = self.features.reshape((-1, self.feature_frame_length)) + + # derive number of samples is dataset + self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1 - 2) // self.frames_per_sample + + # signals + self.frame_length = dataset['frame_length'] + self.signal_frame_layout = dataset['signal_frame_layout'] + self.target = target + + # load signals + self.signal_file = os.path.join(path_to_dataset, dataset['signal_file']) + self.signals = np.memmap(self.signal_file, dtype=dataset['signal_dtype']) + self.signal_frame_length = dataset['signal_frame_length'] + self.signals = self.signals.reshape((-1, self.signal_frame_length)) + assert len(self.signals) == len(self.features) * self.frame_length + + + def __getitem__(self, index): + return self.getitem(index) + + def getitem_v2(self, index): + sample = dict() + + # extract features + frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history + frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead + + for feature in self.input_features: + feature_start, feature_stop = self.feature_frame_layout[feature] + sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] + + # convert periods + if 'periods' in self.input_features: + sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') + + signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length + signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length + + # last_signal and signal are always expected to be there + sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] + sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']] + + # calculate prediction and error if lpc coefficients present and prediction not given + if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout: + # lpc coefficients with one frame lookahead + # frame positions (start one frame early for past excitation) + frame_start = self.frame_offset + self.frames_per_sample * index - 1 + frame_stop = self.frame_offset + self.frames_per_sample * (index + 1) + + # feature positions + lpc_start, lpc_stop = self.feature_frame_layout['lpc'] + lpc_order = lpc_stop - lpc_start + lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop] + + # LPC weighting + lpc_order = lpc_stop - lpc_start + weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)]) + lpcs = lpcs * weights + + # signal position (lpc_order samples as history) + signal_start = frame_start * self.frame_length - lpc_order + 1 + signal_stop = frame_stop * self.frame_length + 1 + noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']] + clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']] + + noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length) + + # extract signals + offset = self.frame_length + sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample] + sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample] + # calculate error between real signal and noisy prediction + + + sample['error'] = sample['signal'] - sample['prediction'] + + + # concatenate features + feature_keys = [key for key in self.input_features if not key.startswith("periods")] + features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) + target = torch.FloatTensor(sample[self.target]) / 2**15 + periods = torch.LongTensor(sample['periods']) + + return {'features' : features, 'periods' : periods, 'target' : target} + + def getitem_v1(self, index): + sample = dict() + + # extract features + frame_start = self.frame_offset + index * self.frames_per_sample - self.feature_history + frame_stop = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead + + for feature in self.input_features: + feature_start, feature_stop = self.feature_frame_layout[feature] + sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop] + + # convert periods + if 'periods' in self.input_features: + sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16') + + signal_start = (self.frame_offset + index * self.frames_per_sample) * self.frame_length + signal_stop = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length + + # last_signal and signal are always expected to be there + for signal_name, index in self.signal_frame_layout.items(): + sample[signal_name] = self.signals[signal_start : signal_stop, index] + + # concatenate features + feature_keys = [key for key in self.input_features if not key.startswith("periods")] + features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1) + signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1) + target = torch.LongTensor(sample[self.target]) + periods = torch.LongTensor(sample['periods']) + + return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target} + + def __len__(self): + return self.dataset_length diff --git a/dnn/torch/osce/data/silk_enhancement_set.py b/dnn/torch/osce/data/silk_enhancement_set.py index 186333e9..65e97508 100644 --- a/dnn/torch/osce/data/silk_enhancement_set.py +++ b/dnn/torch/osce/data/silk_enhancement_set.py @@ -50,7 +50,7 @@ class SilkEnhancementSet(Dataset): noisy_spec_scale='opus', noisy_apply_dct=True, add_offset=False, - add_double_lag_acorr=False + add_double_lag_acorr=False, ): assert frames_per_sample % 4 == 0 @@ -75,8 +75,9 @@ class SilkEnhancementSet(Dataset): self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32) self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32) - self.clean_signal = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16) - self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16) + self.clean_signal_hp = np.fromfile(os.path.join(path, 'clean_hp.s16'), dtype=np.int16) + self.clean_signal = np.fromfile(os.path.join(path, 'clean.s16'), dtype=np.int16) + self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16) self.create_features = silk_feature_factory(no_pitch_value, acorr_radius, @@ -92,7 +93,7 @@ class SilkEnhancementSet(Dataset): # discard some frames to have enough signal history self.skip_frames = 4 * ((skip + self.history_len + 319) // 320 + 2) - num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames + num_frames = self.clean_signal_hp.shape[0] // 80 - self.skip_frames self.len = num_frames // frames_per_sample @@ -107,8 +108,9 @@ class SilkEnhancementSet(Dataset): signal_start = frame_start * self.frame_size - self.skip signal_stop = frame_stop * self.frame_size - self.skip - clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15 - coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15 + clean_signal_hp = self.clean_signal_hp[signal_start : signal_stop].astype(np.float32) / 2**15 + clean_signal = self.clean_signal[signal_start : signal_stop].astype(np.float32) / 2**15 + coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15 coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15 @@ -124,6 +126,7 @@ class SilkEnhancementSet(Dataset): if self.preemph > 0: clean_signal[1:] -= self.preemph * clean_signal[: -1] + clean_signal_hp[1:] -= self.preemph * clean_signal_hp[: -1] coded_signal[1:] -= self.preemph * coded_signal[: -1] num_bits = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1) @@ -132,9 +135,10 @@ class SilkEnhancementSet(Dataset): numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1) return { - 'features' : features, - 'periods' : periods.astype(np.int64), - 'target' : clean_signal.astype(np.float32), - 'signals' : coded_signal.reshape(-1, 1).astype(np.float32), - 'numbits' : numbits.astype(np.float32) + 'features' : features, + 'periods' : periods.astype(np.int64), + 'target_orig' : clean_signal.astype(np.float32), + 'target' : clean_signal_hp.astype(np.float32), + 'signals' : coded_signal.reshape(-1, 1).astype(np.float32), + 'numbits' : numbits.astype(np.float32) } diff --git a/dnn/torch/osce/engine/vocoder_engine.py b/dnn/torch/osce/engine/vocoder_engine.py new file mode 100644 index 00000000..9eee49e4 --- /dev/null +++ b/dnn/torch/osce/engine/vocoder_engine.py @@ -0,0 +1,101 @@ +import torch +from tqdm import tqdm +import sys + +def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10): + + model.to(device) + model.train() + + running_loss = 0 + previous_running_loss = 0 + + + with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: + + for i, batch in enumerate(tepoch): + + # set gradients to zero + optimizer.zero_grad() + + + # push batch to device + for key in batch: + batch[key] = batch[key].to(device) + + target = batch['target'] + + # calculate model output + output = model(batch['features'], batch['periods']) + + # calculate loss + if isinstance(output, list): + loss = torch.zeros(1, device=device) + for y in output: + loss = loss + criterion(target, y.squeeze(1)) + loss = loss / len(output) + else: + loss = criterion(target, output.squeeze(1)) + + # calculate gradients + loss.backward() + + # update weights + optimizer.step() + + # update learning rate + scheduler.step() + + # update running loss + running_loss += float(loss.cpu()) + + # update status bar + if i % log_interval == 0: + tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") + previous_running_loss = running_loss + + + running_loss /= len(dataloader) + + return running_loss + +def evaluate(model, criterion, dataloader, device, log_interval=10): + + model.to(device) + model.eval() + + running_loss = 0 + previous_running_loss = 0 + + + with torch.no_grad(): + with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch: + + for i, batch in enumerate(tepoch): + + + + # push batch to device + for key in batch: + batch[key] = batch[key].to(device) + + target = batch['target'] + + # calculate model output + output = model(batch['features'], batch['periods']) + + # calculate loss + loss = criterion(target, output.squeeze(1)) + + # update running loss + running_loss += float(loss.cpu()) + + # update status bar + if i % log_interval == 0: + tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}") + previous_running_loss = running_loss + + + running_loss /= len(dataloader) + + return running_loss
\ No newline at end of file diff --git a/dnn/torch/osce/make_default_setup.py b/dnn/torch/osce/make_default_setup.py index 2b295662..d7365fff 100644 --- a/dnn/torch/osce/make_default_setup.py +++ b/dnn/torch/osce/make_default_setup.py @@ -27,6 +27,36 @@ */ """ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import sys import argparse import yaml @@ -36,12 +66,19 @@ from utils.templates import setup_dict parser = argparse.ArgumentParser() parser.add_argument('name', type=str, help='name of default setup file') -parser.add_argument('--model', choices=['lace', 'nolace'], help='model name', default='lace') +parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace') +parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training') parser.add_argument('--path2dataset', type=str, help='dataset path', default=None) args = parser.parse_args() -setup = setup_dict[args.model] +key = args.model + "_adv" if args.adversarial else args.model + +try: + setup = setup_dict[key] +except KeyError: + print("setup not found, adversarial training possibly not specified for model") + sys.exit(1) # update dataset if given if type(args.path2dataset) != type(None): diff --git a/dnn/torch/osce/models/__init__.py b/dnn/torch/osce/models/__init__.py index 49a88ae2..c7857349 100644 --- a/dnn/torch/osce/models/__init__.py +++ b/dnn/torch/osce/models/__init__.py @@ -29,10 +29,12 @@ from .lace import LACE from .no_lace import NoLACE - - +from .lavoce import LaVoce +from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc model_dict = { 'lace': LACE, - 'nolace': NoLACE + 'nolace': NoLACE, + 'lavoce': LaVoce, + 'fdmresdisc': FDMResDisc, } diff --git a/dnn/torch/osce/models/fd_discriminator.py b/dnn/torch/osce/models/fd_discriminator.py new file mode 100644 index 00000000..22948624 --- /dev/null +++ b/dnn/torch/osce/models/fd_discriminator.py @@ -0,0 +1,974 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import math as m +import copy + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import weight_norm, spectral_norm +import torchaudio + +from utils.spec import gen_filterbank + +# auxiliary functions + +def remove_all_weight_norms(module): + for m in module.modules(): + if hasattr(m, 'weight_v'): + nn.utils.remove_weight_norm(m) + + +def create_smoothing_kernel(h, w, gamma=1.5): + + ch = h / 2 - 0.5 + cw = w / 2 - 0.5 + + sh = gamma * ch + sw = gamma * cw + + vx = ((torch.arange(h) - ch) / sh) ** 2 + vy = ((torch.arange(w) - cw) / sw) ** 2 + vals = vx.view(-1, 1) + vy.view(1, -1) + kernel = torch.exp(- vals) + kernel = kernel / kernel.sum() + + return kernel + + +def create_kernel(h, w, sh, sw): + # proto kernel gives disjoint partition of 1 + proto_kernel = torch.ones((sh, sw)) + + # create smoothing kernel eta + h_eta, w_eta = h - sh + 1, w - sw + 1 + assert h_eta > 0 and w_eta > 0 + eta = create_smoothing_kernel(h_eta, w_eta).view(1, 1, h_eta, w_eta) + + kernel0 = F.pad(proto_kernel, [w_eta - 1, w_eta - 1, h_eta - 1, h_eta - 1]).unsqueeze(0).unsqueeze(0) + kernel = F.conv2d(kernel0, eta) + + return kernel + +# positional embeddings +class FrequencyPositionalEmbedding(nn.Module): + def __init__(self): + + super().__init__() + + def forward(self, x): + + N = x.size(2) + args = torch.arange(0, N, dtype=x.dtype, device=x.device) * torch.pi * 2 / N + cos = torch.cos(args).reshape(1, 1, -1, 1) + sin = torch.sin(args).reshape(1, 1, -1, 1) + zeros = torch.zeros_like(x[:, 0:1, :, :]) + + y = torch.cat((x, zeros + sin, zeros + cos), dim=1) + + return y + + +class PositionalEmbedding2D(nn.Module): + def __init__(self, d=5): + + super().__init__() + + self.d = d + + def forward(self, x): + + N = x.size(2) + M = x.size(3) + + h_args = torch.arange(0, N, dtype=x.dtype, device=x.device).reshape(1, 1, -1, 1) + w_args = torch.arange(0, M, dtype=x.dtype, device=x.device).reshape(1, 1, 1, -1) + coeffs = (10000 ** (-2 * torch.arange(0, self.d, dtype=x.dtype, device=x.device) / self.d)).reshape(1, -1, 1, 1) + + h_sin = torch.sin(coeffs * h_args) + h_cos = torch.sin(coeffs * h_args) + w_sin = torch.sin(coeffs * w_args) + w_cos = torch.sin(coeffs * w_args) + + zeros = torch.zeros_like(x[:, 0:1, :, :]) + + y = torch.cat((x, zeros + h_sin, zeros + h_cos, zeros + w_sin, zeros + w_cos), dim=1) + + return y + + +# spectral discriminator base class +class SpecDiscriminatorBase(nn.Module): + RECEPTIVE_FIELD_MAX_WIDTH=10000 + def __init__(self, + layers, + resolution, + fs=16000, + freq_roi=[50, 7000], + noise_gain=1e-3, + fmap_start_index=0 + ): + super().__init__() + + + self.layers = nn.ModuleList(layers) + self.resolution = resolution + self.fs = fs + self.noise_gain = noise_gain + self.fmap_start_index = fmap_start_index + + if fmap_start_index >= len(layers): + raise ValueError(f'fmap_start_index is larger than number of layers') + + # filter bank for noise shaping + n_fft = resolution[0] + + self.filterbank = nn.Parameter( + gen_filterbank(n_fft // 2, fs, keep_size=True), + requires_grad=False + ) + + # roi bins + f_step = fs / n_fft + self.start_bin = int(m.ceil(freq_roi[0] / f_step - 0.01)) + self.stop_bin = min(int(m.floor(freq_roi[1] / f_step + 0.01)), n_fft//2 + 1) + + self.init_weights() + + # determine receptive field size, offsets and strides + + hw = 1000 + while True: + x = torch.zeros((1, hw, hw)) + with torch.no_grad(): + y = self.run_layer_stack(x)[-1] + + pos0 = [y.size(-2) // 2, y.size(-1) // 2] + pos1 = [t + 1 for t in pos0] + + hs0, ws0 = self._receptive_field((hw, hw), pos0) + hs1, ws1 = self._receptive_field((hw, hw), pos1) + + h0 = hs0[1] - hs0[0] + 1 + h1 = hs1[1] - hs1[0] + 1 + w0 = ws0[1] - ws0[0] + 1 + w1 = ws1[1] - ws1[0] + 1 + + if h0 != h1 or w0 != w1: + hw = 2 * hw + else: + + # strides + sh = hs1[0] - hs0[0] + sw = ws1[0] - ws0[0] + + if sh == 0 or sw == 0: continue + + # offsets + oh = hs0[0] - sh * pos0[0] + ow = ws0[0] - sw * pos0[1] + + # overlap factor + overlap = w0 / sw + h0 / sh + + #print(f"{w0=} {h0=} {sw=} {sh=} {overlap=}") + self.receptive_field_params = {'width': [sw, ow, w0], 'height': [sh, oh, h0], 'overlap': overlap} + + break + + if hw > self.RECEPTIVE_FIELD_MAX_WIDTH: + print("warning: exceeded max size while trying to determine receptive field") + + # create transposed convolutional kernel + #self.tconv_kernel = nn.Parameter(create_kernel(h0, w0, sw, sw), requires_grad=False) + + def run_layer_stack(self, spec): + + output = [] + + x = spec.unsqueeze(1) + + for layer in self.layers: + x = layer(x) + output.append(x) + + return output + + def forward(self, x): + """ returns array with feature maps and final score at index -1 """ + + output = [] + + x = self.spectrogram(x) + + output = self.run_layer_stack(x) + + return output[self.fmap_start_index:] + + def receptive_field(self, output_pos): + + if self.receptive_field_params is not None: + s, o, h = self.receptive_field_params['height'] + h_min = output_pos[0] * s + o + self.start_bin + h_max = h_min + h + h_min = max(h_min, self.start_bin) + h_max = min(h_max, self.stop_bin) + + s, o, w = self.receptive_field_params['width'] + w_min = output_pos[1] * s + o + w_max = w_min + w + + return (h_min, h_max), (w_min, w_max) + + else: + return None, None + + + def _receptive_field(self, input_dims, output_pos): + """ determines receptive field probabilistically via autograd (slow) """ + + x = torch.randn((1,) + input_dims, requires_grad=True) + + # run input through layers + y = self.run_layer_stack(x)[-1] + b, c, h, w = y.shape + + if output_pos[0] >= h or output_pos[1] >= w: + raise ValueError("position out of range") + + mask = torch.zeros((b, c, h, w)) + mask[0, 0, output_pos[0], output_pos[1]] = 1 + + (mask * y).sum().backward() + + hs, ws = torch.nonzero(x.grad[0], as_tuple=True) + + h_min, h_max = hs.min().item(), hs.max().item() + w_min, w_max = ws.min().item(), ws.max().item() + + return [h_min, h_max], [w_min, w_max] + + + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = x.squeeze(1) + window = getattr(torch, 'hann_window')(win_length).to(x.device) + + x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\ + window=window, return_complex=True) #[B, F, T] + x = torch.abs(x) + + # noise floor following spectral envelope + smoothed_x = torch.matmul(self.filterbank, x) + noise = torch.randn_like(x) * smoothed_x * self.noise_gain + x = x + noise + + # frequency ROI + x = x[:, self.start_bin : self.stop_bin + 1, ...] + + return torchaudio.functional.amplitude_to_DB(x,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80)#torch.sqrt(x) + + def grad_map(self, x): + self.zero_grad() + + n_fft, hop_length, win_length = self.resolution + + window = getattr(torch, 'hann_window')(win_length).to(x.device) + + y = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length, + window=window, return_complex=True) #[B, F, T] + y = torch.abs(y) + + specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80) + + specgram.requires_grad = True + specgram.retain_grad() + + if specgram.grad is not None: + specgram.grad.zero_() + + y = specgram[:, self.start_bin : self.stop_bin + 1, ...] + + scores = self.run_layer_stack(y)[-1] + + loss = torch.mean((1 - scores) ** 2) + loss.backward() + + return specgram.data[0], torch.abs(specgram.grad)[0] + + def relevance_map(self, x): + + n_fft, hop_length, win_length = self.resolution + y = x.view(-1) + window = getattr(torch, 'hann_window')(win_length).to(x.device) + + y = torch.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length,\ + window=window, return_complex=True) #[B, F, T] + y = torch.abs(y) + + specgram = torchaudio.functional.amplitude_to_DB(y,db_multiplier=0.0, multiplier=20,amin=1e-05,top_db=80) + + + scores = self.forward(x)[-1] + + sh, _, h = self.receptive_field_params['height'] + sw, _, w = self.receptive_field_params['width'] + kernel = create_kernel(h, w, sh, sw).float().to(scores.device) + with torch.no_grad(): + pad_w = (w + sw - 1) // sw + pad_h = (h + sh - 1) // sh + padded_scores = F.pad(scores, (pad_w, pad_w, pad_h, pad_h), mode='replicate') + # CAVE: padding should be derived from offsets + rv = F.conv_transpose2d(padded_scores, kernel, bias=None, stride=(sh, sw), padding=(h//2, w//2)) + rv = rv[..., pad_h * sh : - pad_h * sh, pad_w * sw : -pad_w * sw] + + relevance = torch.zeros_like(specgram) + relevance[..., self.start_bin : self.start_bin + rv.size(-2), : rv.size(-1)] = rv + + + return specgram, relevance + + + def lrp(self, x, eps=1e-9, label='both', threshold=0.5, low=None, high=None, verbose=False): + """ layer-wise relevance propagation (https://git.tu-berlin.de/gmontavon/lrp-tutorial) """ + + # ToDo: this code is highly unsafe as it assumes that layers are nn.Sequential with suitable activations + + def newconv2d(layer,g): + + new_layer = nn.Conv2d(layer.in_channels, + layer.out_channels, + layer.kernel_size, + stride=layer.stride, + padding=layer.padding, + dilation=layer.dilation, + groups=layer.groups) + + try: new_layer.weight = nn.Parameter(g(layer.weight.data.clone())) + except AttributeError: pass + + try: new_layer.bias = nn.Parameter(g(layer.bias.data.clone())) + except AttributeError: pass + + return new_layer + + bounds = { + 64: [-85.82449722290039, 2.1755014657974243], + 128: [-84.49211349487305, 3.5078893899917607], + 256: [-80.33127822875977, 7.6687201976776125], + 512: [-73.79328079223633, 14.20672025680542], + 1024: [-67.59239501953125, 20.40760498046875], + 2048: [-62.31902580261231, 25.680974197387698], + } + + nfft = self.resolution[0] + if low is None: low = bounds[nfft][0] + if high is None: high = bounds[nfft][1] + + remove_all_weight_norms(self) + + for p in self.parameters(): + if p.grad is not None: + p.grad.zero_() + + num_layers = len(self.layers) + X = self.spectrogram(x). detach() + + + # forward pass + A = [X.unsqueeze(1)] + [None] * len(self.layers) + + for i in range(num_layers - 1): + A[i + 1] = self.layers[i](A[i]) + + # initial relevance is last layer without activation + r = A[-2] + last_layer_rs = [r] + layer = self.layers[-1] + for sublayer in list(layer)[:-1]: + r = sublayer(r) + last_layer_rs.append(r) + + + mask = torch.zeros_like(r) + mask.requires_grad_(False) + if verbose: + print(r.min(), r.max()) + if label in {'both', 'fake'}: + mask[r < -threshold] = 1 + if label in {'both', 'real'}: + mask[r > threshold] = 1 + r = r * mask + + # backward pass + R = [None] * num_layers + [r] + + for l in range(1, num_layers)[::-1]: + A[l] = (A[l]).data.requires_grad_(True) + + layer = nn.Sequential(*(list(self.layers[l])[:-1])) + z = layer(A[l]) + eps + s = (R[l+1] / z).data + (z*s).sum().backward() + c = A[l].grad + R[l] = (A[l] * c).data + + # first layer + A[0] = (A[0].data).requires_grad_(True) + + Xl = (torch.zeros_like(A[0].data) + low).requires_grad_(True) + Xh = (torch.zeros_like(A[0].data) + high).requires_grad_(True) + + if len(list(self.layers)) > 2: + # unsafe way to check for embedding layer + embed = list(self.layers[0])[0] + conv = list(self.layers[0])[1] + + layer = nn.Sequential(embed, conv) + layerl = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(min=0))) + layerh = nn.Sequential(embed, newconv2d(conv, lambda p: p.clamp(max=0))) + + else: + layer = list(self.layers[0])[0] + layerl = newconv2d(layer, lambda p: p.clamp(min=0)) + layerh = newconv2d(layer, lambda p: p.clamp(max=0)) + + + z = layer(A[0]) + z -= layerl(Xl) + layerh(Xh) + s = (R[1] / z).data + (z * s).sum().backward() + c, cp, cm = A[0].grad, Xl.grad, Xh.grad + + R[0] = (A[0] * c + Xl * cp + Xh * cm) + #R[0] = (A[0] * c).data + + return X, R[0].mean(dim=1) + + + + + + + + + + +def create_3x3_conv_plan(num_layers : int, + f_stretch : int, + f_down : int, + t_stretch : int, + t_down : int + ): + + + """ creates a stride, dilation, padding plan for a 2d conv network + + Args: + num_layers (int): number of layers + f_stretch (int): log_2 of stretching factor along frequency axis + f_down (int): log_2 of downsampling factor along frequency axis + t_stretch (int): log_2 of stretching factor along time axis + t_down (int): log_2 of downsampling factor along time axis + + Returns: + list(list(tuple)): list containing entries [(stride_t, stride_f), (dilation_t, dilation_f), (padding_t, padding_f)] + """ + + assert num_layers > 0 and t_stretch >= 0 and t_down >= 0 and f_stretch >= 0 and f_down >= 0 + assert f_stretch < num_layers and t_stretch < num_layers + + def process_dimension(n_layers, stretch, down): + + stack_layers = n_layers - 1 + + stride_layers = min(min(down, stretch) , stack_layers) + dilation_layers = max(min(stack_layers - stride_layers - 1, stretch - stride_layers), 0) + final_stride = 2 ** (max(down - stride_layers, 0)) + + final_dilation = 1 + if stride_layers < stack_layers and stretch - stride_layers - dilation_layers > 0: + final_dilation = 2 + + strides, dilations, paddings = [], [], [] + processed_layers = 0 + current_dilation = 1 + + for _ in range(stride_layers): + # increase receptive field and downsample via stride = 2 + strides.append(2) + dilations.append(1) + paddings.append(1) + processed_layers += 1 + + if processed_layers < stack_layers: + strides.append(1) + dilations.append(1) + paddings.append(1) + processed_layers += 1 + + for _ in range(dilation_layers): + # increase receptive field via dilation = 2 + strides.append(1) + current_dilation *= 2 + dilations.append(current_dilation) + paddings.append(current_dilation) + processed_layers += 1 + + while processed_layers < n_layers - 1: + # fill up with std layers + strides.append(1) + dilations.append(current_dilation) + paddings.append(current_dilation) + processed_layers += 1 + + # final layer + strides.append(final_stride) + current_dilation * final_dilation + dilations.append(current_dilation) + paddings.append(current_dilation) + processed_layers += 1 + + assert processed_layers == n_layers + + return strides, dilations, paddings + + t_strides, t_dilations, t_paddings = process_dimension(num_layers, t_stretch, t_down) + f_strides, f_dilations, f_paddings = process_dimension(num_layers, f_stretch, f_down) + + plan = [] + + for i in range(num_layers): + plan.append([ + (f_strides[i], t_strides[i]), + (f_dilations[i], t_dilations[i]), + (f_paddings[i], t_paddings[i]), + ]) + + return plan + + +class DiscriminatorExperimental(SpecDiscriminatorBase): + + def __init__(self, + resolution, + fs=16000, + freq_roi=[50, 7400], + noise_gain=0, + num_channels=16, + max_channels=512, + num_layers=5, + use_spectral_norm=False): + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + self.num_channels = num_channels + self.num_channels_max = max_channels + self.num_layers = num_layers + + layers = [] + stride = (2, 1) + padding= (1, 1) + in_channels = 1 + 2 + out_channels = self.num_channels + for _ in range(self.num_layers): + layers.append( + nn.Sequential( + FrequencyPositionalEmbedding(), + norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), + nn.ReLU(inplace=True) + ) + ) + in_channels = out_channels + 2 + out_channels = min(2 * out_channels, self.num_channels_max) + + layers.append( + nn.Sequential( + FrequencyPositionalEmbedding(), + norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)), + nn.Sigmoid() + ) + ) + + super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) + + # bias biases + bias_val = 0.1 + with torch.no_grad(): + for name, weight in self.named_parameters(): + if 'bias' in name: + weight = weight + bias_val + + +configs = { + 'f_down': { + 'stretch' : { + 64 : (0, 0), + 128: (1, 0), + 256: (2, 0), + 512: (3, 0), + 1024: (4, 0), + 2048: (5, 0) + }, + 'down' : { + 64 : (0, 0), + 128: (1, 0), + 256: (2, 0), + 512: (3, 0), + 1024: (4, 0), + 2048: (5, 0) + } + }, + 'ft_down': { + 'stretch' : { + 64 : (0, 4), + 128: (1, 3), + 256: (2, 2), + 512: (3, 1), + 1024: (4, 0), + 2048: (5, 0) + }, + 'down' : { + 64 : (0, 4), + 128: (1, 3), + 256: (2, 2), + 512: (3, 1), + 1024: (4, 0), + 2048: (5, 0) + } + }, + 'dilated': { + 'stretch' : { + 64 : (0, 4), + 128: (1, 3), + 256: (2, 2), + 512: (3, 1), + 1024: (4, 0), + 2048: (5, 0) + }, + 'down' : { + 64 : (0, 0), + 128: (0, 0), + 256: (0, 0), + 512: (0, 0), + 1024: (0, 0), + 2048: (0, 0) + } + }, + 'mixed': { + 'stretch' : { + 64 : (0, 4), + 128: (1, 3), + 256: (2, 2), + 512: (3, 1), + 1024: (4, 0), + 2048: (5, 0) + }, + 'down' : { + 64 : (0, 0), + 128: (1, 0), + 256: (2, 0), + 512: (3, 0), + 1024: (4, 0), + 2048: (5, 0) + } + }, +} + + +class DiscriminatorMagFree(SpecDiscriminatorBase): + + def __init__(self, + resolution, + fs=16000, + freq_roi=[50, 7400], + noise_gain=0, + num_channels=16, + max_channels=256, + num_layers=5, + use_spectral_norm=False, + design=None): + + if design is None: + raise ValueError('error: arch required in DiscriminatorMagFree') + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + stretch = configs[design]['stretch'][resolution[0]] + down = configs[design]['down'][resolution[0]] + + self.num_channels = num_channels + self.num_channels_max = max_channels + self.num_layers = num_layers + self.stretch = stretch + self.down = down + + layers = [] + plan = create_3x3_conv_plan(num_layers + 1, stretch[0], down[0], stretch[1], down[1]) + in_channels = 1 + 2 + out_channels = self.num_channels + for i in range(self.num_layers): + layers.append( + nn.Sequential( + FrequencyPositionalEmbedding(), + norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=plan[i][0], dilation=plan[i][1], padding=plan[i][2])), + nn.ReLU(inplace=True) + ) + ) + in_channels = out_channels + 2 + # product over strides + channel_factor = plan[i][0][0] * plan[i][0][1] + out_channels = min(channel_factor * out_channels, self.num_channels_max) + + layers.append( + nn.Sequential( + FrequencyPositionalEmbedding(), + norm_f(nn.Conv2d(in_channels, 1, (3, 3), stride=plan[-1][0], dilation=plan[-1][1], padding=plan[-1][2])), + nn.Sigmoid() + ) + ) + + + + # for layer in layers: + # print(layer) + + # print("end\n\n") + + super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) + + # bias biases + bias_val = 0.1 + with torch.no_grad(): + for name, weight in self.named_parameters(): + if 'bias' in name: + weight = weight + bias_val + +class DiscriminatorMagFreqPosition(SpecDiscriminatorBase): + + def __init__(self, + resolution, + fs=16000, + freq_roi=[50, 7400], + noise_gain=0, + num_channels=16, + max_channels=512, + num_layers=5, + use_spectral_norm=False): + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + self.num_channels = num_channels + self.num_channels_max = max_channels + self.num_layers = num_layers + + layers = [] + stride = (2, 1) + padding= (1, 1) + in_channels = 1 + 2 + out_channels = self.num_channels + for _ in range(self.num_layers): + layers.append( + nn.Sequential( + FrequencyPositionalEmbedding(), + norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), + nn.LeakyReLU(0.2, inplace=True) + ) + ) + in_channels = out_channels + 2 + out_channels = min(2 * out_channels, self.num_channels_max) + + layers.append( + nn.Sequential( + FrequencyPositionalEmbedding(), + norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)) + ) + ) + + super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) + + + +class DiscriminatorMag2dPositional(SpecDiscriminatorBase): + + def __init__(self, + resolution, + fs=16000, + freq_roi=[50, 7400], + noise_gain=0, + num_channels=16, + max_channels=512, + num_layers=5, + d=5, + use_spectral_norm=False): + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.resolution = resolution + self.num_channels = num_channels + self.num_channels_max = max_channels + self.num_layers = num_layers + self.d = d + embedding_dim = 4 * d + + + layers = [] + stride = (2, 2) + padding= (1, 1) + in_channels = 1 + embedding_dim + out_channels = self.num_channels + for _ in range(self.num_layers): + layers.append( + nn.Sequential( + PositionalEmbedding2D(d), + norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), + nn.LeakyReLU(0.2, inplace=True) + ) + ) + in_channels = out_channels + embedding_dim + out_channels = min(2 * out_channels, self.num_channels_max) + + + layers.append( + nn.Sequential( + PositionalEmbedding2D(), + norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding)) + ) + ) + + super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) + + + +class DiscriminatorMag(SpecDiscriminatorBase): + def __init__(self, + resolution, + fs=16000, + freq_roi=[50, 7400], + noise_gain=0, + num_channels=32, + num_layers=5, + use_spectral_norm=False): + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + self.num_channels = num_channels + self.num_layers = num_layers + + layers = [] + stride = (1, 1) + padding= (1, 1) + in_channels = 1 + out_channels = self.num_channels + for _ in range(self.num_layers): + layers.append( + nn.Sequential( + norm_f(nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=padding)), + nn.LeakyReLU(0.2, inplace=True) + ) + ) + in_channels = out_channels + + layers.append(norm_f(nn.Conv2d(in_channels, 1, (3, 3), padding=padding))) + + super().__init__(layers=layers, resolution=resolution, fs=fs, freq_roi=freq_roi, noise_gain=noise_gain) + + +discriminators = { + 'mag': DiscriminatorMag, + 'freqpos': DiscriminatorMagFreqPosition, + '2dpos': DiscriminatorMag2dPositional, + 'experimental': DiscriminatorExperimental, + 'free': DiscriminatorMagFree +} + +class TFDMultiResolutionDiscriminator(torch.nn.Module): + def __init__(self, + fft_sizes_16k=[64, 128, 256, 512, 1024, 2048], + architecture='mag', + fs=16000, + freq_roi=[50, 7400], + noise_gain=0, + use_spectral_norm=False, + **kwargs): + + super().__init__() + + + fft_sizes = [int(round(fft_size_16k * fs / 16000)) for fft_size_16k in fft_sizes_16k] + + resolutions = [[n_fft, n_fft // 4, n_fft] for n_fft in fft_sizes] + + + Disc = discriminators[architecture] + + discs = [Disc(resolutions[i], fs=fs, freq_roi=freq_roi, noise_gain=noise_gain, use_spectral_norm=use_spectral_norm, **kwargs) for i in range(len(resolutions))] + + self.discriminators = nn.ModuleList(discs) + + def forward(self, y): + outputs = [] + + for disc in self.discriminators: + outputs.append(disc(y)) + + return outputs + + +class FWGAN_disc_wrapper(nn.Module): + def __init__(self, disc): + super().__init__() + + self.disc = disc + + def forward(self, y, y_hat): + + out_real = self.disc(y) + out_fake = self.disc(y_hat) + + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for y_real, y_fake in zip(out_real, out_fake): + y_d_rs.append(y_real[-1]) + y_d_gs.append(y_fake[-1]) + fmap_rs.append(y_real[:-1]) + fmap_gs.append(y_fake[:-1]) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/dnn/torch/osce/models/lavoce.py b/dnn/torch/osce/models/lavoce.py new file mode 100644 index 00000000..1a9dc871 --- /dev/null +++ b/dnn/torch/osce/models/lavoce.py @@ -0,0 +1,254 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d +from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d +from utils.layers.td_shaper import TDShaper +from utils.layers.noise_shaper import NoiseShaper +from utils.complexity import _conv1d_flop_count +from utils.endoscopy import write_data + +from models.nns_base import NNSBase +from models.lpcnet_feature_net import LPCNetFeatureNet +from .scale_embedding import ScaleEmbedding + +class LaVoce(nn.Module): + """ Linear-Adaptive VOCodEr """ + FEATURE_FRAME_SIZE=160 + FRAME_SIZE=80 + + def __init__(self, + num_features=20, + pitch_embedding_dim=64, + cond_dim=256, + pitch_max=300, + kernel_size=15, + preemph=0.85, + comb_gain_limit_db=-6, + global_gain_limits_db=[-6, 6], + conv_gain_limits_db=[-6, 6], + norm_p=2, + avg_pool_k=4, + pulses=False): + + super().__init__() + + + self.num_features = num_features + self.cond_dim = cond_dim + self.pitch_max = pitch_max + self.pitch_embedding_dim = pitch_embedding_dim + self.kernel_size = kernel_size + self.preemph = preemph + self.pulses = pulses + + assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0 + self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE + + # pitch embedding + self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) + + # feature net + self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor) + + # noise shaper + self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE) + + # comb filters + left_pad = self.kernel_size // 2 + right_pad = self.kernel_size - 1 - left_pad + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + + + self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af_mix = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # spectral shaping + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # non-linear transforms + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + + # combinators + self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # feature transforms + self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) + + + def create_phase_signals(self, periods, pulses=False): + + batch_size = periods.size(0) + progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) + progression = torch.repeat_interleave(progression, batch_size, 0) + + phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) + chunks = [] + for sframe in range(periods.size(1)): + f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) + + if pulses: + alpha = torch.cos(f) + chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) + pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha) + pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha) + + chunk = torch.cat((pulse_a, pulse_b), dim = 1) + else: + chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) + chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) + + chunk = torch.cat((chunk_sin, chunk_cos), dim = 1) + + phase0 = phase0 + self.FRAME_SIZE * f + + chunks.append(chunk) + + phase_signals = torch.cat(chunks, dim=-1) + + return phase_signals + + def flop_count(self, rate=16000, verbose=False): + + frame_rate = rate / self.FRAME_SIZE + + # feature net + feature_net_flops = self.feature_net.flop_count(frame_rate) + comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) + af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_prescale.flop_count(rate) + self.af_mix.flop_count(rate) + feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) + + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) + + if verbose: + print(f"feature net: {feature_net_flops / 1e6} MFLOPS") + print(f"comb filters: {comb_flops / 1e6} MFLOPS") + print(f"adaptive conv: {af_flops / 1e6} MFLOPS") + print(f"feature transforms: {feature_flops / 1e6} MFLOPS") + + return feature_net_flops + comb_flops + af_flops + feature_flops + + def feature_transform(self, f, layer): + f = f.permute(0, 2, 1) + f = F.pad(f, [1, 0]) + f = torch.tanh(layer(f)) + return f.permute(0, 2, 1) + + def forward(self, features, periods, debug=False): + + periods = periods.squeeze(-1) + pitch_embedding = self.pitch_embedding(periods) + + full_features = torch.cat((features, pitch_embedding), dim=-1) + cf = self.feature_net(full_features) + + # upsample periods + periods = torch.repeat_interleave(periods, self.upsamp_factor, 1) + + # pre-net + ref_phase = torch.tanh(self.create_phase_signals(periods)) + x = self.af_prescale(ref_phase, cf) + noise = self.noise_shaper(cf) + y = self.af_mix(torch.cat((x, noise), dim=1), cf) + + if debug: + ch0 = y[0,0,:].detach().cpu().numpy() + ch1 = y[0,1,:].detach().cpu().numpy() + ch0 = (2**15 * ch0 / np.max(ch0)).astype(np.int16) + ch1 = (2**15 * ch1 / np.max(ch1)).astype(np.int16) + write_data('prior_channel0', ch0, 16000) + write_data('prior_channel1', ch1, 16000) + + # temporal shaping + innovating + y1 = y[:, 0:1, :] + y2 = self.tdshape1(y[:, 1:2, :], cf) + y = torch.cat((y1, y2), dim=1) + y = self.af2(y, cf, debug=debug) + cf = self.feature_transform(cf, self.post_af2) + + y1 = y[:, 0:1, :] + y2 = self.tdshape2(y[:, 1:2, :], cf) + y = torch.cat((y1, y2), dim=1) + y = self.af3(y, cf, debug=debug) + cf = self.feature_transform(cf, self.post_af3) + + # spectral shaping + y = self.cf1(y, cf, periods, debug=debug) + cf = self.feature_transform(cf, self.post_cf1) + + y = self.cf2(y, cf, periods, debug=debug) + cf = self.feature_transform(cf, self.post_cf2) + + y = self.af1(y, cf, debug=debug) + cf = self.feature_transform(cf, self.post_af1) + + # final temporal env adjustment + y1 = y[:, 0:1, :] + y2 = self.tdshape3(y[:, 1:2, :], cf) + y = torch.cat((y1, y2), dim=1) + y = self.af4(y, cf, debug=debug) + + return y + + def process(self, features, periods, debug=False): + + self.eval() + device = next(iter(self.parameters())).device + with torch.no_grad(): + + # run model + f = features.unsqueeze(0).to(device) + p = periods.unsqueeze(0).to(device) + + y = self.forward(f, p, debug=debug).squeeze() + + # deemphasis + if self.preemph > 0: + for i in range(len(y) - 1): + y[i + 1] += self.preemph * y[i] + + # clip to valid range + out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short() + + return out
\ No newline at end of file diff --git a/dnn/torch/osce/models/lpcnet_feature_net.py b/dnn/torch/osce/models/lpcnet_feature_net.py new file mode 100644 index 00000000..b637d748 --- /dev/null +++ b/dnn/torch/osce/models/lpcnet_feature_net.py @@ -0,0 +1,91 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.complexity import _conv1d_flop_count + +class LPCNetFeatureNet(nn.Module): + + def __init__(self, + feature_dim=84, + num_channels=256, + upsamp_factor=2, + lookahead=True): + + super().__init__() + + self.feature_dim = feature_dim + self.num_channels = num_channels + self.upsamp_factor = upsamp_factor + self.lookahead = lookahead + + self.conv1 = nn.Conv1d(feature_dim, num_channels, 3) + self.conv2 = nn.Conv1d(num_channels, num_channels, 3) + + self.gru = nn.GRU(num_channels, num_channels, batch_first=True) + + self.tconv = nn.ConvTranspose1d(num_channels, num_channels, upsamp_factor, upsamp_factor) + + def flop_count(self, rate=100): + count = 0 + for conv in self.conv1, self.conv2, self.tconv: + count += _conv1d_flop_count(conv, rate) + + count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate + + return count + + + def forward(self, features, state=None): + """ features shape: (batch_size, num_frames, feature_dim) """ + + batch_size = features.size(0) + + if state is None: + state = torch.zeros((1, batch_size, self.num_channels), device=features.device) + + + features = features.permute(0, 2, 1) + if self.lookahead: + c = torch.tanh(self.conv1(F.pad(features, [1, 1]))) + c = torch.tanh(self.conv2(F.pad(c, [2, 0]))) + else: + c = torch.tanh(self.conv1(F.pad(features, [2, 0]))) + c = torch.tanh(self.conv2(F.pad(c, [2, 0]))) + + c = torch.tanh(self.tconv(c)) + + c = c.permute(0, 2, 1) + + c, _ = self.gru(c, state) + + return c
\ No newline at end of file diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 4524906d..2709274c 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -1,3 +1,31 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" import torch from torch import nn diff --git a/dnn/torch/osce/test_vocoder.py b/dnn/torch/osce/test_vocoder.py new file mode 100644 index 00000000..e71a5c37 --- /dev/null +++ b/dnn/torch/osce/test_vocoder.py @@ -0,0 +1,103 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import argparse + +import torch + +from scipy.io import wavfile + +from time import time + + +from models import model_dict +from utils.lpcnet_features import load_lpcnet_features +from utils import endoscopy + +debug = False +if debug: + args = type('dummy', (object,), + { + 'input' : 'testitems/all_0_orig.se', + 'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth', + 'output' : 'out.wav', + })() +else: + parser = argparse.ArgumentParser() + + parser.add_argument('input', type=str, help='path to input features') + parser.add_argument('checkpoint', type=str, help='checkpoint file') + parser.add_argument('output', type=str, help='output file') + parser.add_argument('--debug', action='store_true', help='enables debug output') + + + args = parser.parse_args() + + +torch.set_num_threads(2) + +input_folder = args.input +checkpoint_file = args.checkpoint + + +output_file = args.output +if not output_file.endswith('.wav'): + output_file += '.wav' + +checkpoint = torch.load(checkpoint_file, map_location="cpu") + +# check model +if not 'name' in checkpoint['setup']['model']: + print(f'warning: did not find model name entry in setup, using pitchpostfilter per default') + model_name = 'pitchpostfilter' +else: + model_name = checkpoint['setup']['model']['name'] + +model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) + +model.load_state_dict(checkpoint['state_dict']) + +# generate model input +setup = checkpoint['setup'] +testdata = load_lpcnet_features(input_folder) +features = testdata['features'] +periods = testdata['periods'] + +if args.debug: + endoscopy.init() + +start = time() +output = model.process(features, periods, debug=args.debug) +elapsed = time() - start +print(f"[timing] inference took {elapsed * 1000} ms") + +wavfile.write(output_file, 16000, output.cpu().numpy()) + +if args.debug: + endoscopy.close() diff --git a/dnn/torch/osce/train_vocoder.py b/dnn/torch/osce/train_vocoder.py new file mode 100644 index 00000000..f4d8157d --- /dev/null +++ b/dnn/torch/osce/train_vocoder.py @@ -0,0 +1,287 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse +import sys + +import yaml + +try: + import git + has_git = True +except: + has_git = False + +import torch +from torch.optim.lr_scheduler import LambdaLR + +from scipy.io import wavfile + +import pesq + +from data import LPCNetVocodingDataset +from models import model_dict +from engine.vocoder_engine import train_one_epoch, evaluate + + +from utils.lpcnet_features import load_lpcnet_features +from utils.misc import count_parameters + +from losses.stft_loss import MRSTFTLoss, MRLogMelLoss + + +parser = argparse.ArgumentParser() + +parser.add_argument('setup', type=str, help='setup yaml file') +parser.add_argument('output', type=str, help='output path') +parser.add_argument('--device', type=str, help='compute device', default=None) +parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) +parser.add_argument('--test-features', type=str, help='path to features for testing', default=None) +parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of stdout') + +args = parser.parse_args() + + +torch.set_num_threads(4) + +with open(args.setup, 'r') as f: + setup = yaml.load(f.read(), yaml.FullLoader) + +checkpoint_prefix = 'checkpoint' +output_prefix = 'output' +setup_name = 'setup.yml' +output_file='out.txt' + + +# check model +if not 'name' in setup['model']: + print(f'warning: did not find model entry in setup, using default PitchPostFilter') + model_name = 'pitchpostfilter' +else: + model_name = setup['model']['name'] + +# prepare output folder +if os.path.exists(args.output): + print("warning: output folder exists") + + reply = input('continue? (y/n): ') + while reply not in {'y', 'n'}: + reply = input('continue? (y/n): ') + + if reply == 'n': + os._exit() +else: + os.makedirs(args.output, exist_ok=True) + +checkpoint_dir = os.path.join(args.output, 'checkpoints') +os.makedirs(checkpoint_dir, exist_ok=True) + +# add repo info to setup +if has_git: + working_dir = os.path.split(__file__)[0] + try: + repo = git.Repo(working_dir) + setup['repo'] = dict() + hash = repo.head.object.hexsha + urls = list(repo.remote().urls) + is_dirty = repo.is_dirty() + + if is_dirty: + print("warning: repo is dirty") + + setup['repo']['hash'] = hash + setup['repo']['urls'] = urls + setup['repo']['dirty'] = is_dirty + except: + has_git = False + +# dump setup +with open(os.path.join(args.output, setup_name), 'w') as f: + yaml.dump(setup, f) + +ref = None +# prepare inference test if wanted +inference_test = False +if type(args.test_features) != type(None): + test_features = load_lpcnet_features(args.test_features) + features = test_features['features'] + periods = test_features['periods'] + inference_folder = os.path.join(args.output, 'inference_test') + os.makedirs(inference_folder, exist_ok=True) + inference_test = True + + +# training parameters +batch_size = setup['training']['batch_size'] +epochs = setup['training']['epochs'] +lr = setup['training']['lr'] +lr_decay_factor = setup['training']['lr_decay_factor'] + +# load training dataset +data_config = setup['data'] +data = LPCNetVocodingDataset(setup['dataset'], **data_config) + +# load validation dataset if given +if 'validation_dataset' in setup: + validation_data = LPCNetVocodingDataset(setup['validation_dataset'], **data_config) + + validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=8) + + run_validation = True +else: + run_validation = False + +# create model +model = model_dict[model_name](*setup['model']['args'], **setup['model']['kwargs']) + +if args.initial_checkpoint is not None: + print(f"loading state dict from {args.initial_checkpoint}...") + chkpt = torch.load(args.initial_checkpoint, map_location='cpu') + model.load_state_dict(chkpt['state_dict']) + +# set compute device +if type(args.device) == type(None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +else: + device = torch.device(args.device) + +# push model to device +model.to(device) + +# dataloader +dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8) + +# optimizer is introduced to trainable parameters +parameters = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.Adam(parameters, lr=lr) + +# learning rate scheduler +scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) + +# loss +w_l1 = setup['training']['loss']['w_l1'] +w_lm = setup['training']['loss']['w_lm'] +w_slm = setup['training']['loss']['w_slm'] +w_sc = setup['training']['loss']['w_sc'] +w_logmel = setup['training']['loss']['w_logmel'] +w_wsc = setup['training']['loss']['w_wsc'] +w_xcorr = setup['training']['loss']['w_xcorr'] +w_sxcorr = setup['training']['loss']['w_sxcorr'] +w_l2 = setup['training']['loss']['w_l2'] + +w_sum = w_l1 + w_lm + w_sc + w_logmel + w_wsc + w_slm + w_xcorr + w_sxcorr + w_l2 + +stftloss = MRSTFTLoss(sc_weight=w_sc, log_mag_weight=w_lm, wsc_weight=w_wsc, smooth_log_mag_weight=w_slm, sxcorr_weight=w_sxcorr).to(device) +logmelloss = MRLogMelLoss().to(device) + +def xcorr_loss(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = 1 - torch.sum(y_true * y_pred, dim=dims) / torch.sqrt(torch.sum(y_true ** 2, dim=dims) * torch.sum(y_pred ** 2, dim=dims) + 1e-9) + + return torch.mean(loss) + +def td_l2_norm(y_true, y_pred): + dims = list(range(1, len(y_true.shape))) + + loss = torch.mean((y_true - y_pred) ** 2, dim=dims) / (torch.mean(y_pred ** 2, dim=dims) ** .5 + 1e-6) + + return loss.mean() + +def td_l1(y_true, y_pred, pow=0): + dims = list(range(1, len(y_true.shape))) + tmp = torch.mean(torch.abs(y_true - y_pred), dim=dims) / ((torch.mean(torch.abs(y_pred), dim=dims) + 1e-9) ** pow) + + return torch.mean(tmp) + +def criterion(x, y): + + return (w_l1 * td_l1(x, y, pow=1) + stftloss(x, y) + w_logmel * logmelloss(x, y) + + w_xcorr * xcorr_loss(x, y) + w_l2 * td_l2_norm(x, y)) / w_sum + + + +# model checkpoint +checkpoint = { + 'setup' : setup, + 'state_dict' : model.state_dict(), + 'loss' : -1 +} + + +if not args.no_redirect: + print(f"re-directing output to {os.path.join(args.output, output_file)}") + sys.stdout = open(os.path.join(args.output, output_file), "w") + +print("summary:") + +print(f"{count_parameters(model.cpu()) / 1e6:5.3f} M parameters") +if hasattr(model, 'flop_count'): + print(f"{model.flop_count(16000) / 1e6:5.3f} MFLOPS") + +if ref is not None: + pass + +best_loss = 1e9 + +for ep in range(1, epochs + 1): + print(f"training epoch {ep}...") + new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler) + + + # save checkpoint + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = new_loss + + if run_validation: + print("running validation...") + validation_loss = evaluate(model, criterion, validation_dataloader, device) + checkpoint['validation_loss'] = validation_loss + + if validation_loss < best_loss: + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth')) + best_loss = validation_loss + + if inference_test: + print("running inference test...") + out = model.process(features, periods).cpu().numpy() + wavfile.write(os.path.join(inference_folder, f'{model_name}_epoch_{ep}.wav'), 16000, out) + if ref is not None: + mos = pesq.pesq(16000, ref, out, mode='wb') + print(f"MOS (PESQ): {mos}") + + + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) + torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) + + + print() + +print('Done') diff --git a/dnn/torch/osce/utils/complexity.py b/dnn/torch/osce/utils/complexity.py index 79de22c5..4ee6e3f3 100644 --- a/dnn/torch/osce/utils/complexity.py +++ b/dnn/torch/osce/utils/complexity.py @@ -1,31 +1,4 @@ -""" -/* Copyright (c) 2023 Amazon - Written by Jan Buethe */ -/* - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER - OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ -""" def _conv1d_flop_count(layer, rate): return 2 * ((layer.in_channels + 1) * layer.out_channels * rate / layer.stride[0] ) * layer.kernel_size[0] diff --git a/dnn/torch/osce/utils/endoscopy.py b/dnn/torch/osce/utils/endoscopy.py index 141447e2..05dd4750 100644 --- a/dnn/torch/osce/utils/endoscopy.py +++ b/dnn/torch/osce/utils/endoscopy.py @@ -1,32 +1,3 @@ -""" -/* Copyright (c) 2023 Amazon - Written by Jan Buethe */ -/* - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER - OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ -""" - """ module for inspecting models during inference """ import os diff --git a/dnn/torch/osce/utils/layers/noise_shaper.py b/dnn/torch/osce/utils/layers/noise_shaper.py new file mode 100644 index 00000000..ba8a3af3 --- /dev/null +++ b/dnn/torch/osce/utils/layers/noise_shaper.py @@ -0,0 +1,100 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.complexity import _conv1d_flop_count + +class NoiseShaper(nn.Module): + + def __init__(self, + feature_dim, + frame_size=160 + ): + """ + + Parameters: + ----------- + + feature_dim : int + dimension of input features + + frame_size : int + frame size + + """ + + super().__init__() + + self.feature_dim = feature_dim + self.frame_size = frame_size + + # feature transform + self.feature_alpha1 = nn.Conv1d(self.feature_dim, frame_size, 2) + self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + + + def flop_count(self, rate): + + frame_rate = rate / self.frame_size + + shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size + + return shape_flops + + + def forward(self, features): + """ creates temporally shaped noise + + + Parameters: + ----------- + features : torch.tensor + frame-wise features of shape (batch_size, num_frames, feature_dim) + + """ + + batch_size = features.size(0) + num_frames = features.size(1) + frame_size = self.frame_size + num_samples = num_frames * frame_size + + # feature path + f = F.pad(features.permute(0, 2, 1), [1, 0]) + alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) + alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) + alpha = alpha.permute(0, 2, 1) + + # signal generation + y = torch.randn((batch_size, num_frames, frame_size), dtype=features.dtype, device=features.device) + y = alpha * y + + return y.reshape(batch_size, 1, num_samples) diff --git a/dnn/torch/osce/utils/layers/silk_upsampler.py b/dnn/torch/osce/utils/layers/silk_upsampler.py index d5f396ed..0d20b8a6 100644 --- a/dnn/torch/osce/utils/layers/silk_upsampler.py +++ b/dnn/torch/osce/utils/layers/silk_upsampler.py @@ -1,3 +1,32 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + """ This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """ import torch diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 2ab12bad..73d66bd5 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -11,7 +11,8 @@ class TDShaper(nn.Module): feature_dim, frame_size=160, avg_pool_k=4, - innovate=False + innovate=False, + pool_after=False ): """ @@ -39,6 +40,7 @@ class TDShaper(nn.Module): self.frame_size = frame_size self.avg_pool_k = avg_pool_k self.innovate = innovate + self.pool_after = pool_after assert frame_size % avg_pool_k == 0 self.env_dim = frame_size // avg_pool_k + 1 @@ -71,8 +73,12 @@ class TDShaper(nn.Module): def envelope_transform(self, x): x = torch.abs(x) - x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) - x = torch.log(x + .5**16) + if self.pool_after: + x = torch.log(x + .5**16) + x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) + else: + x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) + x = torch.log(x + .5**16) x = x.reshape(x.size(0), -1, self.env_dim - 1) avg_x = torch.mean(x, -1, keepdim=True) diff --git a/dnn/torch/osce/utils/lpcnet_features.py b/dnn/torch/osce/utils/lpcnet_features.py new file mode 100644 index 00000000..3d109fd3 --- /dev/null +++ b/dnn/torch/osce/utils/lpcnet_features.py @@ -0,0 +1,112 @@ +import os + +import torch +import numpy as np + +def load_lpcnet_features(feature_file, version=2): + if version == 2: + layout = { + 'cepstrum': [0,18], + 'periods': [18, 19], + 'pitch_corr': [19, 20], + 'lpc': [20, 36] + } + frame_length = 36 + + elif version == 1: + layout = { + 'cepstrum': [0,18], + 'periods': [36, 37], + 'pitch_corr': [37, 38], + 'lpc': [39, 55], + } + frame_length = 55 + else: + raise ValueError(f'unknown feature version: {version}') + + + raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32')) + raw_features = raw_features.reshape((-1, frame_length)) + + features = torch.cat( + [ + raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]], + raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]] + ], + dim=1 + ) + + lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]] + periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long() + + return {'features' : features, 'periods' : periods, 'lpcs' : lpcs} + + + +def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85): + ref_data = np.memmap(reference_data_path, dtype=np.int16) + signal = np.memmap(signal_path, dtype=np.int16) + + signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw' + signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape) + + + assert len(signal) % 160 == 0 + num_frames = len(signal) // 160 + mem = np.zeros(1) + for fr in range(len(signal)//160): + signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid') + mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160] + + new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape) + + new_data[:] = 0 + N = len(signal) - offset + new_data[1 : 2*N + 1: 2] = signal_preemph[offset:] + new_data[2 : 2*N + 2: 2] = signal_preemph[offset:] + + +def parse_warpq_scores(output_file): + """ extracts warpq scores from output file """ + + with open(output_file, "r") as f: + lines = f.readlines() + + scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")] + + return scores + + +def parse_stats_file(file): + + with open(file, "r") as f: + lines = f.readlines() + + mean = float(lines[0].split(":")[-1]) + bt_mean = float(lines[1].split(":")[-1]) + top_mean = float(lines[2].split(":")[-1]) + + return mean, bt_mean, top_mean + +def collect_test_stats(test_folder): + """ collects statistics for all discovered metrics from test folder """ + + metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'} + + results = dict() + + content = os.listdir(test_folder) + + stats_files = [file for file in content if file.startswith('stats_')] + + for file in stats_files: + metric = file[len("stats_") : -len(".txt")] + + if metric not in metrics: + print(f"warning: unknown metric {metric}") + + mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file)) + + results[metric] = [mean, bt_mean, top_mean] + + return results diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py index d4c03478..6fe3dfa8 100644 --- a/dnn/torch/osce/utils/misc.py +++ b/dnn/torch/osce/utils/misc.py @@ -39,4 +39,27 @@ def count_parameters(model, verbose=False): total += count - return total
\ No newline at end of file + return total + + +def retain_grads(module): + for p in module.parameters(): + if p.requires_grad: + p.retain_grad() + +def get_grad_norm(module, p=2): + norm = 0 + for param in module.parameters(): + if param.requires_grad: + norm = norm + (torch.abs(param.grad) ** p).sum() + + return norm ** (1/p) + +def create_weights(s_real, s_gen, alpha): + weights = [] + with torch.no_grad(): + for sr, sg in zip(s_real, s_gen): + weight = torch.exp(alpha * (sr[-1] - sg[-1])) + weights.append(weight) + + return weights
\ No newline at end of file diff --git a/dnn/torch/osce/utils/silk_features.py b/dnn/torch/osce/utils/silk_features.py index 071a6c26..2997ef5f 100644 --- a/dnn/torch/osce/utils/silk_features.py +++ b/dnn/torch/osce/utils/silk_features.py @@ -27,7 +27,6 @@ */ """ - import os import numpy as np diff --git a/dnn/torch/osce/utils/spec.py b/dnn/torch/osce/utils/spec.py index 7e41d84e..01b923ae 100644 --- a/dnn/torch/osce/utils/spec.py +++ b/dnn/torch/osce/utils/spec.py @@ -30,6 +30,7 @@ import math as m import numpy as np import scipy +import torch def erb(f): return 24.7 * (4.37 * f + 1) @@ -49,6 +50,20 @@ scale_dict = { 'erb': [erb, inv_erb] } +def gen_filterbank(N, Fs=16000, keep_size=False): + in_freq = (np.arange(N+1, dtype='float32')/N*Fs/2)[None,:] + M = N + 1 if keep_size else N + out_freq = (np.arange(M, dtype='float32')/N*Fs/2)[:,None] + #ERB from B.C.J Moore, An Introduction to the Psychology of Hearing, 5th Ed., page 73. + ERB_N = 24.7 + .108*in_freq + delta = np.abs(in_freq-out_freq)/ERB_N + center = (delta<.5).astype('float32') + R = -12*center*delta**2 + (1-center)*(3-12*delta) + RE = 10.**(R/10.) + norm = np.sum(RE, axis=1) + RE = RE/norm[:, np.newaxis] + return torch.from_numpy(RE) + def create_filter_bank(num_bands, n_fft=320, fs=16000, scale='bark', round_center_bins=False, return_upper=False, normalize=False): f0 = 0 diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index c9648f44..42137b26 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -140,8 +140,196 @@ nolace_setup = { } } +nolace_setup_adv = { + 'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training', + 'model': { + 'name': 'nolace', + 'args': [], + 'kwargs': { + 'avg_pool_k': 4, + 'comb_gain_limit_db': 10, + 'cond_dim': 256, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'hidden_feature_dim': 96, + 'kernel_size': 15, + 'num_features': 93, + 'numbits_embedding_dim': 8, + 'numbits_range': [50, 650], + 'partial_lookahead': True, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'skip': 91 + } + }, + 'data': { + 'frames_per_sample': 100, + 'no_pitch_value': 7, + 'preemph': 0.85, + 'skip': 91, + 'pitch_hangover': 8, + 'acorr_radius': 2, + 'num_bands_clean_spec': 64, + 'num_bands_noisy_spec': 18, + 'noisy_spec_scale': 'opus', + 'pitch_hangover': 8, + }, + 'discriminator': { + 'args': [], + 'kwargs': { + 'architecture': 'free', + 'design': 'f_down', + 'fft_sizes_16k': [ + 64, + 128, + 256, + 512, + 1024, + 2048, + ], + 'freq_roi': [0, 7400], + 'fs': 16000, + 'max_channels': 256, + 'noise_gain': 0.0, + }, + 'name': 'fdmresdisc', + }, + 'training': { + 'adv_target': 'target_orig', + 'batch_size': 64, + 'epochs': 50, + 'gen_lr_reduction': 1, + 'lambda_feat': 1.0, + 'lambda_reg': 0.6, + 'loss': { + 'w_l1': 0, + 'w_l2': 10, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_slm': 20, + 'w_sxcorr': 1, + 'w_wsc': 0, + 'w_xcorr': 0, + }, + 'lr': 0.0001, + 'lr_decay_factor': 2.5e-09, + } +} + + +lavoce_setup = { + 'data': { + 'frames_per_sample': 100, + 'target': 'signal' + }, + 'dataset': '/local/datasets/lpcnet_large/training', + 'model': { + 'args': [], + 'kwargs': { + 'comb_gain_limit_db': 10, + 'cond_dim': 256, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'kernel_size': 15, + 'num_features': 19, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'pulses': True + }, + 'name': 'lavoce' + }, + 'training': { + 'batch_size': 256, + 'epochs': 50, + 'loss': { + 'w_l1': 0, + 'w_l2': 0, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_slm': 2, + 'w_sxcorr': 1, + 'w_wsc': 0, + 'w_xcorr': 0 + }, + 'lr': 0.0005, + 'lr_decay_factor': 2.5e-05 + }, + 'validation_dataset': '/local/datasets/lpcnet_large/validation' +} + +lavoce_setup_adv = { + 'data': { + 'frames_per_sample': 100, + 'target': 'signal' + }, + 'dataset': '/local/datasets/lpcnet_large/training', + 'discriminator': { + 'args': [], + 'kwargs': { + 'architecture': 'free', + 'design': 'f_down', + 'fft_sizes_16k': [ + 64, + 128, + 256, + 512, + 1024, + 2048, + ], + 'freq_roi': [0, 7400], + 'fs': 16000, + 'max_channels': 256, + 'noise_gain': 0.0, + }, + 'name': 'fdmresdisc', + }, + 'model': { + 'args': [], + 'kwargs': { + 'comb_gain_limit_db': 10, + 'cond_dim': 256, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'kernel_size': 15, + 'num_features': 19, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'pulses': True + }, + 'name': 'lavoce' + }, + 'training': { + 'batch_size': 64, + 'epochs': 50, + 'gen_lr_reduction': 1, + 'lambda_feat': 1.0, + 'lambda_reg': 0.6, + 'loss': { + 'w_l1': 0, + 'w_l2': 0, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_slm': 2, + 'w_sxcorr': 1, + 'w_wsc': 0, + 'w_xcorr': 0 + }, + 'lr': 0.0001, + 'lr_decay_factor': 2.5e-09 + }, +} + setup_dict = { 'lace': lace_setup, - 'nolace': nolace_setup + 'nolace': nolace_setup, + 'nolace_adv': nolace_setup_adv, + 'lavoce': lavoce_setup, + 'lavoce_adv': lavoce_setup_adv } |