diff options
Diffstat (limited to 'dnn/torch/osce/utils/layers/td_shaper.py')
-rw-r--r-- | dnn/torch/osce/utils/layers/td_shaper.py | 32 |
1 files changed, 21 insertions, 11 deletions
diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 73d66bd5..fa7bf348 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -3,6 +3,7 @@ from torch import nn import torch.nn.functional as F from utils.complexity import _conv1d_flop_count +from utils.softquant import soft_quant class TDShaper(nn.Module): COUNTER = 1 @@ -12,7 +13,9 @@ class TDShaper(nn.Module): frame_size=160, avg_pool_k=4, innovate=False, - pool_after=False + pool_after=False, + softquant=False, + apply_weight_norm=False ): """ @@ -45,23 +48,29 @@ class TDShaper(nn.Module): assert frame_size % avg_pool_k == 0 self.env_dim = frame_size // avg_pool_k + 1 + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # feature transform - self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) - self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2)) + self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2)) + self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2)) + + if softquant: + self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) if self.innovate: - self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) - self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) + self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) - self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2) - self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2)) + self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2)) def flop_count(self, rate): frame_rate = rate / self.frame_size - shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size + shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size if self.innovate: inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size @@ -110,9 +119,10 @@ class TDShaper(nn.Module): tenv = self.envelope_transform(x) # feature path - f = torch.cat((features, tenv), dim=-1) - f = F.pad(f.permute(0, 2, 1), [1, 0]) - alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) + f = F.pad(features.permute(0, 2, 1), [1, 0]) + t = F.pad(tenv.permute(0, 2, 1), [1, 0]) + alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t) + alpha = F.leaky_relu(alpha, 0.2) alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) alpha = alpha.permute(0, 2, 1) |