diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-08 16:03:39 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-12-20 11:42:44 +0300 |
commit | 7d328f5bfaa321d823ff4d11b62d5357c99e0693 (patch) | |
tree | 873593e93c87a7b9b1de7f710696502737f1922b /dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py | |
parent | 591c8bad70d8aa414729d1a243a6d930f64d6316 (diff) |
Merge LACE/NoLACE under OSCE frameworkopus-ng-lace-integration5
Diffstat (limited to 'dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py')
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py | 18 |
1 files changed, 2 insertions, 16 deletions
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py index b146240e..3bb6fa07 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -41,13 +41,13 @@ class LimitedAdaptiveComb1d(nn.Module): feature_dim, frame_size=160, overlap_size=40, - use_bias=True, padding=None, max_lag=256, name=None, gain_limit_db=10, global_gain_limits_db=[-6, 6], - norm_p=2): + norm_p=2, + **kwargs): """ Parameters: @@ -87,7 +87,6 @@ class LimitedAdaptiveComb1d(nn.Module): self.kernel_size = kernel_size self.frame_size = frame_size self.overlap_size = overlap_size - self.use_bias = use_bias self.max_lag = max_lag self.limit_db = gain_limit_db self.norm_p = norm_p @@ -101,8 +100,6 @@ class LimitedAdaptiveComb1d(nn.Module): # network for generating convolution weights self.conv_kernel = nn.Linear(feature_dim, kernel_size) - if self.use_bias: - self.conv_bias = nn.Linear(feature_dim,1) # comb filter gain self.filter_gain = nn.Linear(feature_dim, 1) @@ -154,9 +151,6 @@ class LimitedAdaptiveComb1d(nn.Module): conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True)) - if self.use_bias: - conv_biases = self.conv_bias(features).permute(0, 2, 1) - conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit) # calculate gains global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b) @@ -190,10 +184,6 @@ class LimitedAdaptiveComb1d(nn.Module): new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1) - - if self.use_bias: - new_chunk = new_chunk + conv_biases[:, :, i : i + 1] - offset = self.max_lag + self.padding[0] new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size]) @@ -223,10 +213,6 @@ class LimitedAdaptiveComb1d(nn.Module): count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate) count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels - # bias computation - if self.use_bias: - count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead) - # a0 computation count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels |