diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-07-22 23:31:22 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-07-22 23:31:22 +0300 |
commit | f9aee675dcf60970dc0cdfc99cc7ac3b79b54e38 (patch) | |
tree | 2198a6de482283aaa34baa3c720e13d592260e09 /dnn/torch/osce/utils/layers | |
parent | 57ab4949a8b586ae5b1aaa9c37748a1f38b6e68d (diff) |
added ShapeNet and ShapeUp48 models
Diffstat (limited to 'dnn/torch/osce/utils/layers')
-rw-r--r-- | dnn/torch/osce/utils/layers/silk_upsampler.py | 138 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/td_shaper.py | 129 |
2 files changed, 267 insertions, 0 deletions
diff --git a/dnn/torch/osce/utils/layers/silk_upsampler.py b/dnn/torch/osce/utils/layers/silk_upsampler.py new file mode 100644 index 00000000..d5f396ed --- /dev/null +++ b/dnn/torch/osce/utils/layers/silk_upsampler.py @@ -0,0 +1,138 @@ +""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """ + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +frac_fir = np.array( + [ + [189, -600, 617, 30567, 2996, -1375, 425, -46], + [117, -159, -1070, 29704, 5784, -2143, 611, -71], + [52, 221, -2392, 28276, 8798, -2865, 773, -91], + [-4, 529, -3350, 26341, 11950, -3487, 896, -103], + [-48, 758, -3956, 23973, 15143, -3957, 967, -107], + [-80, 905, -4235, 21254, 18278, -4222, 972, -99], + [-99, 972, -4222, 18278, 21254, -4235, 905, -80], + [-107, 967, -3957, 15143, 23973, -3956, 758, -48], + [-103, 896, -3487, 11950, 26341, -3350, 529, -4], + [-91, 773, -2865, 8798, 28276, -2392, 221, 52], + [-71, 611, -2143, 5784, 29704, -1070, -159, 117], + [-46, 425, -1375, 2996, 30567, 617, -600, 189] + ], + dtype=np.float32 +) / 2**15 + + +hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]] +hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]] + + +def get_impz(coeffs, n): + s = 3*[0] + y = np.zeros(n) + x = 1 + + for i in range(n): + Y = x - s[0] + X = Y * coeffs[0] + tmp1 = s[0] + X + s[0] = x + X + + Y = tmp1 - s[1] + X = Y * coeffs[1] + tmp2 = s[1] + X + s[1] = tmp1 + X + + Y = tmp2 - s[2] + X = Y * (1 + coeffs[2]) + tmp3 = s[2] + X + s[2] = tmp2 + X + + y[i] = tmp3 + x = 0 + + return y + + + +class SilkUpsampler(nn.Module): + SUPPORTED_TARGET_RATES = {24000, 48000} + SUPPORTED_SOURCE_RATES = {16000} + def __init__(self, + fs_in=16000, + fs_out=48000): + + super().__init__() + self.fs_in = fs_in + self.fs_out = fs_out + + if fs_in not in self.SUPPORTED_SOURCE_RATES: + raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz') + + + if fs_out not in self.SUPPORTED_TARGET_RATES: + raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz') + + + # hq 2x upsampler as FIR approximation + hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy() + hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy() + + self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False) + self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False) + self.hq_2x_up_padding = [127, 0] + + # interpolation filters + frac_01_24 = frac_fir[0] + frac_17_24 = frac_fir[8] + frac_09_24 = frac_fir[4] + + self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False) + self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False) + self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False) + + self.stride = 1 if fs_out == 48000 else 2 + + def hq_2x_up(self, x): + + num_channels = x.size(1) + + weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0) + weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0) + + x_pad = F.pad(x, self.hq_2x_up_padding) + y_even = F.conv1d(x_pad, weight_even, groups=num_channels) + y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels) + + y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2) + + return y + + def interpolate_3_2(self, x): + + num_channels = x.size(1) + + weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0) + weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0) + weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0) + + x_pad = F.pad(x, [8, 0]) + y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels) + y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels) + y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels) + + + y = torch.cat( + (y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)), + dim=-1).flatten(2) + + return y[..., :-3] + + def forward(self, x): + + y_2x = self.hq_2x_up(x) + y_3x = self.interpolate_3_2(y_2x) + + return y_3x[:, :, ::self.stride] diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py new file mode 100644 index 00000000..2ab12bad --- /dev/null +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -0,0 +1,129 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from utils.complexity import _conv1d_flop_count + +class TDShaper(nn.Module): + COUNTER = 1 + + def __init__(self, + feature_dim, + frame_size=160, + avg_pool_k=4, + innovate=False + ): + """ + + Parameters: + ----------- + + + feature_dim : int + dimension of input features + + frame_size : int + frame size + + avg_pool_k : int, optional + kernel size and stride for avg pooling + + padding : List[int, int] + + """ + + super().__init__() + + + self.feature_dim = feature_dim + self.frame_size = frame_size + self.avg_pool_k = avg_pool_k + self.innovate = innovate + + assert frame_size % avg_pool_k == 0 + self.env_dim = frame_size // avg_pool_k + 1 + + # feature transform + self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + + if self.innovate: + self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + + self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2) + + + def flop_count(self, rate): + + frame_rate = rate / self.frame_size + + shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size + + if self.innovate: + inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size + else: + inno_flops = 0 + + return shape_flops + inno_flops + + def envelope_transform(self, x): + + x = torch.abs(x) + x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) + x = torch.log(x + .5**16) + + x = x.reshape(x.size(0), -1, self.env_dim - 1) + avg_x = torch.mean(x, -1, keepdim=True) + + x = torch.cat((x - avg_x, avg_x), dim=-1) + + return x + + def forward(self, x, features, debug=False): + """ innovate signal parts with temporal shaping + + + Parameters: + ----------- + x : torch.tensor + input signal of shape (batch_size, 1, num_samples) + + features : torch.tensor + frame-wise features of shape (batch_size, num_frames, feature_dim) + + """ + + batch_size = x.size(0) + num_frames = features.size(1) + num_samples = x.size(2) + frame_size = self.frame_size + + # generate temporal envelope + tenv = self.envelope_transform(x) + + # feature path + f = torch.cat((features, tenv), dim=-1) + f = F.pad(f.permute(0, 2, 1), [1, 0]) + alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) + alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) + alpha = alpha.permute(0, 2, 1) + + if self.innovate: + inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2) + inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0]))) + inno_alpha = inno_alpha.permute(0, 2, 1) + + inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2) + inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0]))) + inno_x = inno_x.permute(0, 2, 1) + + # signal path + y = x.reshape(batch_size, num_frames, -1) + y = alpha * y + + if self.innovate: + y = y + inno_alpha * inno_x + + return y.reshape(batch_size, 1, num_samples) |