Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/train_model.py')
-rw-r--r--dnn/torch/osce/train_model.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py
index 6e2514b9..e8f94dcc 100644
--- a/dnn/torch/osce/train_model.py
+++ b/dnn/torch/osce/train_model.py
@@ -27,9 +27,13 @@
*/
"""
+seed=1888
+
import os
import argparse
import sys
+import random
+random.seed(seed)
import yaml
@@ -40,9 +44,12 @@ except:
has_git = False
import torch
+torch.manual_seed(seed)
+torch.backends.cudnn.benchmark = False
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
+np.random.seed(seed)
from scipy.io import wavfile
@@ -54,7 +61,7 @@ from engine.engine import train_one_epoch, evaluate
from utils.silk_features import load_inference_data
-from utils.misc import count_parameters
+from utils.misc import count_parameters, count_nonzero_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
@@ -71,6 +78,7 @@ parser.add_argument('--no-redirect', action='store_true', help='disables re-dire
args = parser.parse_args()
+
torch.set_num_threads(4)
with open(args.setup, 'r') as f:
@@ -98,7 +106,7 @@ if os.path.exists(args.output):
reply = input('continue? (y/n): ')
if reply == 'n':
- os._exit()
+ os._exit(0)
else:
os.makedirs(args.output, exist_ok=True)
@@ -109,7 +117,7 @@ os.makedirs(checkpoint_dir, exist_ok=True)
if has_git:
working_dir = os.path.split(__file__)[0]
try:
- repo = git.Repo(working_dir)
+ repo = git.Repo(working_dir, search_parent_directories=True)
setup['repo'] = dict()
hash = repo.head.object.hexsha
urls = list(repo.remote().urls)
@@ -117,6 +125,8 @@ if has_git:
if is_dirty:
print("warning: repo is dirty")
+ with open(os.path.join(args.output, 'repo.diff'), "w") as f:
+ f.write(repo.git.execute(["git", "diff"]))
setup['repo']['hash'] = hash
setup['repo']['urls'] = urls
@@ -292,6 +302,6 @@ for ep in range(1, epochs + 1):
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
- print()
+ print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
print('Done')