Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py')
-rw-r--r--dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
new file mode 100644
index 00000000..59251ddd
--- /dev/null
+++ b/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
@@ -0,0 +1,128 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+
+from .base_sparsifier import BaseSparsifier
+from .common import sparsify_matrix
+
+
+class LinearSparsifier(BaseSparsifier):
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (linear, params), where linear is an instance
+ of torch.nn.Linear and params is a tuple (density, [m, n]),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> linear = torch.nn.Linear(8, 16)
+ >>> params = (0.2, [8, 4])
+ >>> sparsifier = LinearSparsifier([(linear, params)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+
+ super().__init__(task_list, start, stop, interval, exponent=3)
+
+ self.last_mask = None
+
+ def sparsify(self, alpha, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ alpha : float
+ density interpolation parameter (1: dense, 0: target density)
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+
+ with torch.no_grad():
+ for linear, params in self.task_list:
+ if hasattr(linear, 'weight_v'):
+ weight = linear.weight_v
+ else:
+ weight = linear.weight
+ target_density, block_size = params
+ density = alpha + (1 - alpha) * target_density
+ weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)
+
+ if self.last_mask is not None:
+ if not torch.all(self.last_mask * new_mask == new_mask) and debug:
+ print("weight resurrection in conv.weight")
+
+ self.last_mask = new_mask
+
+ if verbose:
+ print(f"linear_sparsifier[{self.step_counter}]: {density=}")
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ import torch
+ linear = torch.nn.Linear(8, 16)
+ params = (0.2, [4, 2])
+
+ sparsifier = LinearSparsifier([(linear, params)], 0, 100, 5)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)
+
+ print(linear.weight)