diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-12-18 14:19:55 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:42:54 +0300 |
commit | bfcbbbcb38ce399a6f4cf629c80097bc7d1dadbc (patch) | |
tree | 002d54b3a7d8ae44a641623c95234b06f3d2be40 /dnn/torch/osce/utils | |
parent | 4f311a1ad44f1b7bd60e32984ca0604c46b6c593 (diff) |
added softquant option and activate tconv layer
Diffstat (limited to 'dnn/torch/osce/utils')
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py | 5 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py | 5 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/td_shaper.py | 19 | ||||
-rw-r--r-- | dnn/torch/osce/utils/softquant.py | 55 | ||||
-rw-r--r-- | dnn/torch/osce/utils/templates.py | 15 |
5 files changed, 86 insertions, 13 deletions
diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py index 3bb6fa07..a116fd72 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -32,6 +32,7 @@ from torch import nn import torch.nn.functional as F from utils.endoscopy import write_data +from utils.softquant import soft_quant class LimitedAdaptiveComb1d(nn.Module): COUNTER = 1 @@ -47,6 +48,7 @@ class LimitedAdaptiveComb1d(nn.Module): gain_limit_db=10, global_gain_limits_db=[-6, 6], norm_p=2, + softquant=False, **kwargs): """ @@ -100,6 +102,9 @@ class LimitedAdaptiveComb1d(nn.Module): # network for generating convolution weights self.conv_kernel = nn.Linear(feature_dim, kernel_size) + if softquant: + self.conv_kernel = soft_quant(self.conv_kernel) + # comb filter gain self.filter_gain = nn.Linear(feature_dim, 1) diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py index a17b0e9b..af6ec0ec 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -34,7 +34,7 @@ import torch.nn.functional as F from utils.endoscopy import write_data from utils.ada_conv import adaconv_kernel - +from utils.softquant import soft_quant class LimitedAdaptiveConv1d(nn.Module): COUNTER = 1 @@ -51,6 +51,7 @@ class LimitedAdaptiveConv1d(nn.Module): gain_limits_db=[-6, 6], shape_gain_db=0, norm_p=2, + softquant=False, **kwargs): """ @@ -102,6 +103,8 @@ class LimitedAdaptiveConv1d(nn.Module): # network for generating convolution weights self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size) + if softquant: + self.conv_kernel = soft_quant(self.conv_kernel) self.shape_gain = min(1, 10**(shape_gain_db / 20)) diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 73d66bd5..7080a1b3 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -3,6 +3,7 @@ from torch import nn import torch.nn.functional as F from utils.complexity import _conv1d_flop_count +from utils.softquant import soft_quant class TDShaper(nn.Module): COUNTER = 1 @@ -12,7 +13,8 @@ class TDShaper(nn.Module): frame_size=160, avg_pool_k=4, innovate=False, - pool_after=False + pool_after=False, + softquant=False ): """ @@ -46,9 +48,13 @@ class TDShaper(nn.Module): self.env_dim = frame_size // avg_pool_k + 1 # feature transform - self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1_f = nn.Conv1d(self.feature_dim, frame_size, 2) + self.feature_alpha1_t = nn.Conv1d(self.env_dim, frame_size, 2) self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + if soft_quant: + self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) + if self.innovate: self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) @@ -61,7 +67,7 @@ class TDShaper(nn.Module): frame_rate = rate / self.frame_size - shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size + shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size if self.innovate: inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size @@ -110,9 +116,10 @@ class TDShaper(nn.Module): tenv = self.envelope_transform(x) # feature path - f = torch.cat((features, tenv), dim=-1) - f = F.pad(f.permute(0, 2, 1), [1, 0]) - alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) + f = F.pad(features.permute(0, 2, 1), [1, 0]) + t = F.pad(tenv.permute(0, 2, 1), [1, 0]) + alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t) + alpha = F.leaky_relu(alpha, 0.2) alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) alpha = alpha.permute(0, 2, 1) diff --git a/dnn/torch/osce/utils/softquant.py b/dnn/torch/osce/utils/softquant.py new file mode 100644 index 00000000..3917488b --- /dev/null +++ b/dnn/torch/osce/utils/softquant.py @@ -0,0 +1,55 @@ +import torch + + +class SoftQuant: + name: str + + def __init__(self, names: str, scale: float) -> None: + self.names = names + self.quantization_noise = None + self.scale = scale + + def __call__(self, module, inputs, *args): + if self.quantization_noise is None: + self.quantization_noise = dict() + for name in self.names: + weight = getattr(module, name) + self.quantization_noise[name] = \ + self.scale * weight.abs().max() * 2 * (torch.rand_like(weight) - 0.5) + with torch.no_grad(): + weight.data[:] = weight + self.quantization_noise[name] + else: + for name in self.names: + weight = getattr(module, name) + with torch.no_grad(): + weight.data[:] = weight - self.quantization_noise[name] + self.quantization_noise = None + + def apply(module, names=['weight'], scale=0.5/127): + fn = SoftQuant(names, scale) + + for name in names: + if not hasattr(module, name): + raise ValueError("") + + module.register_forward_pre_hook(fn) + module.register_forward_hook(fn) + + module + + return fn + + +def soft_quant(module, names=['weight'], scale=0.5/127): + fn = SoftQuant.apply(module, names, scale) + return module + +def remove_soft_quant(module, names=['weight']): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SoftQuant) and hook.names == names: + del module._forward_pre_hooks[k] + for k, hook in module._forward_hooks.items(): + if isinstance(hook, SoftQuant) and hook.names == names: + del module._forward_hooks[k] + + return module
\ No newline at end of file diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index 42137b26..0d731127 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -50,7 +50,8 @@ lace_setup = { 'pitch_embedding_dim': 64, 'pitch_max': 300, 'preemph': 0.85, - 'skip': 91 + 'skip': 91, + 'softquant': True } }, 'data': { @@ -63,7 +64,7 @@ lace_setup = { 'num_bands_clean_spec': 64, 'num_bands_noisy_spec': 18, 'noisy_spec_scale': 'opus', - 'pitch_hangover': 8, + 'pitch_hangover': 0, }, 'training': { 'batch_size': 256, @@ -106,7 +107,8 @@ nolace_setup = { 'pitch_embedding_dim': 64, 'pitch_max': 300, 'preemph': 0.85, - 'skip': 91 + 'skip': 91, + 'softquant': True } }, 'data': { @@ -119,7 +121,7 @@ nolace_setup = { 'num_bands_clean_spec': 64, 'num_bands_noisy_spec': 18, 'noisy_spec_scale': 'opus', - 'pitch_hangover': 8, + 'pitch_hangover': 0, }, 'training': { 'batch_size': 256, @@ -160,7 +162,8 @@ nolace_setup_adv = { 'pitch_embedding_dim': 64, 'pitch_max': 300, 'preemph': 0.85, - 'skip': 91 + 'skip': 91, + 'softquant': True } }, 'data': { @@ -173,7 +176,7 @@ nolace_setup_adv = { 'num_bands_clean_spec': 64, 'num_bands_noisy_spec': 18, 'noisy_spec_scale': 'opus', - 'pitch_hangover': 8, + 'pitch_hangover': 0, }, 'discriminator': { 'args': [], |