Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/utils/layers/td_shaper.py')
-rw-r--r--dnn/torch/osce/utils/layers/td_shaper.py32
1 files changed, 21 insertions, 11 deletions
diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py
index 73d66bd5..fa7bf348 100644
--- a/dnn/torch/osce/utils/layers/td_shaper.py
+++ b/dnn/torch/osce/utils/layers/td_shaper.py
@@ -3,6 +3,7 @@ from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
+from utils.softquant import soft_quant
class TDShaper(nn.Module):
COUNTER = 1
@@ -12,7 +13,9 @@ class TDShaper(nn.Module):
frame_size=160,
avg_pool_k=4,
innovate=False,
- pool_after=False
+ pool_after=False,
+ softquant=False,
+ apply_weight_norm=False
):
"""
@@ -45,23 +48,29 @@ class TDShaper(nn.Module):
assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1
+ norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
+
# feature transform
- self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
- self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
+ self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
+ self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
+ self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
+
+ if softquant:
+ self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
if self.innovate:
- self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
- self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
+ self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
+ self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
- self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2)
- self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2)
+ self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
+ self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
def flop_count(self, rate):
frame_rate = rate / self.frame_size
- shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
+ shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
if self.innovate:
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
@@ -110,9 +119,10 @@ class TDShaper(nn.Module):
tenv = self.envelope_transform(x)
# feature path
- f = torch.cat((features, tenv), dim=-1)
- f = F.pad(f.permute(0, 2, 1), [1, 0])
- alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
+ f = F.pad(features.permute(0, 2, 1), [1, 0])
+ t = F.pad(tenv.permute(0, 2, 1), [1, 0])
+ alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
+ alpha = F.leaky_relu(alpha, 0.2)
alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
alpha = alpha.permute(0, 2, 1)