Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/lossgen/test_lossgen.py')
-rw-r--r--dnn/torch/lossgen/test_lossgen.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/dnn/torch/lossgen/test_lossgen.py b/dnn/torch/lossgen/test_lossgen.py
new file mode 100644
index 00000000..0258d0e6
--- /dev/null
+++ b/dnn/torch/lossgen/test_lossgen.py
@@ -0,0 +1,45 @@
+import lossgen
+import os
+import argparse
+import torch
+import numpy as np
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('model', type=str, help='CELPNet model')
+parser.add_argument('percentage', type=float, help='percentage loss')
+parser.add_argument('output', type=str, help='path to output file (ascii)')
+
+parser.add_argument('--length', type=int, help="length of sequence to generate", default=500)
+
+args = parser.parse_args()
+
+
+
+checkpoint = torch.load(args.model, map_location='cpu')
+
+model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+states=None
+last = torch.zeros((1,1,1))
+perc = torch.tensor((args.percentage,))[None,None,:]
+seq = torch.zeros((0,1,1))
+
+one = torch.ones((1,1,1))
+zero = torch.zeros((1,1,1))
+
+if __name__ == '__main__':
+ for i in range(args.length):
+ prob, states = model(last, perc, states=states)
+ prob = torch.sigmoid(prob)
+ states[0] = states[0].detach()
+ states[1] = states[1].detach()
+ loss = one if np.random.rand() < prob else zero
+ last = loss
+ seq = torch.cat([seq, loss])
+
+np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d')