diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-10 19:24:15 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:44:26 +0300 |
commit | 3d01ceca1d06352d5d485e60b8bedfe284c5a7be (patch) | |
tree | fdd1b5eaa21c99a07d4c76941e7bd6336b8e5f88 | |
parent | 88f3df329e811cb394933c640edfa0d1e47b7d87 (diff) |
split sparsification density according to layers
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 15 | ||||
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 21 | ||||
-rw-r--r-- | dnn/torch/osce/train_model.py | 8 |
3 files changed, 31 insertions, 13 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 19bc6caf..5ac3cbfa 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -27,6 +27,8 @@ */ """ +import numbers + import torch from torch import nn import torch.nn.functional as F @@ -93,6 +95,9 @@ class NoLACE(NNSBase): self.hidden_feature_dim = hidden_feature_dim self.partial_lookahead = partial_lookahead + if isinstance(sparsification_density, numbers.Number): + sparsification_density = 10 * [sparsification_density] + norm = weight_norm if apply_weight_norm else lambda x, name=None: x # pitch embedding @@ -142,11 +147,11 @@ class NoLACE(NNSBase): if sparsify: - mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4])) - mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4])) - mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4])) - mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4])) - mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4])) + mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4])) + mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4])) + mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4])) + mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4])) self.sparsifier = create_sparsifier(self, *sparsification_schedule) diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 75421449..fa476a4e 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -28,6 +28,8 @@ """ import sys sys.path.append('../dnntools') +import numbers + import torch from torch import nn @@ -51,6 +53,9 @@ class SilkFeatureNetPL(nn.Module): super(SilkFeatureNetPL, self).__init__() + if isinstance(sparsification_density, numbers.Number): + sparsification_density = 4 * [sparsification_density] + self.feature_dim = feature_dim self.num_channels = num_channels self.hidden_feature_dim = hidden_feature_dim @@ -69,17 +74,17 @@ class SilkFeatureNetPL(nn.Module): if sparsify: - mark_for_sparsification(self.conv2, (sparsification_density, [8, 4])) - mark_for_sparsification(self.tconv, (sparsification_density, [8, 4])) + mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4])) + mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4])) mark_for_sparsification( self.gru, { - 'W_ir' : (sparsification_density, [8, 4], False), - 'W_iz' : (sparsification_density, [8, 4], False), - 'W_in' : (sparsification_density, [8, 4], False), - 'W_hr' : (sparsification_density, [8, 4], True), - 'W_hz' : (sparsification_density, [8, 4], True), - 'W_hn' : (sparsification_density, [8, 4], True), + 'W_ir' : (sparsification_density[2], [8, 4], False), + 'W_iz' : (sparsification_density[2], [8, 4], False), + 'W_in' : (sparsification_density[2], [8, 4], False), + 'W_hr' : (sparsification_density[3], [8, 4], True), + 'W_hz' : (sparsification_density[3], [8, 4], True), + 'W_hn' : (sparsification_density[3], [8, 4], True), } ) diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py index 34cc638c..9b83c14d 100644 --- a/dnn/torch/osce/train_model.py +++ b/dnn/torch/osce/train_model.py @@ -27,9 +27,13 @@ */ """ +seed=1888 + import os import argparse import sys +import random +random.seed(seed) import yaml @@ -40,9 +44,12 @@ except: has_git = False import torch +torch.manual_seed(seed) +torch.backends.cudnn.benchmark = False from torch.optim.lr_scheduler import LambdaLR import numpy as np +np.random.seed(seed) from scipy.io import wavfile @@ -71,6 +78,7 @@ parser.add_argument('--no-redirect', action='store_true', help='disables re-dire args = parser.parse_args() + torch.set_num_threads(4) with open(args.setup, 'r') as f: |