Welcome to mirror list, hosted at ThFree Co, Russian Federation.

lossgen.py « lossgen « torch « dnn - gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: a1f2708bbd366380e9cb7e35673a19ee0c1edef3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch import nn
import torch.nn.functional as F

class LossGen(nn.Module):
    def __init__(self, gru1_size=16, gru2_size=16):
        super(LossGen, self).__init__()

        self.gru1_size = gru1_size
        self.gru2_size = gru2_size
        self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
        self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
        self.dense_out = nn.Linear(self.gru2_size, 1)

    def forward(self, loss, perc, states=None):
        #print(states)
        device = loss.device
        batch_size = loss.size(0)
        if states is None:
            gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
            gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
        else:
            gru1_state = states[0]
            gru2_state = states[1]
        x = torch.cat([loss, perc], dim=-1)
        gru1_out, gru1_state = self.gru1(x, gru1_state)
        gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
        return self.dense_out(gru2_out), [gru1_state, gru2_state]