diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-09-05 13:29:38 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-09-05 13:29:38 +0300 |
commit | 35ee397e060283d30c098ae5e17836316bbec08b (patch) | |
tree | 4a81b86f8c0738bbdc7147214c53fda54cd0f3f3 /dnn/torch/lpcnet/utils/sparsification | |
parent | 90a171c1c2c9839b561f8446ad2bbfe48eacf255 (diff) |
added LPCNet torch implementation
Signed-off-by: Jan Buethe <jbuethe@amazon.de>
Diffstat (limited to 'dnn/torch/lpcnet/utils/sparsification')
-rw-r--r-- | dnn/torch/lpcnet/utils/sparsification/__init__.py | 2 | ||||
-rw-r--r-- | dnn/torch/lpcnet/utils/sparsification/common.py | 92 | ||||
-rw-r--r-- | dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py | 158 |
3 files changed, 252 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/utils/sparsification/__init__.py b/dnn/torch/lpcnet/utils/sparsification/__init__.py new file mode 100644 index 00000000..ebfa9d9a --- /dev/null +++ b/dnn/torch/lpcnet/utils/sparsification/__init__.py @@ -0,0 +1,2 @@ +from .gru_sparsifier import GRUSparsifier +from .common import sparsify_matrix, calculate_gru_flops_per_step
\ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/sparsification/common.py b/dnn/torch/lpcnet/utils/sparsification/common.py new file mode 100644 index 00000000..34989d4b --- /dev/null +++ b/dnn/torch/lpcnet/utils/sparsification/common.py @@ -0,0 +1,92 @@ +import torch + +def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False): + """ sparsifies matrix with specified block size + + Parameters: + ----------- + matrix : torch.tensor + matrix to sparsify + density : int + target density + block_size : [int, int] + block size dimensions + keep_diagonal : bool + If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False + """ + + m, n = matrix.shape + m1, n1 = block_size + + if m % m1 or n % n1: + raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}") + + # extract diagonal if keep_diagonal = True + if keep_diagonal: + if m != n: + raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True") + + to_spare = torch.diag(torch.diag(matrix)) + matrix = matrix - to_spare + else: + to_spare = torch.zeros_like(matrix) + + # calculate energy in sub-blocks + x = torch.reshape(matrix, (m // m1, m1, n // n1, n1)) + x = x ** 2 + block_energies = torch.sum(torch.sum(x, dim=3), dim=1) + + number_of_blocks = (m * n) // (m1 * n1) + number_of_survivors = round(number_of_blocks * density) + + # masking threshold + if number_of_survivors == 0: + threshold = 0 + else: + threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors] + + # create mask + mask = torch.ones_like(block_energies) + mask[block_energies < threshold] = 0 + mask = torch.repeat_interleave(mask, m1, dim=0) + mask = torch.repeat_interleave(mask, n1, dim=1) + + # perform masking + masked_matrix = mask * matrix + to_spare + + if return_mask: + return masked_matrix, mask + else: + return masked_matrix + +def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False): + input_size = gru.input_size + hidden_size = gru.hidden_size + flops = 0 + + input_density = ( + sparsification_dict.get('W_ir', [1])[0] + + sparsification_dict.get('W_in', [1])[0] + + sparsification_dict.get('W_iz', [1])[0] + ) / 3 + + recurrent_density = ( + sparsification_dict.get('W_hr', [1])[0] + + sparsification_dict.get('W_hn', [1])[0] + + sparsification_dict.get('W_hz', [1])[0] + ) / 3 + + # input matrix vector multiplications + if not drop_input: + flops += 2 * 3 * input_size * hidden_size * input_density + + # recurrent matrix vector multiplications + flops += 2 * 3 * hidden_size * hidden_size * recurrent_density + + # biases + flops += 6 * hidden_size + + # activations estimated by 10 flops per activation + flops += 30 * hidden_size + + return flops
\ No newline at end of file diff --git a/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py new file mode 100644 index 00000000..865f3a7d --- /dev/null +++ b/dnn/torch/lpcnet/utils/sparsification/gru_sparsifier.py @@ -0,0 +1,158 @@ +import torch + +from .common import sparsify_matrix + + +class GRUSparsifier: + def __init__(self, task_list, start, stop, interval, exponent=3): + """ Sparsifier for torch.nn.GRUs + + Parameters: + ----------- + task_list : list + task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance + of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in', + 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset, + update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal), + where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which + sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal + should be kept. + + start : int + training step after which sparsification will be started. + + stop : int + training step after which sparsification will be completed. + + interval : int + sparsification interval for steps between start and stop. After stop sparsification will be + carried out after every call to GRUSparsifier.step() + + exponent : float + Interpolation exponent for sparsification interval. In step i sparsification will be carried out + with density (alpha + target_density * (1 * alpha)), where + alpha = ((stop - i) / (start - stop)) ** exponent + + Example: + -------- + >>> import torch + >>> gru = torch.nn.GRU(10, 20) + >>> sparsify_dict = { + ... 'W_ir' : (0.5, [2, 2], False), + ... 'W_iz' : (0.6, [2, 2], False), + ... 'W_in' : (0.7, [2, 2], False), + ... 'W_hr' : (0.1, [4, 4], True), + ... 'W_hz' : (0.2, [4, 4], True), + ... 'W_hn' : (0.3, [4, 4], True), + ... } + >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50) + >>> for i in range(100): + ... sparsifier.step() + """ + # just copying parameters... + self.start = start + self.stop = stop + self.interval = interval + self.exponent = exponent + self.task_list = task_list + + # ... and setting counter to 0 + self.step_counter = 0 + + self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']} + + def step(self, verbose=False): + """ carries out sparsification step + + Call this function after optimizer.step in your + training loop. + + Parameters: + ---------- + verbose : bool + if true, densities are printed out + + Returns: + -------- + None + + """ + # compute current interpolation factor + self.step_counter += 1 + + if self.step_counter < self.start: + return + elif self.step_counter < self.stop: + # update only every self.interval-th interval + if self.step_counter % self.interval: + return + + alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent + else: + alpha = 0 + + + with torch.no_grad(): + for gru, params in self.task_list: + hidden_size = gru.hidden_size + + # input weights + for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): + if key in params: + density = alpha + (1 - alpha) * params[key][0] + if verbose: + print(f"[{self.step_counter}]: {key} density: {density}") + + gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], + density, # density + params[key][1], # block_size + params[key][2], # keep_diagonal (might want to set this to False) + return_mask=True + ) + + if type(self.last_masks[key]) != type(None): + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: + print(f"sparsification mask {key} changed for gru {gru}") + + self.last_masks[key] = new_mask + + # recurrent weights + for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): + if key in params: + density = alpha + (1 - alpha) * params[key][0] + if verbose: + print(f"[{self.step_counter}]: {key} density: {density}") + gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], + density, + params[key][1], # block_size + params[key][2], # keep_diagonal (might want to set this to False) + return_mask=True + ) + + if type(self.last_masks[key]) != type(None): + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: + print(f"sparsification mask {key} changed for gru {gru}") + + self.last_masks[key] = new_mask + + + +if __name__ == "__main__": + print("Testing sparsifier") + + gru = torch.nn.GRU(10, 20) + sparsify_dict = { + 'W_ir' : (0.5, [2, 2], False), + 'W_iz' : (0.6, [2, 2], False), + 'W_in' : (0.7, [2, 2], False), + 'W_hr' : (0.1, [4, 4], True), + 'W_hz' : (0.2, [4, 4], True), + 'W_hn' : (0.3, [4, 4], True), + } + + sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10) + + for i in range(100): + sparsifier.step(verbose=True) |