Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-05 15:26:25 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-05 15:26:25 +0300
commit11128e7c2a3de9f9c1b398a8c721847aac0a7273 (patch)
tree09d67375ab275974363a2ede3d8d1e76856b3cfd
parent30244724e15302c9d3033012c612844c591ace86 (diff)
added sparsification to LACE model
-rw-r--r--dnn/torch/dnntools/dnntools/quantization/softquant.py3
-rw-r--r--dnn/torch/dnntools/dnntools/sparsification/__init__.py4
-rw-r--r--dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py139
-rw-r--r--dnn/torch/dnntools/dnntools/sparsification/utils.py43
-rw-r--r--dnn/torch/osce/engine/engine.py6
-rw-r--r--dnn/torch/osce/models/lace.py12
-rw-r--r--dnn/torch/osce/models/silk_feature_net_pl.py22
7 files changed, 223 insertions, 6 deletions
diff --git a/dnn/torch/dnntools/dnntools/quantization/softquant.py b/dnn/torch/dnntools/dnntools/quantization/softquant.py
index 5fca5b2a..ca82ceca 100644
--- a/dnn/torch/dnntools/dnntools/quantization/softquant.py
+++ b/dnn/torch/dnntools/dnntools/quantization/softquant.py
@@ -24,6 +24,7 @@ def q_scaled_noise(module, weight):
if isinstance(module, torch.nn.Conv1d):
w = weight.permute(0, 2, 1).flatten(1)
noise = torch.rand_like(w) - 0.5
+ noise[w == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
@@ -31,11 +32,13 @@ def q_scaled_noise(module, weight):
i, o, k = weight.shape
w = weight.permute(2, 1, 0).reshape(k * o, i)
noise = torch.rand_like(w) - 0.5
+ noise[w == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(w)
noise = noise * scale.unsqueeze(-1)
noise = noise.reshape(k, o, i).permute(2, 1, 0)
elif len(weight.shape) == 2:
noise = torch.rand_like(weight) - 0.5
+ noise[weight == 0] = 0 # ignore zero entries from sparsification
scale = compute_optimal_scale(weight)
noise = noise * scale.unsqueeze(-1)
else:
diff --git a/dnn/torch/dnntools/dnntools/sparsification/__init__.py b/dnn/torch/dnntools/dnntools/sparsification/__init__.py
index 409e4977..fcc91746 100644
--- a/dnn/torch/dnntools/dnntools/sparsification/__init__.py
+++ b/dnn/torch/dnntools/dnntools/sparsification/__init__.py
@@ -1,4 +1,6 @@
from .gru_sparsifier import GRUSparsifier
from .conv1d_sparsifier import Conv1dSparsifier
+from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier
from .linear_sparsifier import LinearSparsifier
-from .common import sparsify_matrix, calculate_gru_flops_per_step \ No newline at end of file
+from .common import sparsify_matrix, calculate_gru_flops_per_step
+from .utils import mark_for_sparsification, create_sparsifier \ No newline at end of file
diff --git a/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py
new file mode 100644
index 00000000..73f2fc3d
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py
@@ -0,0 +1,139 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+from .common import sparsify_matrix
+
+
+class ConvTranspose1dSparsifier:
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (conv1d, params), where conv1d is an instance
+ of torch.nn.Conv1d and params is a tuple (density, [m, n]),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> conv = torch.nn.ConvTranspose1d(8, 16, 8)
+ >>> params = (0.2, [8, 4])
+ >>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+ # just copying parameters...
+ self.start = start
+ self.stop = stop
+ self.interval = interval
+ self.exponent = exponent
+ self.task_list = task_list
+
+ # ... and setting counter to 0
+ self.step_counter = 0
+
+ def step(self, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+ # compute current interpolation factor
+ self.step_counter += 1
+
+ if self.step_counter < self.start:
+ return
+ elif self.step_counter < self.stop:
+ # update only every self.interval-th interval
+ if self.step_counter % self.interval:
+ return
+
+ alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
+ else:
+ alpha = 0
+
+
+ with torch.no_grad():
+ for conv, params in self.task_list:
+ # reshape weight
+ i, o, k = conv.weight.shape
+ w = conv.weight.permute(2, 1, 0).reshape(k * o, i)
+ target_density, block_size = params
+ density = alpha + (1 - alpha) * target_density
+ w = sparsify_matrix(w, density, block_size)
+ w = w.reshape(k, o, i).permute(2, 1, 0)
+ conv.weight[:] = w
+
+ if verbose:
+ print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ import torch
+ conv = torch.nn.ConvTranspose1d(8, 16, 4, 4)
+ params = (0.2, [8, 4])
+
+ sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
+
+ print(conv.weight)
diff --git a/dnn/torch/dnntools/dnntools/sparsification/utils.py b/dnn/torch/dnntools/dnntools/sparsification/utils.py
new file mode 100644
index 00000000..da9dc89e
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/utils.py
@@ -0,0 +1,43 @@
+import torch
+
+from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier
+
+def mark_for_sparsification(module, params):
+ setattr(module, 'sparsify', True)
+ setattr(module, 'sparsification_params', params)
+ return module
+
+def create_sparsifier(module, start, stop, interval):
+ sparsifier_list = []
+ for m in module.modules():
+ if hasattr(m, 'sparsify'):
+ if isinstance(m, torch.nn.GRU):
+ sparsifier_list.append(
+ GRUSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ elif isinstance(m, torch.nn.Linear):
+ sparsifier_list.append(
+ LinearSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ elif isinstance(m, torch.nn.Conv1d):
+ sparsifier_list.append(
+ Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ elif isinstance(m, torch.nn.ConvTranspose1d):
+ sparsifier_list.append(
+ ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval)
+ )
+ else:
+ print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.")
+
+ def sparsify(verbose=False):
+ for sparsifier in sparsifier_list:
+ sparsifier.step(verbose)
+
+ return sparsify
+
+def estimate_parameters(module):
+ num_zero_parameters = 0
+ if hasattr(module, 'sparsify'):
+ if isinstance(module, torch.nn.Conv1d):
+ pass
diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py
index 7688e9b4..8e5731de 100644
--- a/dnn/torch/osce/engine/engine.py
+++ b/dnn/torch/osce/engine/engine.py
@@ -46,6 +46,10 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
# update learning rate
scheduler.step()
+ # sparsification
+ if hasattr(model, 'sparsify'):
+ model.sparsify(verbose=True)
+
# update running loss
running_loss += float(loss.cpu())
@@ -73,8 +77,6 @@ def evaluate(model, criterion, dataloader, device, log_interval=10):
for i, batch in enumerate(tepoch):
-
-
# push batch to device
for key in batch:
batch[key] = batch[key].to(device)
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
index 5bc9fe41..78c1a717 100644
--- a/dnn/torch/osce/models/lace.py
+++ b/dnn/torch/osce/models/lace.py
@@ -41,6 +41,12 @@ from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding
+import sys
+sys.path.append('../dnntools')
+
+from dnntools.sparsification import create_sparsifier
+
+
class LACE(NNSBase):
""" Linear-Adaptive Coding Enhancer """
FRAME_SIZE=80
@@ -61,7 +67,8 @@ class LACE(NNSBase):
hidden_feature_dim=64,
partial_lookahead=True,
norm_p=2,
- softquant=False):
+ softquant=False,
+ sparsify=False):
super().__init__(skip=skip, preemph=preemph)
@@ -99,6 +106,9 @@ class LACE(NNSBase):
# spectral shaping
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant)
+ if sparsify:
+ self.sparsify = create_sparsifier(self, 500, 2000, 100)
+
def flop_count(self, rate=16000, verbose=False):
frame_rate = rate / self.FRAME_SIZE
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py
index 81beb36d..72f8531c 100644
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -26,7 +26,8 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
-
+import sys
+sys.path.append('../dnntools')
import torch
from torch import nn
@@ -35,13 +36,15 @@ import torch.nn.functional as F
from utils.complexity import _conv1d_flop_count
from utils.softquant import soft_quant
+from dnntools.sparsification import mark_for_sparsification
class SilkFeatureNetPL(nn.Module):
""" feature net with partial lookahead """
def __init__(self,
feature_dim=47,
num_channels=256,
hidden_feature_dim=64,
- softquant=False):
+ softquant=False,
+ sparsify=True):
super(SilkFeatureNetPL, self).__init__()
@@ -60,6 +63,21 @@ class SilkFeatureNetPL(nn.Module):
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
+ if sparsify:
+ mark_for_sparsification(self.conv2, (0.25, [8, 4]))
+ mark_for_sparsification(self.tconv, (0.25, [8, 4]))
+ mark_for_sparsification(
+ self.gru,
+ {
+ 'W_ir' : (0.25, [8, 4], False),
+ 'W_iz' : (0.25, [8, 4], False),
+ 'W_in' : (0.25, [8, 4], False),
+ 'W_hr' : (0.125, [8, 4], True),
+ 'W_hz' : (0.125, [8, 4], True),
+ 'W_hn' : (0.125, [8, 4], True),
+ }
+ )
+
def flop_count(self, rate=200):
count = 0