diff options
Diffstat (limited to 'dnn/torch/lossgen/lossgen.py')
-rw-r--r-- | dnn/torch/lossgen/lossgen.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/dnn/torch/lossgen/lossgen.py b/dnn/torch/lossgen/lossgen.py new file mode 100644 index 00000000..a1f2708b --- /dev/null +++ b/dnn/torch/lossgen/lossgen.py @@ -0,0 +1,28 @@ +import torch +from torch import nn +import torch.nn.functional as F + +class LossGen(nn.Module): + def __init__(self, gru1_size=16, gru2_size=16): + super(LossGen, self).__init__() + + self.gru1_size = gru1_size + self.gru2_size = gru2_size + self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True) + self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True) + self.dense_out = nn.Linear(self.gru2_size, 1) + + def forward(self, loss, perc, states=None): + #print(states) + device = loss.device + batch_size = loss.size(0) + if states is None: + gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device) + gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device) + else: + gru1_state = states[0] + gru2_state = states[1] + x = torch.cat([loss, perc], dim=-1) + gru1_out, gru1_state = self.gru1(x, gru1_state) + gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) + return self.dense_out(gru2_out), [gru1_state, gru2_state] |