Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-07-01 00:15:56 +0300
committerJean-Marc Valin <jmvalin@jmvalin.ca>2023-07-01 00:15:56 +0300
commit105e1d83fad6393b00edb7eb676be483eb4ee2d7 (patch)
tree74873b0b88818966a9f614f144bb2e55762338a6 /dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
parent178672ed1823f2a2fdc7e36e34578383f799f4f6 (diff)
Opus ng lace
Diffstat (limited to 'dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py')
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py236
1 files changed, 236 insertions, 0 deletions
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
new file mode 100644
index 00000000..b146240e
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py
@@ -0,0 +1,236 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.endoscopy import write_data
+
+class LimitedAdaptiveComb1d(nn.Module):
+ COUNTER = 1
+
+ def __init__(self,
+ kernel_size,
+ feature_dim,
+ frame_size=160,
+ overlap_size=40,
+ use_bias=True,
+ padding=None,
+ max_lag=256,
+ name=None,
+ gain_limit_db=10,
+ global_gain_limits_db=[-6, 6],
+ norm_p=2):
+ """
+
+ Parameters:
+ -----------
+
+ feature_dim : int
+ dimension of features from which kernels, biases and gains are computed
+
+ frame_size : int, optional
+ frame size, defaults to 160
+
+ overlap_size : int, optional
+ overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40
+
+ use_bias : bool, optional
+ if true, biases will be added to output channels. Defaults to True
+
+ padding : List[int, int], optional
+ left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2]
+
+ max_lag : int, optional
+ maximal pitch lag, defaults to 256
+
+ have_a0 : bool, optional
+ If true, the filter coefficient a0 will be learned as a positive gain (requires in_channels == out_channels). Otherwise, a0 is set to 0. Defaults to False
+
+ name: str or None, optional
+ specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d
+
+ """
+
+ super(LimitedAdaptiveComb1d, self).__init__()
+
+ self.in_channels = 1
+ self.out_channels = 1
+ self.feature_dim = feature_dim
+ self.kernel_size = kernel_size
+ self.frame_size = frame_size
+ self.overlap_size = overlap_size
+ self.use_bias = use_bias
+ self.max_lag = max_lag
+ self.limit_db = gain_limit_db
+ self.norm_p = norm_p
+
+ if name is None:
+ self.name = "limited_adaptive_comb1d_" + str(LimitedAdaptiveComb1d.COUNTER)
+ LimitedAdaptiveComb1d.COUNTER += 1
+ else:
+ self.name = name
+
+ # network for generating convolution weights
+ self.conv_kernel = nn.Linear(feature_dim, kernel_size)
+
+ if self.use_bias:
+ self.conv_bias = nn.Linear(feature_dim,1)
+
+ # comb filter gain
+ self.filter_gain = nn.Linear(feature_dim, 1)
+ self.log_gain_limit = gain_limit_db * 0.11512925464970229
+ with torch.no_grad():
+ self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit)
+
+ self.global_filter_gain = nn.Linear(feature_dim, 1)
+ log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229
+ self.filter_gain_a = (log_max - log_min) / 2
+ self.filter_gain_b = (log_max + log_min) / 2
+
+ if type(padding) == type(None):
+ self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2]
+ else:
+ self.padding = padding
+
+ self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False)
+
+ def forward(self, x, features, lags, debug=False):
+ """ adaptive 1d convolution
+
+
+ Parameters:
+ -----------
+ x : torch.tensor
+ input signal of shape (batch_size, in_channels, num_samples)
+
+ feathres : torch.tensor
+ frame-wise features of shape (batch_size, num_frames, feature_dim)
+
+ lags: torch.LongTensor
+ frame-wise lags for comb-filtering
+
+ """
+
+ batch_size = x.size(0)
+ num_frames = features.size(1)
+ num_samples = x.size(2)
+ frame_size = self.frame_size
+ overlap_size = self.overlap_size
+ kernel_size = self.kernel_size
+ win1 = torch.flip(self.overlap_win, [0])
+ win2 = self.overlap_win
+
+ if num_samples // self.frame_size != num_frames:
+ raise ValueError('non matching sizes in AdaptiveConv1d.forward')
+
+ conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size))
+ conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True))
+
+ if self.use_bias:
+ conv_biases = self.conv_bias(features).permute(0, 2, 1)
+
+ conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit)
+ # calculate gains
+ global_conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.global_filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+
+ if debug and batch_size == 1:
+ key = self.name + "_gains"
+ write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_kernels"
+ write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_lags"
+ write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ key = self.name + "_global_conv_gains"
+ write_data(key, global_conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+
+
+ # frame-wise convolution with overlap-add
+ output_frames = []
+ overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
+ x = F.pad(x, self.padding)
+ x = F.pad(x, [self.max_lag, self.overlap_size])
+
+ idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1)
+ idx = torch.repeat_interleave(idx, batch_size, 0)
+ idx = torch.repeat_interleave(idx, self.in_channels, 1)
+
+
+ for i in range(num_frames):
+
+ cidx = idx + i * frame_size + self.max_lag - lags[..., i].view(batch_size, 1, 1)
+ xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1))
+
+ new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
+
+
+ if self.use_bias:
+ new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
+
+ offset = self.max_lag + self.padding[0]
+ new_chunk = global_conv_gains[:, :, i : i + 1] * (new_chunk * conv_gains[:, :, i : i + 1] + x[..., offset + i * frame_size : offset + (i + 1) * frame_size + overlap_size])
+
+ # overlapping part
+ output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
+
+ # non-overlapping part
+ output_frames.append(new_chunk[:, :, overlap_size : frame_size])
+
+ # mem for next frame
+ overlap_mem = new_chunk[:, :, frame_size :]
+
+ # concatenate chunks
+ output = torch.cat(output_frames, dim=-1)
+
+ return output
+
+ def flop_count(self, rate):
+ frame_rate = rate / self.frame_size
+ overlap = self.overlap_size
+ overhead = overlap / self.frame_size
+
+ count = 0
+
+ # kernel computation and filtering
+ count += 2 * (frame_rate * self.feature_dim * self.kernel_size)
+ count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate)
+ count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+ # bias computation
+ if self.use_bias:
+ count += 2 * (frame_rate * self.feature_dim) + rate * (1 + overhead)
+
+ # a0 computation
+ count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels
+
+ # windowing
+ count += overlap * frame_rate * 3 * self.out_channels
+
+ return count