Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-05 13:29:38 +0300
commit35ee397e060283d30c098ae5e17836316bbec08b (patch)
tree4a81b86f8c0738bbdc7147214c53fda54cd0f3f3 /dnn/torch/lpcnet/utils/sparsification
parent90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff)
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
Diffstat (limited to 'dnn/torch/lpcnet/utils/sparsification')
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/__init__.py2
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/common.py92
-rw-r--r--dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py158
3 files changed, 252 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/utils/sparsification/__init__.py b/dnn/torch/lpcnet/utils/sparsification/__init__.py
new file mode 100644
index 00000000..ebfa9d9a
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/__init__.py
@@ -0,0 +1,2 @@
+from .gru_sparsifier import GRUSparsifier
+from .common import sparsify_matrix, calculate_gru_flops_per_step \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py
new file mode 100644
index 00000000..34989d4b
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/common.py
@@ -0,0 +1,92 @@
+import torch
+
+def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
+ """ sparsifies matrix with specified block size
+
+ Parameters:
+ -----------
+ matrix : torch.tensor
+ matrix to sparsify
+ density : int
+ target density
+ block_size : [int, int]
+ block size dimensions
+ keep_diagonal : bool
+ If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
+ """
+
+ m, n = matrix.shape
+ m1, n1 = block_size
+
+ if m % m1 or n % n1:
+ raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
+
+ # extract diagonal if keep_diagonal = True
+ if keep_diagonal:
+ if m != n:
+ raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
+
+ to_spare = torch.diag(torch.diag(matrix))
+ matrix = matrix - to_spare
+ else:
+ to_spare = torch.zeros_like(matrix)
+
+ # calculate energy in sub-blocks
+ x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
+ x = x ** 2
+ block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
+
+ number_of_blocks = (m * n) // (m1 * n1)
+ number_of_survivors = round(number_of_blocks * density)
+
+ # masking threshold
+ if number_of_survivors == 0:
+ threshold = 0
+ else:
+ threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
+
+ # create mask
+ mask = torch.ones_like(block_energies)
+ mask[block_energies < threshold] = 0
+ mask = torch.repeat_interleave(mask, m1, dim=0)
+ mask = torch.repeat_interleave(mask, n1, dim=1)
+
+ # perform masking
+ masked_matrix = mask * matrix + to_spare
+
+ if return_mask:
+ return masked_matrix, mask
+ else:
+ return masked_matrix
+
+def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
+ input_size = gru.input_size
+ hidden_size = gru.hidden_size
+ flops = 0
+
+ input_density = (
+ sparsification_dict.get('W_ir', [1])[0]
+ + sparsification_dict.get('W_in', [1])[0]
+ + sparsification_dict.get('W_iz', [1])[0]
+ ) / 3
+
+ recurrent_density = (
+ sparsification_dict.get('W_hr', [1])[0]
+ + sparsification_dict.get('W_hn', [1])[0]
+ + sparsification_dict.get('W_hz', [1])[0]
+ ) / 3
+
+ # input matrix vector multiplications
+ if not drop_input:
+ flops += 2 * 3 * input_size * hidden_size * input_density
+
+ # recurrent matrix vector multiplications
+ flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
+
+ # biases
+ flops += 6 * hidden_size
+
+ # activations estimated by 10 flops per activation
+ flops += 30 * hidden_size
+
+ return flops \ No newline at end of file
diff --git a/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
new file mode 100644
index 00000000..865f3a7d
--- /dev/null
+++ b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py
@@ -0,0 +1,158 @@
+import torch
+
+from .common import sparsify_matrix
+
+
+class GRUSparsifier:
+ def __init__(self, task_list, start, stop, interval, exponent=3):
+ """ Sparsifier for torch.nn.GRUs
+
+ Parameters:
+ -----------
+ task_list : list
+ task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
+ of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
+ 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
+ update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
+ sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
+ should be kept.
+
+ start : int
+ training step after which sparsification will be started.
+
+ stop : int
+ training step after which sparsification will be completed.
+
+ interval : int
+ sparsification interval for steps between start and stop. After stop sparsification will be
+ carried out after every call to GRUSparsifier.step()
+
+ exponent : float
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
+ with density (alpha + target_density * (1 * alpha)), where
+ alpha = ((stop - i) / (start - stop)) ** exponent
+
+ Example:
+ --------
+ >>> import torch
+ >>> gru = torch.nn.GRU(10, 20)
+ >>> sparsify_dict = {
+ ... 'W_ir' : (0.5, [2, 2], False),
+ ... 'W_iz' : (0.6, [2, 2], False),
+ ... 'W_in' : (0.7, [2, 2], False),
+ ... 'W_hr' : (0.1, [4, 4], True),
+ ... 'W_hz' : (0.2, [4, 4], True),
+ ... 'W_hn' : (0.3, [4, 4], True),
+ ... }
+ >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
+ >>> for i in range(100):
+ ... sparsifier.step()
+ """
+ # just copying parameters...
+ self.start = start
+ self.stop = stop
+ self.interval = interval
+ self.exponent = exponent
+ self.task_list = task_list
+
+ # ... and setting counter to 0
+ self.step_counter = 0
+
+ self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
+
+ def step(self, verbose=False):
+ """ carries out sparsification step
+
+ Call this function after optimizer.step in your
+ training loop.
+
+ Parameters:
+ ----------
+ verbose : bool
+ if true, densities are printed out
+
+ Returns:
+ --------
+ None
+
+ """
+ # compute current interpolation factor
+ self.step_counter += 1
+
+ if self.step_counter < self.start:
+ return
+ elif self.step_counter < self.stop:
+ # update only every self.interval-th interval
+ if self.step_counter % self.interval:
+ return
+
+ alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
+ else:
+ alpha = 0
+
+
+ with torch.no_grad():
+ for gru, params in self.task_list:
+ hidden_size = gru.hidden_size
+
+ # input weights
+ for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
+ if key in params:
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+
+ gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
+ density, # density
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
+ print(f"sparsification mask {key} changed for gru {gru}")
+
+ self.last_masks[key] = new_mask
+
+ # recurrent weights
+ for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
+ if key in params:
+ density = alpha + (1 - alpha) * params[key][0]
+ if verbose:
+ print(f"[{self.step_counter}]: {key} density: {density}")
+ gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
+ gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
+ density,
+ params[key][1], # block_size
+ params[key][2], # keep_diagonal (might want to set this to False)
+ return_mask=True
+ )
+
+ if type(self.last_masks[key]) != type(None):
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
+ print(f"sparsification mask {key} changed for gru {gru}")
+
+ self.last_masks[key] = new_mask
+
+
+
+if __name__ == "__main__":
+ print("Testing sparsifier")
+
+ gru = torch.nn.GRU(10, 20)
+ sparsify_dict = {
+ 'W_ir' : (0.5, [2, 2], False),
+ 'W_iz' : (0.6, [2, 2], False),
+ 'W_in' : (0.7, [2, 2], False),
+ 'W_hr' : (0.1, [4, 4], True),
+ 'W_hz' : (0.2, [4, 4], True),
+ 'W_hn' : (0.3, [4, 4], True),
+ }
+
+ sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
+
+ for i in range(100):
+ sparsifier.step(verbose=True)