Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/lossgen/lossgen.py')
-rw-r--r--dnn/torch/lossgen/lossgen.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/dnn/torch/lossgen/lossgen.py b/dnn/torch/lossgen/lossgen.py
new file mode 100644
index 00000000..a1f2708b
--- /dev/null
+++ b/dnn/torch/lossgen/lossgen.py
@@ -0,0 +1,28 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+class LossGen(nn.Module):
+ def __init__(self, gru1_size=16, gru2_size=16):
+ super(LossGen, self).__init__()
+
+ self.gru1_size = gru1_size
+ self.gru2_size = gru2_size
+ self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
+ self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
+ self.dense_out = nn.Linear(self.gru2_size, 1)
+
+ def forward(self, loss, perc, states=None):
+ #print(states)
+ device = loss.device
+ batch_size = loss.size(0)
+ if states is None:
+ gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
+ gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
+ else:
+ gru1_state = states[0]
+ gru2_state = states[1]
+ x = torch.cat([loss, perc], dim=-1)
+ gru1_out, gru1_state = self.gru1(x, gru1_state)
+ gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
+ return self.dense_out(gru2_out), [gru1_state, gru2_state]