diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-12-18 14:19:55 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:42:54 +0300 |
commit | bfcbbbcb38ce399a6f4cf629c80097bc7d1dadbc (patch) | |
tree | 002d54b3a7d8ae44a641623c95234b06f3d2be40 /dnn/torch/osce | |
parent | 4f311a1ad44f1b7bd60e32984ca0604c46b6c593 (diff) |
added softquant option and activate tconv layer
Diffstat (limited to 'dnn/torch/osce')
-rw-r--r-- | dnn/torch/osce/models/lace.py | 11 | ||||
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 37 | ||||
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 14 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py | 5 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py | 5 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/td_shaper.py | 19 | ||||
-rw-r--r-- | dnn/torch/osce/utils/softquant.py | 55 | ||||
-rw-r--r-- | dnn/torch/osce/utils/templates.py | 15 |
8 files changed, 129 insertions, 32 deletions
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 58293de4..5bc9fe41 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -60,7 +60,8 @@ class LACE(NNSBase): numbits_embedding_dim=8, hidden_feature_dim=64, partial_lookahead=True, - norm_p=2): + norm_p=2, + softquant=False): super().__init__(skip=skip, preemph=preemph) @@ -85,18 +86,18 @@ class LACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) # comb filters left_pad = self.kernel_size // 2 right_pad = self.kernel_size - 1 - left_pad - self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) - self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) def flop_count(self, rate=16000, verbose=False): diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 0e0fb1b3..088340e5 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -37,6 +37,7 @@ from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d from utils.layers.td_shaper import TDShaper from utils.complexity import _conv1d_flop_count +from utils.softquant import soft_quant from models.nns_base import NNSBase from models.silk_feature_net_pl import SilkFeatureNetPL @@ -64,7 +65,8 @@ class NoLACE(NNSBase): partial_lookahead=True, norm_p=2, avg_pool_k=4, - pool_after=False): + pool_after=False, + softquant=False): super().__init__(skip=skip, preemph=preemph) @@ -89,28 +91,28 @@ class NoLACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) # comb filters left_pad = self.kernel_size // 2 right_pad = self.kernel_size - 1 - left_pad - self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) - self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) # non-linear transforms - self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) - self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) - self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant) # combinators - self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) - self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) - self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) # feature transforms self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) @@ -119,6 +121,19 @@ class NoLACE(NNSBase): self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) + if softquant: + self.post_cf1 = soft_quant(self.post_cf1) + self.post_cf2 = soft_quant(self.post_cf2) + self.post_af1 = soft_quant(self.post_af1) + self.post_af2 = soft_quant(self.post_af2) + self.post_af3 = soft_quant(self.post_af3) + + + + + + + def flop_count(self, rate=16000, verbose=False): diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index ae37951c..81beb36d 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -33,13 +33,15 @@ from torch import nn import torch.nn.functional as F from utils.complexity import _conv1d_flop_count +from utils.softquant import soft_quant class SilkFeatureNetPL(nn.Module): """ feature net with partial lookahead """ def __init__(self, feature_dim=47, num_channels=256, - hidden_feature_dim=64): + hidden_feature_dim=64, + softquant=False): super(SilkFeatureNetPL, self).__init__() @@ -50,9 +52,15 @@ class SilkFeatureNetPL(nn.Module): self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1) self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2) self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4) - self.gru = nn.GRU(num_channels, num_channels, batch_first=True) + if softquant: + self.conv2 = soft_quant(self.conv2) + self.tconv = soft_quant(self.tconv) + self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0']) + + + def flop_count(self, rate=200): count = 0 for conv in self.conv1, self.conv2, self.tconv: @@ -82,7 +90,7 @@ class SilkFeatureNetPL(nn.Module): c = torch.tanh(self.conv2(F.pad(c, [1, 0]))) # upsampling - c = self.tconv(c) + c = torch.tanh(self.tconv(c)) c = c.permute(0, 2, 1) c, _ = self.gru(c, state) diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py index 3bb6fa07..a116fd72 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -32,6 +32,7 @@ from torch import nn import torch.nn.functional as F from utils.endoscopy import write_data +from utils.softquant import soft_quant class LimitedAdaptiveComb1d(nn.Module): COUNTER = 1 @@ -47,6 +48,7 @@ class LimitedAdaptiveComb1d(nn.Module): gain_limit_db=10, global_gain_limits_db=[-6, 6], norm_p=2, + softquant=False, **kwargs): """ @@ -100,6 +102,9 @@ class LimitedAdaptiveComb1d(nn.Module): # network for generating convolution weights self.conv_kernel = nn.Linear(feature_dim, kernel_size) + if softquant: + self.conv_kernel = soft_quant(self.conv_kernel) + # comb filter gain self.filter_gain = nn.Linear(feature_dim, 1) diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py index a17b0e9b..af6ec0ec 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -34,7 +34,7 @@ import torch.nn.functional as F from utils.endoscopy import write_data from utils.ada_conv import adaconv_kernel - +from utils.softquant import soft_quant class LimitedAdaptiveConv1d(nn.Module): COUNTER = 1 @@ -51,6 +51,7 @@ class LimitedAdaptiveConv1d(nn.Module): gain_limits_db=[-6, 6], shape_gain_db=0, norm_p=2, + softquant=False, **kwargs): """ @@ -102,6 +103,8 @@ class LimitedAdaptiveConv1d(nn.Module): # network for generating convolution weights self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size) + if softquant: + self.conv_kernel = soft_quant(self.conv_kernel) self.shape_gain = min(1, 10**(shape_gain_db / 20)) diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 73d66bd5..7080a1b3 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -3,6 +3,7 @@ from torch import nn import torch.nn.functional as F from utils.complexity import _conv1d_flop_count +from utils.softquant import soft_quant class TDShaper(nn.Module): COUNTER = 1 @@ -12,7 +13,8 @@ class TDShaper(nn.Module): frame_size=160, avg_pool_k=4, innovate=False, - pool_after=False + pool_after=False, + softquant=False ): """ @@ -46,9 +48,13 @@ class TDShaper(nn.Module): self.env_dim = frame_size // avg_pool_k + 1 # feature transform - self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1_f = nn.Conv1d(self.feature_dim, frame_size, 2) + self.feature_alpha1_t = nn.Conv1d(self.env_dim, frame_size, 2) self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + if soft_quant: + self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) + if self.innovate: self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) @@ -61,7 +67,7 @@ class TDShaper(nn.Module): frame_rate = rate / self.frame_size - shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size + shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size if self.innovate: inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size @@ -110,9 +116,10 @@ class TDShaper(nn.Module): tenv = self.envelope_transform(x) # feature path - f = torch.cat((features, tenv), dim=-1) - f = F.pad(f.permute(0, 2, 1), [1, 0]) - alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) + f = F.pad(features.permute(0, 2, 1), [1, 0]) + t = F.pad(tenv.permute(0, 2, 1), [1, 0]) + alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t) + alpha = F.leaky_relu(alpha, 0.2) alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) alpha = alpha.permute(0, 2, 1) diff --git a/dnn/torch/osce/utils/softquant.py b/dnn/torch/osce/utils/softquant.py new file mode 100644 index 00000000..3917488b --- /dev/null +++ b/dnn/torch/osce/utils/softquant.py @@ -0,0 +1,55 @@ +import torch + + +class SoftQuant: + name: str + + def __init__(self, names: str, scale: float) -> None: + self.names = names + self.quantization_noise = None + self.scale = scale + + def __call__(self, module, inputs, *args): + if self.quantization_noise is None: + self.quantization_noise = dict() + for name in self.names: + weight = getattr(module, name) + self.quantization_noise[name] = \ + self.scale * weight.abs().max() * 2 * (torch.rand_like(weight) - 0.5) + with torch.no_grad(): + weight.data[:] = weight + self.quantization_noise[name] + else: + for name in self.names: + weight = getattr(module, name) + with torch.no_grad(): + weight.data[:] = weight - self.quantization_noise[name] + self.quantization_noise = None + + def apply(module, names=['weight'], scale=0.5/127): + fn = SoftQuant(names, scale) + + for name in names: + if not hasattr(module, name): + raise ValueError("") + + module.register_forward_pre_hook(fn) + module.register_forward_hook(fn) + + module + + return fn + + +def soft_quant(module, names=['weight'], scale=0.5/127): + fn = SoftQuant.apply(module, names, scale) + return module + +def remove_soft_quant(module, names=['weight']): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SoftQuant) and hook.names == names: + del module._forward_pre_hooks[k] + for k, hook in module._forward_hooks.items(): + if isinstance(hook, SoftQuant) and hook.names == names: + del module._forward_hooks[k] + + return module
\ No newline at end of file diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index 42137b26..0d731127 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -50,7 +50,8 @@ lace_setup = { 'pitch_embedding_dim': 64, 'pitch_max': 300, 'preemph': 0.85, - 'skip': 91 + 'skip': 91, + 'softquant': True } }, 'data': { @@ -63,7 +64,7 @@ lace_setup = { 'num_bands_clean_spec': 64, 'num_bands_noisy_spec': 18, 'noisy_spec_scale': 'opus', - 'pitch_hangover': 8, + 'pitch_hangover': 0, }, 'training': { 'batch_size': 256, @@ -106,7 +107,8 @@ nolace_setup = { 'pitch_embedding_dim': 64, 'pitch_max': 300, 'preemph': 0.85, - 'skip': 91 + 'skip': 91, + 'softquant': True } }, 'data': { @@ -119,7 +121,7 @@ nolace_setup = { 'num_bands_clean_spec': 64, 'num_bands_noisy_spec': 18, 'noisy_spec_scale': 'opus', - 'pitch_hangover': 8, + 'pitch_hangover': 0, }, 'training': { 'batch_size': 256, @@ -160,7 +162,8 @@ nolace_setup_adv = { 'pitch_embedding_dim': 64, 'pitch_max': 300, 'preemph': 0.85, - 'skip': 91 + 'skip': 91, + 'softquant': True } }, 'data': { @@ -173,7 +176,7 @@ nolace_setup_adv = { 'num_bands_clean_spec': 64, 'num_bands_noisy_spec': 18, 'noisy_spec_scale': 'opus', - 'pitch_hangover': 8, + 'pitch_hangover': 0, }, 'discriminator': { 'args': [], |