Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/lossgen/train_lossgen.py')
-rw-r--r--dnn/torch/lossgen/train_lossgen.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/dnn/torch/lossgen/train_lossgen.py b/dnn/torch/lossgen/train_lossgen.py
index 26e0f012..4dda3190 100644
--- a/dnn/torch/lossgen/train_lossgen.py
+++ b/dnn/torch/lossgen/train_lossgen.py
@@ -27,9 +27,11 @@ class LossDataset(torch.utils.data.Dataset):
return self.nb_sequences
def __getitem__(self, index):
- r0 = np.random.normal(scale=.02, size=(1,1)).astype('float32')
- r1 = np.random.normal(scale=.02, size=(self.sequence_length,1)).astype('float32')
- return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1]
+ r0 = np.random.normal(scale=.1, size=(1,1)).astype('float32')
+ r1 = np.random.normal(scale=.1, size=(self.sequence_length,1)).astype('float32')
+ perc = self.perc[index, :, :]
+ perc = perc + (r0+r1)*perc*(1-perc)
+ return [self.loss[index, :, :], perc]
adam_betas = [0.8, 0.98]
@@ -61,7 +63,7 @@ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lam
if __name__ == '__main__':
model.to(device)
-
+ states = None
for epoch in range(1, epochs + 1):
running_loss = 0
@@ -73,7 +75,8 @@ if __name__ == '__main__':
loss = loss.to(device)
perc = perc.to(device)
- out, _ = model(loss, perc)
+ out, states = model(loss, perc, states=states)
+ states = [state.detach() for state in states]
out = torch.sigmoid(out[:,:-1,:])
target = loss[:,1:,:]