Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-12-21 23:34:33 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-12-21 23:34:33 +0300
commit627aa7f5b3688ba787c69e55e199ba82e2013be0 (patch)
tree7da937443dd9e435f790ef151f4d6dbd7e79baf5
parent7d328f5bfaa321d823ff4d11b62d5357c99e0693 (diff)
Packet loss generation model
-rw-r--r--dnn/torch/lossgen/lossgen.py28
-rwxr-xr-xdnn/torch/lossgen/process_data.sh17
-rw-r--r--dnn/torch/lossgen/test_lossgen.py45
-rw-r--r--dnn/torch/lossgen/train_lossgen.py96
4 files changed, 186 insertions, 0 deletions
diff --git a/dnn/torch/lossgen/lossgen.py b/dnn/torch/lossgen/lossgen.py
new file mode 100644
index 00000000..a1f2708b
--- /dev/null
+++ b/dnn/torch/lossgen/lossgen.py
@@ -0,0 +1,28 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+class LossGen(nn.Module):
+ def __init__(self, gru1_size=16, gru2_size=16):
+ super(LossGen, self).__init__()
+
+ self.gru1_size = gru1_size
+ self.gru2_size = gru2_size
+ self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
+ self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
+ self.dense_out = nn.Linear(self.gru2_size, 1)
+
+ def forward(self, loss, perc, states=None):
+ #print(states)
+ device = loss.device
+ batch_size = loss.size(0)
+ if states is None:
+ gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
+ gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
+ else:
+ gru1_state = states[0]
+ gru2_state = states[1]
+ x = torch.cat([loss, perc], dim=-1)
+ gru1_out, gru1_state = self.gru1(x, gru1_state)
+ gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
+ return self.dense_out(gru2_out), [gru1_state, gru2_state]
diff --git a/dnn/torch/lossgen/process_data.sh b/dnn/torch/lossgen/process_data.sh
new file mode 100755
index 00000000..308fd0aa
--- /dev/null
+++ b/dnn/torch/lossgen/process_data.sh
@@ -0,0 +1,17 @@
+#!/bin/sh
+
+#directory containing the loss files
+datadir=$1
+
+for i in $datadir/*_is_lost.txt
+do
+ perc=`cat $i | awk '{a+=$1}END{print a/NR}'`
+ echo $perc $i
+done > percentage_list.txt
+
+sort -n percentage_list.txt | awk '{print $2}' > percentage_sorted.txt
+
+for i in `cat percentage_sorted.txt`
+do
+ cat $i
+done > loss_sorted.txt
diff --git a/dnn/torch/lossgen/test_lossgen.py b/dnn/torch/lossgen/test_lossgen.py
new file mode 100644
index 00000000..0258d0e6
--- /dev/null
+++ b/dnn/torch/lossgen/test_lossgen.py
@@ -0,0 +1,45 @@
+import lossgen
+import os
+import argparse
+import torch
+import numpy as np
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('model', type=str, help='CELPNet model')
+parser.add_argument('percentage', type=float, help='percentage loss')
+parser.add_argument('output', type=str, help='path to output file (ascii)')
+
+parser.add_argument('--length', type=int, help="length of sequence to generate", default=500)
+
+args = parser.parse_args()
+
+
+
+checkpoint = torch.load(args.model, map_location='cpu')
+
+model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+states=None
+last = torch.zeros((1,1,1))
+perc = torch.tensor((args.percentage,))[None,None,:]
+seq = torch.zeros((0,1,1))
+
+one = torch.ones((1,1,1))
+zero = torch.zeros((1,1,1))
+
+if __name__ == '__main__':
+ for i in range(args.length):
+ prob, states = model(last, perc, states=states)
+ prob = torch.sigmoid(prob)
+ states[0] = states[0].detach()
+ states[1] = states[1].detach()
+ loss = one if np.random.rand() < prob else zero
+ last = loss
+ seq = torch.cat([seq, loss])
+
+np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d')
diff --git a/dnn/torch/lossgen/train_lossgen.py b/dnn/torch/lossgen/train_lossgen.py
new file mode 100644
index 00000000..f0f6dd75
--- /dev/null
+++ b/dnn/torch/lossgen/train_lossgen.py
@@ -0,0 +1,96 @@
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import tqdm
+from scipy.signal import lfilter
+import os
+import lossgen
+
+class LossDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ loss_file,
+ sequence_length=997):
+
+ self.sequence_length = sequence_length
+
+ self.loss = np.loadtxt(loss_file, dtype='float32')
+
+ self.nb_sequences = self.loss.shape[0]//self.sequence_length
+ self.loss = self.loss[:self.nb_sequences*self.sequence_length]
+ self.perc = lfilter(np.array([.001], dtype='float32'), np.array([1., -.999], dtype='float32'), self.loss)
+
+ self.loss = np.reshape(self.loss, (self.nb_sequences, self.sequence_length, 1))
+ self.perc = np.reshape(self.perc, (self.nb_sequences, self.sequence_length, 1))
+
+ def __len__(self):
+ return self.nb_sequences
+
+ def __getitem__(self, index):
+ r0 = np.random.normal(scale=.02, size=(1,1)).astype('float32')
+ r1 = np.random.normal(scale=.02, size=(self.sequence_length,1)).astype('float32')
+ return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1]
+
+
+adam_betas = [0.8, 0.99]
+adam_eps = 1e-8
+batch_size=512
+lr_decay = 0.0001
+lr = 0.001
+epsilon = 1e-5
+epochs = 20
+checkpoint_dir='checkpoint'
+os.makedirs(checkpoint_dir, exist_ok=True)
+checkpoint = dict()
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+checkpoint['model_args'] = ()
+checkpoint['model_kwargs'] = {'gru1_size': 16, 'gru2_size': 48}
+model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+dataset = LossDataset('loss_sorted.txt')
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
+
+
+optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
+
+
+# learning rate scheduler
+scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
+
+
+if __name__ == '__main__':
+ model.to(device)
+
+ for epoch in range(1, epochs + 1):
+
+ running_loss = 0
+
+ print(f"training epoch {epoch}...")
+ with tqdm.tqdm(dataloader, unit='batch') as tepoch:
+ for i, (loss, perc) in enumerate(tepoch):
+ optimizer.zero_grad()
+ loss = loss.to(device)
+ perc = perc.to(device)
+
+ out, _ = model(loss, perc)
+ out = torch.sigmoid(out[:,:-1,:])
+ target = loss[:,1:,:]
+
+ loss = torch.mean(-target*torch.log(out+epsilon) - (1-target)*torch.log(1-out+epsilon))
+
+ loss.backward()
+ optimizer.step()
+
+ scheduler.step()
+
+ running_loss += loss.detach().cpu().item()
+ tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
+ )
+
+ # save checkpoint
+ checkpoint_path = os.path.join(checkpoint_dir, f'lossgen_{epoch}.pth')
+ checkpoint['state_dict'] = model.state_dict()
+ checkpoint['loss'] = running_loss / len(dataloader)
+ checkpoint['epoch'] = epoch
+ torch.save(checkpoint, checkpoint_path)