diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-09-13 17:31:29 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-09-13 17:31:29 +0300 |
commit | e7beaec3fb49df389b077799c5d1778ccb68610e (patch) | |
tree | d118a8dcc6b0ae3c3cf54d0b49225463b34aae0d | |
parent | b24c7b433ae9db990dbd52eb0f1b357568fb484c (diff) |
integrated JM's FFT ada conv
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
-rw-r--r-- | dnn/torch/osce/utils/ada_conv.py | 71 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py | 33 |
2 files changed, 79 insertions, 25 deletions
diff --git a/dnn/torch/osce/utils/ada_conv.py b/dnn/torch/osce/utils/ada_conv.py new file mode 100644 index 00000000..b5b93f87 --- /dev/null +++ b/dnn/torch/osce/utils/ada_conv.py @@ -0,0 +1,71 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jean-Marc Valin */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +# x is (batch, nb_in_channels, nb_frames*frame_size) +# kernels is (batch, nb_out_channels, nb_in_channels, nb_frames, coeffs) +def adaconv_kernel(x, kernels, half_window, fft_size=256): + device=x.device + overlap_size=half_window.size(-1) + nb_frames=kernels.size(3) + nb_batches=kernels.size(0) + nb_out_channels=kernels.size(1) + nb_in_channels=kernels.size(2) + kernel_size = kernels.size(-1) + x = x.reshape(nb_batches, 1, nb_in_channels, nb_frames, -1) + frame_size = x.size(-1) + # build window: [zeros, rising window, ones, falling window, zeros] + window = torch.cat( + [ + torch.zeros(frame_size, device=device), + half_window, + torch.ones(frame_size - overlap_size, device=device), + 1 - half_window, + torch.zeros(fft_size - 2 * frame_size - overlap_size,device=device) + ]) + x_prev = torch.cat([torch.zeros_like(x[:, :, :, :1, :]), x[:, :, :, :-1, :]], dim=-2) + x_next = torch.cat([x[:, :, :, 1:, :overlap_size], torch.zeros_like(x[:, :, :, -1:, :overlap_size])], dim=-2) + x_padded = torch.cat([x_prev, x, x_next, torch.zeros(nb_batches, 1, nb_in_channels, nb_frames, fft_size - 2 * frame_size - overlap_size, device=device)], -1) + k_padded = torch.cat([torch.flip(kernels, [-1]), torch.zeros(nb_batches, nb_out_channels, nb_in_channels, nb_frames, fft_size-kernel_size, device=device)], dim=-1) + + # compute convolution + X = torch.fft.rfft(x_padded, dim=-1) + K = torch.fft.rfft(k_padded, dim=-1) + + out = torch.fft.irfft(X * K, dim=-1) + # combine in channels + out = torch.sum(out, dim=2) + # apply the cross-fading + out = window.reshape(1, 1, 1, -1)*out + crossfaded = out[:,:,:,frame_size:2*frame_size] + torch.cat([torch.zeros(nb_batches, nb_out_channels, 1, frame_size, device=device), out[:, :, :-1, 2*frame_size:3*frame_size]], dim=-2) + + return crossfaded.reshape(nb_batches, nb_out_channels, -1)
\ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py index 5992296f..073ea1b1 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -33,6 +33,9 @@ import torch.nn.functional as F from utils.endoscopy import write_data +from utils.ada_conv import adaconv_kernel + + class LimitedAdaptiveConv1d(nn.Module): COUNTER = 1 @@ -184,39 +187,19 @@ class LimitedAdaptiveConv1d(nn.Module): conv_biases = self.conv_bias(features).permute(0, 2, 1) # calculate gains - conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features).permute(0, 2, 1)) + self.filter_gain_b) + conv_gains = torch.exp(self.filter_gain_a * torch.tanh(self.filter_gain(features)) + self.filter_gain_b) if debug and batch_size == 1: key = self.name + "_gains" - write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + write_data(key, conv_gains.permute(0, 2, 1).detach().squeeze().cpu().numpy(), 16000 // self.frame_size) key = self.name + "_kernels" write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) - # frame-wise convolution with overlap-add - output_frames = [] - overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device) - x = F.pad(x, self.padding) - x = F.pad(x, [0, self.overlap_size]) - - for i in range(num_frames): - xx = x[:, :, i * frame_size : (i + 1) * frame_size + kernel_size - 1 + overlap_size].reshape((1, batch_size * self.in_channels, -1)) - new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1) - - if self.use_bias: - new_chunk = new_chunk + conv_biases[:, :, i : i + 1] - - new_chunk = new_chunk * conv_gains[:, :, i : i + 1] - - # overlapping part - output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2) + conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1) - # non-overlapping part - output_frames.append(new_chunk[:, :, overlap_size : frame_size]) + conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4) - # mem for next frame - overlap_mem = new_chunk[:, :, frame_size :] + output = adaconv_kernel(x, conv_kernels, win1, fft_size=256) - # concatenate chunks - output = torch.cat(output_frames, dim=-1) return output
\ No newline at end of file |