diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-12-21 23:34:33 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-12-21 23:34:33 +0300 |
commit | 627aa7f5b3688ba787c69e55e199ba82e2013be0 (patch) | |
tree | 7da937443dd9e435f790ef151f4d6dbd7e79baf5 | |
parent | 7d328f5bfaa321d823ff4d11b62d5357c99e0693 (diff) |
Packet loss generation model
-rw-r--r-- | dnn/torch/lossgen/lossgen.py | 28 | ||||
-rwxr-xr-x | dnn/torch/lossgen/process_data.sh | 17 | ||||
-rw-r--r-- | dnn/torch/lossgen/test_lossgen.py | 45 | ||||
-rw-r--r-- | dnn/torch/lossgen/train_lossgen.py | 96 |
4 files changed, 186 insertions, 0 deletions
diff --git a/dnn/torch/lossgen/lossgen.py b/dnn/torch/lossgen/lossgen.py new file mode 100644 index 00000000..a1f2708b --- /dev/null +++ b/dnn/torch/lossgen/lossgen.py @@ -0,0 +1,28 @@ +import torch +from torch import nn +import torch.nn.functional as F + +class LossGen(nn.Module): + def __init__(self, gru1_size=16, gru2_size=16): + super(LossGen, self).__init__() + + self.gru1_size = gru1_size + self.gru2_size = gru2_size + self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True) + self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True) + self.dense_out = nn.Linear(self.gru2_size, 1) + + def forward(self, loss, perc, states=None): + #print(states) + device = loss.device + batch_size = loss.size(0) + if states is None: + gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device) + gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device) + else: + gru1_state = states[0] + gru2_state = states[1] + x = torch.cat([loss, perc], dim=-1) + gru1_out, gru1_state = self.gru1(x, gru1_state) + gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) + return self.dense_out(gru2_out), [gru1_state, gru2_state] diff --git a/dnn/torch/lossgen/process_data.sh b/dnn/torch/lossgen/process_data.sh new file mode 100755 index 00000000..308fd0aa --- /dev/null +++ b/dnn/torch/lossgen/process_data.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +#directory containing the loss files +datadir=$1 + +for i in $datadir/*_is_lost.txt +do + perc=`cat $i | awk '{a+=$1}END{print a/NR}'` + echo $perc $i +done > percentage_list.txt + +sort -n percentage_list.txt | awk '{print $2}' > percentage_sorted.txt + +for i in `cat percentage_sorted.txt` +do + cat $i +done > loss_sorted.txt diff --git a/dnn/torch/lossgen/test_lossgen.py b/dnn/torch/lossgen/test_lossgen.py new file mode 100644 index 00000000..0258d0e6 --- /dev/null +++ b/dnn/torch/lossgen/test_lossgen.py @@ -0,0 +1,45 @@ +import lossgen +import os +import argparse +import torch +import numpy as np + + +parser = argparse.ArgumentParser() + +parser.add_argument('model', type=str, help='CELPNet model') +parser.add_argument('percentage', type=float, help='percentage loss') +parser.add_argument('output', type=str, help='path to output file (ascii)') + +parser.add_argument('--length', type=int, help="length of sequence to generate", default=500) + +args = parser.parse_args() + + + +checkpoint = torch.load(args.model, map_location='cpu') + +model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs']) + + +model.load_state_dict(checkpoint['state_dict'], strict=False) + +states=None +last = torch.zeros((1,1,1)) +perc = torch.tensor((args.percentage,))[None,None,:] +seq = torch.zeros((0,1,1)) + +one = torch.ones((1,1,1)) +zero = torch.zeros((1,1,1)) + +if __name__ == '__main__': + for i in range(args.length): + prob, states = model(last, perc, states=states) + prob = torch.sigmoid(prob) + states[0] = states[0].detach() + states[1] = states[1].detach() + loss = one if np.random.rand() < prob else zero + last = loss + seq = torch.cat([seq, loss]) + +np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d') diff --git a/dnn/torch/lossgen/train_lossgen.py b/dnn/torch/lossgen/train_lossgen.py new file mode 100644 index 00000000..f0f6dd75 --- /dev/null +++ b/dnn/torch/lossgen/train_lossgen.py @@ -0,0 +1,96 @@ +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import tqdm +from scipy.signal import lfilter +import os +import lossgen + +class LossDataset(torch.utils.data.Dataset): + def __init__(self, + loss_file, + sequence_length=997): + + self.sequence_length = sequence_length + + self.loss = np.loadtxt(loss_file, dtype='float32') + + self.nb_sequences = self.loss.shape[0]//self.sequence_length + self.loss = self.loss[:self.nb_sequences*self.sequence_length] + self.perc = lfilter(np.array([.001], dtype='float32'), np.array([1., -.999], dtype='float32'), self.loss) + + self.loss = np.reshape(self.loss, (self.nb_sequences, self.sequence_length, 1)) + self.perc = np.reshape(self.perc, (self.nb_sequences, self.sequence_length, 1)) + + def __len__(self): + return self.nb_sequences + + def __getitem__(self, index): + r0 = np.random.normal(scale=.02, size=(1,1)).astype('float32') + r1 = np.random.normal(scale=.02, size=(self.sequence_length,1)).astype('float32') + return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1] + + +adam_betas = [0.8, 0.99] +adam_eps = 1e-8 +batch_size=512 +lr_decay = 0.0001 +lr = 0.001 +epsilon = 1e-5 +epochs = 20 +checkpoint_dir='checkpoint' +os.makedirs(checkpoint_dir, exist_ok=True) +checkpoint = dict() + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +checkpoint['model_args'] = () +checkpoint['model_kwargs'] = {'gru1_size': 16, 'gru2_size': 48} +model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs']) +dataset = LossDataset('loss_sorted.txt') +dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) + + +optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) + + +# learning rate scheduler +scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) + + +if __name__ == '__main__': + model.to(device) + + for epoch in range(1, epochs + 1): + + running_loss = 0 + + print(f"training epoch {epoch}...") + with tqdm.tqdm(dataloader, unit='batch') as tepoch: + for i, (loss, perc) in enumerate(tepoch): + optimizer.zero_grad() + loss = loss.to(device) + perc = perc.to(device) + + out, _ = model(loss, perc) + out = torch.sigmoid(out[:,:-1,:]) + target = loss[:,1:,:] + + loss = torch.mean(-target*torch.log(out+epsilon) - (1-target)*torch.log(1-out+epsilon)) + + loss.backward() + optimizer.step() + + scheduler.step() + + running_loss += loss.detach().cpu().item() + tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}", + ) + + # save checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f'lossgen_{epoch}.pth') + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = running_loss / len(dataloader) + checkpoint['epoch'] = epoch + torch.save(checkpoint, checkpoint_path) |