Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-13 17:31:29 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-13 17:31:29 +0300
commite7beaec3fb49df389b077799c5d1778ccb68610e (patch)
treed118a8dcc6b0ae3c3cf54d0b49225463b34aae0d
parentb24c7b433ae9db990dbd52eb0f1b357568fb484c (diff)
integrated JM's FFT ada conv
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
-rw-r--r--dnn/torch/osce/utils/ada_conv.py71
-rw-r--r--dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py33
2 files changed, 79 insertions, 25 deletions
diff --git a/dnn/torch/osce/utils/ada_conv.py b/dnn/torch/osce/utils/ada_conv.py
new file mode 100644
index 00000000..b5b93f87
--- /dev/null
+++ b/dnn/torch/osce/utils/ada_conv.py
@@ -0,0 +1,71 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jean-Marc Valin */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+# x is (batch, nb_in_channels, nb_frames*frame_size)
+# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs)
+def adaconv_kernel(x, kernels, half_window, fft_size=256):
+ device=x.device
+ overlap_size=half_window.size(-1)
+ nb_frames=kernels.size(3)
+ nb_batches=kernels.size(0)
+ nb_out_channels=kernels.size(1)
+ nb_in_channels=kernels.size(2)
+ kernel_size = kernels.size(-1)
+ x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1)
+ frame_size = x.size(-1)
+ # build window: [zeros, rising window, ones, falling window, zeros]
+ window = torch.cat(
+ [
+ torch.zeros(frame_size, device=device),
+ half_window,
+ torch.ones(frame_size - overlap_size, device=device),
+ 1 - half_window,
+ torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device)
+ ])
+ x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2)
+ x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2)
+ x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1)
+ k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1)
+
+ # compute convolution
+ X = torch.fft.rfft(x_padded, dim=-1)
+ K = torch.fft.rfft(k_padded, dim=-1)
+
+ out = torch.fft.irfft(X * K, dim=-1)
+ # combine in channels
+ out = torch.sum(out, dim=2)
+ # apply the cross-fading
+ out = window.reshape(1, 1, 1, -1)*out
+ crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2)
+
+ return crossfaded.reshape(nb_batches, nb_out_channels, -1) \ No newline at end of file
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
index 5992296f..073ea1b1 100644
--- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
+++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py
@@ -33,6 +33,9 @@ import torch.nn.functional as F
from utils.endoscopy import write_data
+from utils.ada_conv import adaconv_kernel
+
+
class LimitedAdaptiveConv1d(nn.Module):
COUNTER = 1
@@ -184,39 +187,19 @@ class LimitedAdaptiveConv1d(nn.Module):
conv_biases = self.conv_bias(features).permute(0, 2, 1)
# calculate gains
- conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b)
+ conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b)
if debug and batch_size == 1:
key = self.name + "_gains"
- write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
+ write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
key = self.name + "_kernels"
write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size)
- # frame-wise convolution with overlap-add
- output_frames = []
- overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device)
- x = F.pad(x, self.padding)
- x = F.pad(x, [0, self.overlap_size])
-
- for i in range(num_frames):
- xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1))
- new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1)
-
- if self.use_bias:
- new_chunk = new_chunk + conv_biases[:, :, i : i + 1]
-
- new_chunk = new_chunk * conv_gains[:, :, i : i + 1]
-
- # overlapping part
- output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2)
+ conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1)
- # non-overlapping part
- output_frames.append(new_chunk[:, :, overlap_size : frame_size])
+ conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4)
- # mem for next frame
- overlap_mem = new_chunk[:, :, frame_size :]
+ output = adaconv_kernel(x, conv_kernels, win1, fft_size=256)
- # concatenate chunks
- output = torch.cat(output_frames, dim=-1)
return output \ No newline at end of file