Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-09 17:22:48 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-19 18:44:26 +0300
commit88f3df329e811cb394933c640edfa0d1e47b7d87 (patch)
treee9623299003c6c3429f0474d025daa59217c14e1 /dnn/torch/osce/utils/layers
parent72ea20de26d8276ca21cf0d269a01a7c138d1f3f (diff)
sparsification now compatible with weight_norm
Diffstat (limited to 'dnn/torch/osce/utils/layers')
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py9
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py7
-rw-r--r--dnn/torch/osce/utils/layers/td_shaper.py19
3 files changed, 22 insertions, 13 deletions
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
index a116fd72..0d87ca19 100644
--- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
@@ -49,6 +49,7 @@ class LimitedAdaptiveComb1d(nn.Module):
global_gain_limits_db=[-6, 6],
norm_p=2,
softquant=False,
+ apply_weight_norm=False,
**kwargs):
"""
@@ -99,20 +100,22 @@ class LimitedAdaptiveComb1d(nn.Module):
else:
self.name = name
+ norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
+
# network for generating convolution weights
- self.conv_kernel = nn.Linear(feature_dim, kernel_size)
+ self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
# comb filter gain
- self.filter_gain = nn.Linear(feature_dim, 1)
+ self.filter_gain = norm(nn.Linear(feature_dim, 1))
self.log_gain_limit = gain_limit_db * 0.11512925464970229
with torch.no_grad():
self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
- self.global_filter_gain = nn.Linear(feature_dim, 1)
+ self.global_filter_gain = norm(nn.Linear(feature_dim, 1))
log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
index af6ec0ec..55df8c14 100644
--- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -52,6 +52,7 @@ class LimitedAdaptiveConv1d(nn.Module):
shape_gain_db=0,
norm_p=2,
softquant=False,
+ apply_weight_norm=False,
**kwargs):
"""
@@ -101,14 +102,16 @@ class LimitedAdaptiveConv1d(nn.Module):
else:
self.name = name
+ norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
+
# network for generating convolution weights
- self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size)
+ self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size))
if softquant:
self.conv_kernel = soft_quant(self.conv_kernel)
self.shape_gain = min(1, 10**(shape_gain_db / 20))
- self.filter_gain = nn.Linear(feature_dim, out_channels)
+ self.filter_gain = norm(nn.Linear(feature_dim, out_channels))
log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229
self.filter_gain_a = (log_max - log_min) / 2
self.filter_gain_b = (log_max + log_min) / 2
diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py
index 0d2052ee..fa7bf348 100644
--- a/dnn/torch/osce/utils/layers/td_shaper.py
+++ b/dnn/torch/osce/utils/layers/td_shaper.py
@@ -14,7 +14,8 @@ class TDShaper(nn.Module):
avg_pool_k=4,
innovate=False,
pool_after=False,
- softquant=False
+ softquant=False,
+ apply_weight_norm=False
):
"""
@@ -47,20 +48,22 @@ class TDShaper(nn.Module):
assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1
+ norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
+
# feature transform
- self.feature_alpha1_f = nn.Conv1d(self.feature_dim, frame_size, 2)
- self.feature_alpha1_t = nn.Conv1d(self.env_dim, frame_size, 2)
- self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
+ self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
+ self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
+ self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
if softquant:
self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
if self.innovate:
- self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
- self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
+ self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
+ self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
- self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2)
- self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2)
+ self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
+ self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
def flop_count(self, rate):