1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
|
import torch
from torch import nn
import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
class TDShaper(nn.Module):
COUNTER = 1
def __init__(self,
feature_dim,
frame_size=160,
avg_pool_k=4,
innovate=False,
pool_after=False,
kernel_size=2,
tanh_activation=False,
):
"""
Parameters:
-----------
feature_dim : int
dimension of input features
frame_size : int
frame size
avg_pool_k : int, optional
kernel size and stride for avg pooling
padding : List[int, int]
"""
super().__init__()
self.feature_dim = feature_dim
self.frame_size = frame_size
self.avg_pool_k = avg_pool_k
self.innovate = innovate
self.pool_after = pool_after
self.kernel_size = kernel_size
self.tanh_activation = tanh_activation
assert frame_size % avg_pool_k == 0
self.env_dim = frame_size // avg_pool_k + 1
# feature transform
self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, kernel_size)
self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, kernel_size)
if self.innovate:
self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, kernel_size)
self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, kernel_size)
self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, kernel_size)
self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, kernel_size)
self.activation = torch.tanh if self.tanh_activation else torch.nn.LeakyReLU(0.2)
def flop_count(self, rate):
frame_rate = rate / self.frame_size
shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
if self.innovate:
inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
else:
inno_flops = 0
return shape_flops + inno_flops
def envelope_transform(self, x):
x = torch.abs(x)
if self.pool_after:
x = torch.log(x + .5**16)
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
else:
x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
x = torch.log(x + .5**16)
x = x.reshape(x.size(0), -1, self.env_dim - 1)
avg_x = torch.mean(x, -1, keepdim=True)
x = torch.cat((x - avg_x, avg_x), dim=-1)
return x
def forward(self, x, features, state=None, return_state=False, debug=False):
""" innovate signal parts with temporal shaping
Parameters:
-----------
x : torch.tensor
input signal of shape (batch_size, 1, num_samples)
features : torch.tensor
frame-wise features of shape (batch_size, num_frames, feature_dim)
"""
batch_size = x.size(0)
num_frames = features.size(1)
num_samples = x.size(2)
padding = 2 * self.kernel_size - 2
# generate temporal envelope
tenv = self.envelope_transform(x)
# feature path
f = torch.cat((features, tenv), dim=-1).permute(0, 2, 1)
if state is not None:
f = torch.cat((state, f), dim=-1)
else:
f = F.pad(f, [padding, 0])
alpha = self.activation(self.feature_alpha1(f))
alpha = torch.exp(self.feature_alpha2(alpha))
alpha = alpha.permute(0, 2, 1)
if self.innovate:
inno_alpha = self.activation(self.feature_alpha1b(f))
inno_alpha = torch.exp(self.feature_alpha2b(inno_alpha))
inno_alpha = inno_alpha.permute(0, 2, 1)
inno_x = self.activation(self.feature_alpha1c(f))
inno_x = torch.tanh(self.feature_alpha2c(inno_x))
inno_x = inno_x.permute(0, 2, 1)
# signal path
y = x.reshape(batch_size, num_frames, -1)
y = alpha * y
if self.innovate:
y = y + inno_alpha * inno_x
y = y.reshape(batch_size, 1, num_samples)
if return_state:
new_state = f[..., -padding:]
return y, new_state
else:
return y
|