diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2024-01-16 02:10:21 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2024-01-16 02:11:47 +0300 |
commit | 26ddfd713537accce773acc12f565021f4f6d28c (patch) | |
tree | b99f9904fff3492dd328b6863990b697bc9f59ef | |
parent | 6ad03ae03e3b37dc472c291e4e77997bf64e6965 (diff) |
PyTorch code for training the PLC model
Should match the TF2 code, but mostly untested
-rw-r--r-- | dnn/torch/plc/plc.py | 144 | ||||
-rw-r--r-- | dnn/torch/plc/plc_dataset.py | 56 | ||||
-rw-r--r-- | dnn/torch/plc/train_plc.py | 145 |
3 files changed, 345 insertions, 0 deletions
diff --git a/dnn/torch/plc/plc.py b/dnn/torch/plc/plc.py new file mode 100644 index 00000000..f08e564d --- /dev/null +++ b/dnn/torch/plc/plc.py @@ -0,0 +1,144 @@ +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import math + +fid_dict = {} +def dump_signal(x, filename): + return + if filename in fid_dict: + fid = fid_dict[filename] + else: + fid = open(filename, "w") + fid_dict[filename] = fid + x = x.detach().numpy().astype('float32') + x.tofile(fid) + + +class IDCT(nn.Module): + def __init__(self, N, device=None): + super(IDCT, self).__init__() + + self.N = N + n = torch.arange(N, device=device) + k = torch.arange(N, device=device) + self.table = torch.cos(torch.pi/N * (n[:,None]+.5) * k[None,:]) + self.table[:,0] = self.table[:,0] * math.sqrt(.5) + self.table = self.table / math.sqrt(N/2) + + def forward(self, x): + return F.linear(x, self.table, None) + +def plc_loss(N, device=None, alpha=1.0, bias=1.): + idct = IDCT(18, device=device) + def loss(y_true,y_pred): + mask = y_true[:,:,-1:] + y_true = y_true[:,:,:-1] + e = (y_pred - y_true)*mask + e_bands = idct(e[:,:,:-2]) + bias_mask = torch.clamp(4*y_true[:,:,-1:], min=0., max=1.) + l1_loss = torch.mean(torch.abs(e)) + ceps_loss = torch.mean(torch.abs(e[:,:,:-2])) + band_loss = torch.mean(torch.abs(e_bands)) + biased_loss = torch.mean(bias_mask*torch.clamp(e_bands, min=0.)) + pitch_loss1 = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=1.)) + pitch_loss = torch.mean(torch.clamp(torch.abs(e[:,:,18:19]),max=.4)) + voice_bias = torch.mean(torch.clamp(-e[:,:,-1:], min=0.)) + tot = l1_loss + 0.1*voice_bias + alpha*(band_loss + bias*biased_loss) + pitch_loss1 + 8*pitch_loss + return tot, l1_loss, ceps_loss, band_loss, pitch_loss + return loss + + +# weight initialization and clipping +def init_weights(module): + if isinstance(module, nn.GRU): + for p in module.named_parameters(): + if p[0].startswith('weight_hh_'): + nn.init.orthogonal_(p[1]) + + +class GLU(nn.Module): + def __init__(self, feat_size): + super(GLU, self).__init__() + + torch.manual_seed(5) + + self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x): + + out = x * torch.sigmoid(self.gate(x)) + + return out + +class FWConv(nn.Module): + def __init__(self, in_size, out_size, kernel_size=2): + super(FWConv, self).__init__() + + torch.manual_seed(5) + + self.in_size = in_size + self.kernel_size = kernel_size + self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False)) + self.glu = GLU(out_size) + + self.init_weights() + + def init_weights(self): + + for m in self.modules(): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ + or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + nn.init.orthogonal_(m.weight.data) + + def forward(self, x, state): + xcat = torch.cat((state, x), -1) + out = self.glu(torch.tanh(self.conv(xcat))) + return out, xcat[:,self.in_size:] + +def n(x): + return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) + +class PLC(nn.Module): + def __init__(self, features_in=57, features_out=20, cond_size=128, gru_size=128): + super(PLC, self).__init__() + + self.features_in = features_in + self.features_out = features_out + self.cond_size = cond_size + self.gru_size = gru_size + + self.dense_in = nn.Linear(self.features_in, self.cond_size) + self.gru1 = nn.GRU(self.cond_size, self.gru_size, batch_first=True) + self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) + self.dense_out = nn.Linear(self.gru_size, features_out) + + self.apply(init_weights) + nb_params = sum(p.numel() for p in self.parameters()) + print(f"plc model: {nb_params} weights") + + def forward(self, features, lost, states=None): + device = features.device + batch_size = features.size(0) + if states is None: + gru1_state = torch.zeros((1, batch_size, self.gru_size), device=device) + gru2_state = torch.zeros((1, batch_size, self.gru_size), device=device) + else: + gru1_state = states[0] + gru2_state = states[1] + x = torch.cat([features, lost], dim=-1) + x = torch.tanh(self.dense_in(x)) + gru1_out, gru1_state = self.gru1(x, gru1_state) + gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) + return self.dense_out(gru2_out), [gru1_state, gru2_state] diff --git a/dnn/torch/plc/plc_dataset.py b/dnn/torch/plc/plc_dataset.py new file mode 100644 index 00000000..f5e4747f --- /dev/null +++ b/dnn/torch/plc/plc_dataset.py @@ -0,0 +1,56 @@ +import torch +import numpy as np + +class PLCDataset(torch.utils.data.Dataset): + def __init__(self, + feature_file, + loss_file, + sequence_length=1000, + nb_features=20, + nb_burg_features=36, + lpc_order=16): + + self.features_in = nb_features + nb_burg_features + self.nb_burg_features = nb_burg_features + total_features = self.features_in + lpc_order + self.sequence_length = sequence_length + self.nb_features = nb_features + + self.features = np.memmap(feature_file, dtype='float32', mode='r') + self.lost = np.memmap(loss_file, dtype='int8', mode='r') + self.lost = self.lost.astype('float32') + + self.nb_sequences = self.features.shape[0]//self.sequence_length//total_features + + self.features = self.features[:self.nb_sequences*self.sequence_length*total_features] + self.features = self.features.reshape((self.nb_sequences, self.sequence_length, total_features)) + self.features = self.features[:,:,:self.features_in] + + #self.lost = self.lost[:(len(self.lost)//features.shape[1]-1)*features.shape[1]] + #self.lost = self.lost.reshape((-1, self.sequence_length)) + + def __len__(self): + return self.nb_sequences + + def __getitem__(self, index): + features = self.features[index, :, :] + burg_lost = (np.random.rand(features.shape[0]) > .1).astype('float32') + burg_lost = np.reshape(burg_lost, (features.shape[0], 1)) + burg_mask = np.tile(burg_lost, (1,self.nb_burg_features)) + + lost_offset = np.random.randint(0, high=self.lost.shape[0]-self.sequence_length) + lost = self.lost[lost_offset:lost_offset+self.sequence_length] + lost = np.reshape(lost, (features.shape[0], 1)) + lost_mask = np.tile(lost, (1,features.shape[-1])) + in_features = features*lost_mask + in_features[:,:self.nb_burg_features] = in_features[:,:self.nb_burg_features]*burg_mask + + #For the first frame after a loss, we don't have valid features, but the Burg estimate is valid. + #in_features[:,1:,self.nb_burg_features:] = in_features[:,1:,self.nb_burg_features:]*lost_mask[:,:-1,self.nb_burg_features:] + out_lost = np.copy(lost) + #out_lost[:,1:,:] = out_lost[:,1:,:]*out_lost[:,:-1,:] + + out_features = np.concatenate([features[:,self.nb_burg_features:], 1.-out_lost], axis=-1) + burg_sign = 2*burg_lost - 1 + # last dim is 1 for received packet, 0 for lost packet, and -1 when just the Burg info is missing + return in_features*lost_mask, lost*burg_sign, out_features diff --git a/dnn/torch/plc/train_plc.py b/dnn/torch/plc/train_plc.py new file mode 100644 index 00000000..97be2c04 --- /dev/null +++ b/dnn/torch/plc/train_plc.py @@ -0,0 +1,145 @@ +import os +import argparse +import random +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +import tqdm + +import plc +from plc_dataset import PLCDataset + +parser = argparse.ArgumentParser() + +parser.add_argument('features', type=str, help='path to feature file in .f32 format') +parser.add_argument('loss', type=str, help='path to signal file in .s8 format') +parser.add_argument('output', type=str, help='path to output folder') + +parser.add_argument('--suffix', type=str, help="model name suffix", default="") +parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) + + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 128", default=128) +model_group.add_argument('--gru-size', type=int, help="GRU size, default: 128", default=128) + +training_group = parser.add_argument_group(title="training parameters") +training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512) +training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3) +training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 20', default=20) +training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 15', default=15) +training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-4', default=1e-4) +training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) + +args = parser.parse_args() + +if args.cuda_visible_devices != None: + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices + +# checkpoints +checkpoint_dir = os.path.join(args.output, 'checkpoints') +checkpoint = dict() +os.makedirs(checkpoint_dir, exist_ok=True) + + +# training parameters +batch_size = args.batch_size +lr = args.lr +epochs = args.epochs +sequence_length = args.sequence_length +lr_decay = args.lr_decay + +adam_betas = [0.8, 0.95] +adam_eps = 1e-8 +features_file = args.features +loss_file = args.loss + +# model parameters +cond_size = args.cond_size + + +checkpoint['batch_size'] = batch_size +checkpoint['lr'] = lr +checkpoint['lr_decay'] = lr_decay +checkpoint['epochs'] = epochs +checkpoint['sequence_length'] = sequence_length +checkpoint['adam_betas'] = adam_betas + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +checkpoint['model_args'] = () +checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gru_size': args.gru_size} +print(checkpoint['model_kwargs']) +model = plc.PLC(*checkpoint['model_args'], **checkpoint['model_kwargs']) + +if type(args.initial_checkpoint) != type(None): + checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') + model.load_state_dict(checkpoint['state_dict'], strict=False) + +checkpoint['state_dict'] = model.state_dict() + + +dataset = PLCDataset(features_file, loss_file, sequence_length=sequence_length) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) + + +optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) + + +# learning rate scheduler +scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) + +states = None + +plc_loss = plc.plc_loss(18, device=device) +if __name__ == '__main__': + model.to(device) + + for epoch in range(1, epochs + 1): + + running_loss = 0 + running_l1_loss = 0 + running_ceps_loss = 0 + running_band_loss = 0 + running_pitch_loss = 0 + + print(f"training epoch {epoch}...") + with tqdm.tqdm(dataloader, unit='batch') as tepoch: + for i, (features, lost, target) in enumerate(tepoch): + optimizer.zero_grad() + features = features.to(device) + lost = lost.to(device) + target = target.to(device) + + out, states = model(features, lost) + + loss, l1_loss, ceps_loss, band_loss, pitch_loss = plc_loss(target, out) + + loss.backward() + optimizer.step() + + #model.clip_weights() + + scheduler.step() + + running_loss += loss.detach().cpu().item() + running_l1_loss += l1_loss.detach().cpu().item() + running_ceps_loss += ceps_loss.detach().cpu().item() + running_band_loss += band_loss.detach().cpu().item() + running_pitch_loss += pitch_loss.detach().cpu().item() + tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}", + l1_loss=f"{running_l1_loss/(i+1):8.5f}", + ceps_loss=f"{running_ceps_loss/(i+1):8.5f}", + band_loss=f"{running_band_loss/(i+1):8.5f}", + pitch_loss=f"{running_pitch_loss/(i+1):8.5f}", + ) + + # save checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth') + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = running_loss / len(dataloader) + checkpoint['epoch'] = epoch + torch.save(checkpoint, checkpoint_path) |