Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-10 19:24:15 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-19 18:44:26 +0300
commit3d01ceca1d06352d5d485e60b8bedfe284c5a7be (patch)
treefdd1b5eaa21c99a07d4c76941e7bd6336b8e5f88
parent88f3df329e811cb394933c640edfa0d1e47b7d87 (diff)
split sparsification density according to layers
-rw-r--r--dnn/torch/osce/models/no_lace.py15
-rw-r--r--dnn/torch/osce/models/silk_feature_net_pl.py21
-rw-r--r--dnn/torch/osce/train_model.py8
3 files changed, 31 insertions, 13 deletions
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 19bc6caf..5ac3cbfa 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -27,6 +27,8 @@
*/
"""
+import numbers
+
import torch
from torch import nn
import torch.nn.functional as F
@@ -93,6 +95,9 @@ class NoLACE(NNSBase):
self.hidden_feature_dim = hidden_feature_dim
self.partial_lookahead = partial_lookahead
+ if isinstance(sparsification_density, numbers.Number):
+ sparsification_density = 10 * [sparsification_density]
+
norm = weight_norm if apply_weight_norm else lambda x, name=None: x
# pitch embedding
@@ -142,11 +147,11 @@ class NoLACE(NNSBase):
if sparsify:
- mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_cf1, (sparsification_density[4], [8, 4]))
+ mark_for_sparsification(self.post_cf2, (sparsification_density[5], [8, 4]))
+ mark_for_sparsification(self.post_af1, (sparsification_density[6], [8, 4]))
+ mark_for_sparsification(self.post_af2, (sparsification_density[7], [8, 4]))
+ mark_for_sparsification(self.post_af3, (sparsification_density[8], [8, 4]))
self.sparsifier = create_sparsifier(self, *sparsification_schedule)
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py
index 75421449..fa476a4e 100644
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -28,6 +28,8 @@
"""
import sys
sys.path.append('../dnntools')
+import numbers
+
import torch
from torch import nn
@@ -51,6 +53,9 @@ class SilkFeatureNetPL(nn.Module):
super(SilkFeatureNetPL, self).__init__()
+ if isinstance(sparsification_density, numbers.Number):
+ sparsification_density = 4 * [sparsification_density]
+
self.feature_dim = feature_dim
self.num_channels = num_channels
self.hidden_feature_dim = hidden_feature_dim
@@ -69,17 +74,17 @@ class SilkFeatureNetPL(nn.Module):
if sparsify:
- mark_for_sparsification(self.conv2, (sparsification_density, [8, 4]))
- mark_for_sparsification(self.tconv, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
+ mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{
- 'W_ir' : (sparsification_density, [8, 4], False),
- 'W_iz' : (sparsification_density, [8, 4], False),
- 'W_in' : (sparsification_density, [8, 4], False),
- 'W_hr' : (sparsification_density, [8, 4], True),
- 'W_hz' : (sparsification_density, [8, 4], True),
- 'W_hn' : (sparsification_density, [8, 4], True),
+ 'W_ir' : (sparsification_density[2], [8, 4], False),
+ 'W_iz' : (sparsification_density[2], [8, 4], False),
+ 'W_in' : (sparsification_density[2], [8, 4], False),
+ 'W_hr' : (sparsification_density[3], [8, 4], True),
+ 'W_hz' : (sparsification_density[3], [8, 4], True),
+ 'W_hn' : (sparsification_density[3], [8, 4], True),
}
)
diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py
index 34cc638c..9b83c14d 100644
--- a/dnn/torch/osce/train_model.py
+++ b/dnn/torch/osce/train_model.py
@@ -27,9 +27,13 @@
*/
"""
+seed=1888
+
import os
import argparse
import sys
+import random
+random.seed(seed)
import yaml
@@ -40,9 +44,12 @@ except:
has_git = False
import torch
+torch.manual_seed(seed)
+torch.backends.cudnn.benchmark = False
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
+np.random.seed(seed)
from scipy.io import wavfile
@@ -71,6 +78,7 @@ parser.add_argument('--no-redirect', action='store_true', help='disables re-dire
args = parser.parse_args()
+
torch.set_num_threads(4)
with open(args.setup, 'r') as f: