diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-05 15:26:25 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-05 15:26:25 +0300 |
commit | 11128e7c2a3de9f9c1b398a8c721847aac0a7273 (patch) | |
tree | 09d67375ab275974363a2ede3d8d1e76856b3cfd | |
parent | 30244724e15302c9d3033012c612844c591ace86 (diff) |
added sparsification to LACE model
-rw-r--r-- | dnn/torch/dnntools/dnntools/quantization/softquant.py | 3 | ||||
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/__init__.py | 4 | ||||
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py | 139 | ||||
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/utils.py | 43 | ||||
-rw-r--r-- | dnn/torch/osce/engine/engine.py | 6 | ||||
-rw-r--r-- | dnn/torch/osce/models/lace.py | 12 | ||||
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 22 |
7 files changed, 223 insertions, 6 deletions
diff --git a/dnn/torch/dnntools/dnntools/quantization/softquant.py b/dnn/torch/dnntools/dnntools/quantization/softquant.py index 5fca5b2a..ca82ceca 100644 --- a/dnn/torch/dnntools/dnntools/quantization/softquant.py +++ b/dnn/torch/dnntools/dnntools/quantization/softquant.py @@ -24,6 +24,7 @@ def q_scaled_noise(module, weight): if isinstance(module, torch.nn.Conv1d): w = weight.permute(0, 2, 1).flatten(1) noise = torch.rand_like(w) - 0.5 + noise[w == 0] = 0 # ignore zero entries from sparsification scale = compute_optimal_scale(w) noise = noise * scale.unsqueeze(-1) noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1) @@ -31,11 +32,13 @@ def q_scaled_noise(module, weight): i, o, k = weight.shape w = weight.permute(2, 1, 0).reshape(k * o, i) noise = torch.rand_like(w) - 0.5 + noise[w == 0] = 0 # ignore zero entries from sparsification scale = compute_optimal_scale(w) noise = noise * scale.unsqueeze(-1) noise = noise.reshape(k, o, i).permute(2, 1, 0) elif len(weight.shape) == 2: noise = torch.rand_like(weight) - 0.5 + noise[weight == 0] = 0 # ignore zero entries from sparsification scale = compute_optimal_scale(weight) noise = noise * scale.unsqueeze(-1) else: diff --git a/dnn/torch/dnntools/dnntools/sparsification/__init__.py b/dnn/torch/dnntools/dnntools/sparsification/__init__.py index 409e4977..fcc91746 100644 --- a/dnn/torch/dnntools/dnntools/sparsification/__init__.py +++ b/dnn/torch/dnntools/dnntools/sparsification/__init__.py @@ -1,4 +1,6 @@ from .gru_sparsifier import GRUSparsifier from .conv1d_sparsifier import Conv1dSparsifier +from .conv_transpose1d_sparsifier import ConvTranspose1dSparsifier from .linear_sparsifier import LinearSparsifier -from .common import sparsify_matrix, calculate_gru_flops_per_step
\ No newline at end of file +from .common import sparsify_matrix, calculate_gru_flops_per_step +from .utils import mark_for_sparsification, create_sparsifier
\ No newline at end of file diff --git a/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py new file mode 100644 index 00000000..73f2fc3d --- /dev/null +++ b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py @@ -0,0 +1,139 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch + +from .common import sparsify_matrix + + +class ConvTranspose1dSparsifier: + def __init__(self, task_list, start, stop, interval, exponent=3): + """ Sparsifier for torch.nn.GRUs + + Parameters: + ----------- + task_list : list + task_list contains a list of tuples (conv1d, params), where conv1d is an instance + of torch.nn.Conv1d and params is a tuple (density, [m, n]), + where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which + sparsification is applied. + + start : int + training step after which sparsification will be started. + + stop : int + training step after which sparsification will be completed. + + interval : int + sparsification interval for steps between start and stop. After stop sparsification will be + carried out after every call to GRUSparsifier.step() + + exponent : float + Interpolation exponent for sparsification interval. In step i sparsification will be carried out + with density (alpha + target_density * (1 * alpha)), where + alpha = ((stop - i) / (start - stop)) ** exponent + + Example: + -------- + >>> import torch + >>> conv = torch.nn.ConvTranspose1d(8, 16, 8) + >>> params = (0.2, [8, 4]) + >>> sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 50) + >>> for i in range(100): + ... sparsifier.step() + """ + # just copying parameters... + self.start = start + self.stop = stop + self.interval = interval + self.exponent = exponent + self.task_list = task_list + + # ... and setting counter to 0 + self.step_counter = 0 + + def step(self, verbose=False): + """ carries out sparsification step + + Call this function after optimizer.step in your + training loop. + + Parameters: + ---------- + verbose : bool + if true, densities are printed out + + Returns: + -------- + None + + """ + # compute current interpolation factor + self.step_counter += 1 + + if self.step_counter < self.start: + return + elif self.step_counter < self.stop: + # update only every self.interval-th interval + if self.step_counter % self.interval: + return + + alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent + else: + alpha = 0 + + + with torch.no_grad(): + for conv, params in self.task_list: + # reshape weight + i, o, k = conv.weight.shape + w = conv.weight.permute(2, 1, 0).reshape(k * o, i) + target_density, block_size = params + density = alpha + (1 - alpha) * target_density + w = sparsify_matrix(w, density, block_size) + w = w.reshape(k, o, i).permute(2, 1, 0) + conv.weight[:] = w + + if verbose: + print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}") + + +if __name__ == "__main__": + print("Testing sparsifier") + + import torch + conv = torch.nn.ConvTranspose1d(8, 16, 4, 4) + params = (0.2, [8, 4]) + + sparsifier = ConvTranspose1dSparsifier([(conv, params)], 0, 100, 5) + + for i in range(100): + sparsifier.step(verbose=True) + + print(conv.weight) diff --git a/dnn/torch/dnntools/dnntools/sparsification/utils.py b/dnn/torch/dnntools/dnntools/sparsification/utils.py new file mode 100644 index 00000000..da9dc89e --- /dev/null +++ b/dnn/torch/dnntools/dnntools/sparsification/utils.py @@ -0,0 +1,43 @@ +import torch + +from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier + +def mark_for_sparsification(module, params): + setattr(module, 'sparsify', True) + setattr(module, 'sparsification_params', params) + return module + +def create_sparsifier(module, start, stop, interval): + sparsifier_list = [] + for m in module.modules(): + if hasattr(m, 'sparsify'): + if isinstance(m, torch.nn.GRU): + sparsifier_list.append( + GRUSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + elif isinstance(m, torch.nn.Linear): + sparsifier_list.append( + LinearSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + elif isinstance(m, torch.nn.Conv1d): + sparsifier_list.append( + Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + elif isinstance(m, torch.nn.ConvTranspose1d): + sparsifier_list.append( + ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + else: + print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.") + + def sparsify(verbose=False): + for sparsifier in sparsifier_list: + sparsifier.step(verbose) + + return sparsify + +def estimate_parameters(module): + num_zero_parameters = 0 + if hasattr(module, 'sparsify'): + if isinstance(module, torch.nn.Conv1d): + pass diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py index 7688e9b4..8e5731de 100644 --- a/dnn/torch/osce/engine/engine.py +++ b/dnn/torch/osce/engine/engine.py @@ -46,6 +46,10 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, # update learning rate scheduler.step() + # sparsification + if hasattr(model, 'sparsify'): + model.sparsify(verbose=True) + # update running loss running_loss += float(loss.cpu()) @@ -73,8 +77,6 @@ def evaluate(model, criterion, dataloader, device, log_interval=10): for i, batch in enumerate(tepoch): - - # push batch to device for key in batch: batch[key] = batch[key].to(device) diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 5bc9fe41..78c1a717 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -41,6 +41,12 @@ from models.silk_feature_net_pl import SilkFeatureNetPL from models.silk_feature_net import SilkFeatureNet from .scale_embedding import ScaleEmbedding +import sys +sys.path.append('../dnntools') + +from dnntools.sparsification import create_sparsifier + + class LACE(NNSBase): """ Linear-Adaptive Coding Enhancer """ FRAME_SIZE=80 @@ -61,7 +67,8 @@ class LACE(NNSBase): hidden_feature_dim=64, partial_lookahead=True, norm_p=2, - softquant=False): + softquant=False, + sparsify=False): super().__init__(skip=skip, preemph=preemph) @@ -99,6 +106,9 @@ class LACE(NNSBase): # spectral shaping self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) + if sparsify: + self.sparsify = create_sparsifier(self, 500, 2000, 100) + def flop_count(self, rate=16000, verbose=False): frame_rate = rate / self.FRAME_SIZE diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 81beb36d..72f8531c 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -26,7 +26,8 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ """ - +import sys +sys.path.append('../dnntools') import torch from torch import nn @@ -35,13 +36,15 @@ import torch.nn.functional as F from utils.complexity import _conv1d_flop_count from utils.softquant import soft_quant +from dnntools.sparsification import mark_for_sparsification class SilkFeatureNetPL(nn.Module): """ feature net with partial lookahead """ def __init__(self, feature_dim=47, num_channels=256, hidden_feature_dim=64, - softquant=False): + softquant=False, + sparsify=True): super(SilkFeatureNetPL, self).__init__() @@ -60,6 +63,21 @@ class SilkFeatureNetPL(nn.Module): self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0']) + if sparsify: + mark_for_sparsification(self.conv2, (0.25, [8, 4])) + mark_for_sparsification(self.tconv, (0.25, [8, 4])) + mark_for_sparsification( + self.gru, + { + 'W_ir' : (0.25, [8, 4], False), + 'W_iz' : (0.25, [8, 4], False), + 'W_in' : (0.25, [8, 4], False), + 'W_hr' : (0.125, [8, 4], True), + 'W_hz' : (0.125, [8, 4], True), + 'W_hn' : (0.125, [8, 4], True), + } + ) + def flop_count(self, rate=200): count = 0 |