diff options
Diffstat (limited to 'dnn/torch/osce/train_model.py')
-rw-r--r-- | dnn/torch/osce/train_model.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py index 6e2514b9..e8f94dcc 100644 --- a/dnn/torch/osce/train_model.py +++ b/dnn/torch/osce/train_model.py @@ -27,9 +27,13 @@ */ """ +seed=1888 + import os import argparse import sys +import random +random.seed(seed) import yaml @@ -40,9 +44,12 @@ except: has_git = False import torch +torch.manual_seed(seed) +torch.backends.cudnn.benchmark = False from torch.optim.lr_scheduler import LambdaLR import numpy as np +np.random.seed(seed) from scipy.io import wavfile @@ -54,7 +61,7 @@ from engine.engine import train_one_epoch, evaluate from utils.silk_features import load_inference_data -from utils.misc import count_parameters +from utils.misc import count_parameters, count_nonzero_parameters from losses.stft_loss import MRSTFTLoss, MRLogMelLoss @@ -71,6 +78,7 @@ parser.add_argument('--no-redirect', action='store_true', help='disables re-dire args = parser.parse_args() + torch.set_num_threads(4) with open(args.setup, 'r') as f: @@ -98,7 +106,7 @@ if os.path.exists(args.output): reply = input('continue? (y/n): ') if reply == 'n': - os._exit() + os._exit(0) else: os.makedirs(args.output, exist_ok=True) @@ -109,7 +117,7 @@ os.makedirs(checkpoint_dir, exist_ok=True) if has_git: working_dir = os.path.split(__file__)[0] try: - repo = git.Repo(working_dir) + repo = git.Repo(working_dir, search_parent_directories=True) setup['repo'] = dict() hash = repo.head.object.hexsha urls = list(repo.remote().urls) @@ -117,6 +125,8 @@ if has_git: if is_dirty: print("warning: repo is dirty") + with open(os.path.join(args.output, 'repo.diff'), "w") as f: + f.write(repo.git.execute(["git", "diff"])) setup['repo']['hash'] = hash setup['repo']['urls'] = urls @@ -292,6 +302,6 @@ for ep in range(1, epochs + 1): torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) - print() + print(f"non-zero parameters: {count_nonzero_parameters(model)}\n") print('Done') |