diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-09 17:22:48 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:44:26 +0300 |
commit | 88f3df329e811cb394933c640edfa0d1e47b7d87 (patch) | |
tree | e9623299003c6c3429f0474d025daa59217c14e1 /dnn/torch/osce/utils/layers | |
parent | 72ea20de26d8276ca21cf0d269a01a7c138d1f3f (diff) |
sparsification now compatible with weight_norm
Diffstat (limited to 'dnn/torch/osce/utils/layers')
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py | 9 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py | 7 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/td_shaper.py | 19 |
3 files changed, 22 insertions, 13 deletions
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py index a116fd72..0d87ca19 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -49,6 +49,7 @@ class LimitedAdaptiveComb1d(nn.Module): global_gain_limits_db=[-6, 6], norm_p=2, softquant=False, + apply_weight_norm=False, **kwargs): """ @@ -99,20 +100,22 @@ class LimitedAdaptiveComb1d(nn.Module): else: self.name = name + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # network for generating convolution weights - self.conv_kernel = nn.Linear(feature_dim, kernel_size) + self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size)) if softquant: self.conv_kernel = soft_quant(self.conv_kernel) # comb filter gain - self.filter_gain = nn.Linear(feature_dim, 1) + self.filter_gain = norm(nn.Linear(feature_dim, 1)) self.log_gain_limit = gain_limit_db * 0.11512925464970229 with torch.no_grad(): self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit) - self.global_filter_gain = nn.Linear(feature_dim, 1) + self.global_filter_gain = norm(nn.Linear(feature_dim, 1)) log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229 self.filter_gain_a = (log_max - log_min) / 2 self.filter_gain_b = (log_max + log_min) / 2 diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py index af6ec0ec..55df8c14 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -52,6 +52,7 @@ class LimitedAdaptiveConv1d(nn.Module): shape_gain_db=0, norm_p=2, softquant=False, + apply_weight_norm=False, **kwargs): """ @@ -101,14 +102,16 @@ class LimitedAdaptiveConv1d(nn.Module): else: self.name = name + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # network for generating convolution weights - self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size) + self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size)) if softquant: self.conv_kernel = soft_quant(self.conv_kernel) self.shape_gain = min(1, 10**(shape_gain_db / 20)) - self.filter_gain = nn.Linear(feature_dim, out_channels) + self.filter_gain = norm(nn.Linear(feature_dim, out_channels)) log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229 self.filter_gain_a = (log_max - log_min) / 2 self.filter_gain_b = (log_max + log_min) / 2 diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 0d2052ee..fa7bf348 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -14,7 +14,8 @@ class TDShaper(nn.Module): avg_pool_k=4, innovate=False, pool_after=False, - softquant=False + softquant=False, + apply_weight_norm=False ): """ @@ -47,20 +48,22 @@ class TDShaper(nn.Module): assert frame_size % avg_pool_k == 0 self.env_dim = frame_size // avg_pool_k + 1 + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # feature transform - self.feature_alpha1_f = nn.Conv1d(self.feature_dim, frame_size, 2) - self.feature_alpha1_t = nn.Conv1d(self.env_dim, frame_size, 2) - self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2)) + self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2)) + self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2)) if softquant: self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) if self.innovate: - self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) - self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) + self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) - self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2) - self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2)) + self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2)) def flop_count(self, rate): |