diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-09 17:22:48 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:44:26 +0300 |
commit | 88f3df329e811cb394933c640edfa0d1e47b7d87 (patch) | |
tree | e9623299003c6c3429f0474d025daa59217c14e1 /dnn/torch | |
parent | 72ea20de26d8276ca21cf0d269a01a7c138d1f3f (diff) |
sparsification now compatible with weight_norm
Diffstat (limited to 'dnn/torch')
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py | 10 | ||||
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py | 10 | ||||
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py | 16 | ||||
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py | 6 | ||||
-rw-r--r-- | dnn/torch/osce/export_model_weights.py | 24 | ||||
-rw-r--r-- | dnn/torch/osce/models/lace.py | 13 | ||||
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 38 | ||||
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 14 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py | 9 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py | 7 | ||||
-rw-r--r-- | dnn/torch/osce/utils/layers/td_shaper.py | 19 | ||||
-rw-r--r-- | dnn/torch/osce/utils/misc.py | 22 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/common.py | 10 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/torch/torch.py | 8 |
14 files changed, 134 insertions, 72 deletions
diff --git a/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py index d88f8a26..40d9993b 100644 --- a/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py +++ b/dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py @@ -114,13 +114,17 @@ class Conv1dSparsifier: with torch.no_grad(): for conv, params in self.task_list: # reshape weight - i, o, k = conv.weight.shape - w = conv.weight.permute(0, 2, 1).flatten(1) + if hasattr(conv, 'weight_v'): + weight = conv.weight_v + else: + weight = conv.weight + i, o, k = weight.shape + w = weight.permute(0, 2, 1).flatten(1) target_density, block_size = params density = alpha + (1 - alpha) * target_density w = sparsify_matrix(w, density, block_size) w = w.reshape(i, k, o).permute(0, 2, 1) - conv.weight[:] = w + weight[:] = w if verbose: print(f"conv1d_sparsier[{self.step_counter}]: {density=}") diff --git a/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py index 73f2fc3d..1333845a 100644 --- a/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py +++ b/dnn/torch/dnntools/dnntools/sparsification/conv_transpose1d_sparsifier.py @@ -112,13 +112,17 @@ class ConvTranspose1dSparsifier: with torch.no_grad(): for conv, params in self.task_list: # reshape weight - i, o, k = conv.weight.shape - w = conv.weight.permute(2, 1, 0).reshape(k * o, i) + if hasattr(conv, 'weight_v'): + weight = conv.weight_v + else: + weight = conv.weight + i, o, k = weight.shape + w = weight.permute(2, 1, 0).reshape(k * o, i) target_density, block_size = params density = alpha + (1 - alpha) * target_density w = sparsify_matrix(w, density, block_size) w = w.reshape(k, o, i).permute(2, 1, 0) - conv.weight[:] = w + weight[:] = w if verbose: print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}") diff --git a/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py index 4dfdaf0a..4ccff517 100644 --- a/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py +++ b/dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py @@ -128,12 +128,16 @@ class GRUSparsifier: # input weights for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): if key in params: + if hasattr(gru, 'weight_ih_l0_v'): + weight = gru.weight_ih_l0_v + else: + weight = gru.weight_ih_l0 density = alpha + (1 - alpha) * params[key][0] if verbose: print(f"[{self.step_counter}]: {key} density: {density}") - gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( - gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], + weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + weight[i * hidden_size : (i + 1) * hidden_size, : ], density, # density params[key][1], # block_size params[key][2], # keep_diagonal (might want to set this to False) @@ -149,11 +153,15 @@ class GRUSparsifier: # recurrent weights for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): if key in params: + if hasattr(gru, 'weight_hh_l0_v'): + weight = gru.weight_hh_l0_v + else: + weight = gru.weight_hh_l0 density = alpha + (1 - alpha) * params[key][0] if verbose: print(f"[{self.step_counter}]: {key} density: {density}") - gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( - gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], + weight[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + weight[i * hidden_size : (i + 1) * hidden_size, : ], density, params[key][1], # block_size params[key][2], # keep_diagonal (might want to set this to False) diff --git a/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py b/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py index f52a5906..dee7025c 100644 --- a/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py +++ b/dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py @@ -113,9 +113,13 @@ class LinearSparsifier: with torch.no_grad(): for linear, params in self.task_list: + if hasattr(linear, 'weight_v'): + weight = linear.weight_v + else: + weight = linear.weight target_density, block_size = params density = alpha + (1 - alpha) * target_density - linear.weight[:] = sparsify_matrix(linear.weight, density, block_size) + weight[:] = sparsify_matrix(weight, density, block_size) if verbose: print(f"linear_sparsifier[{self.step_counter}]: {density=}") diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py index 39f76ef9..0bec9604 100644 --- a/dnn/torch/osce/export_model_weights.py +++ b/dnn/torch/osce/export_model_weights.py @@ -43,6 +43,7 @@ from models import model_dict from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d from utils.layers.td_shaper import TDShaper +from utils.misc import remove_all_weight_norm from wexchange.torch import dump_torch_weights @@ -58,9 +59,9 @@ schedules = { 'nolace': [ ('pitch_embedding', dict()), ('feature_net.conv1', dict()), - ('feature_net.conv2', dict(quantize=True, scale=None)), - ('feature_net.tconv', dict(quantize=True, scale=None)), - ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None)), + ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)), ('cf1', dict(quantize=True, scale=None)), ('cf2', dict(quantize=True, scale=None)), ('af1', dict(quantize=True, scale=None)), @@ -70,18 +71,18 @@ schedules = { ('af2', dict(quantize=True, scale=None)), ('af3', dict(quantize=True, scale=None)), ('af4', dict(quantize=True, scale=None)), - ('post_cf1', dict(quantize=True, scale=None)), - ('post_cf2', dict(quantize=True, scale=None)), - ('post_af1', dict(quantize=True, scale=None)), - ('post_af2', dict(quantize=True, scale=None)), - ('post_af3', dict(quantize=True, scale=None)) + ('post_cf1', dict(quantize=True, scale=None, sparse=True)), + ('post_cf2', dict(quantize=True, scale=None, sparse=True)), + ('post_af1', dict(quantize=True, scale=None, sparse=True)), + ('post_af2', dict(quantize=True, scale=None, sparse=True)), + ('post_af3', dict(quantize=True, scale=None, sparse=True)) ], 'lace' : [ ('pitch_embedding', dict()), ('feature_net.conv1', dict()), - ('feature_net.conv2', dict(quantize=True, scale=None)), - ('feature_net.tconv', dict(quantize=True, scale=None)), - ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None)), + ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)), + ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)), ('cf1', dict(quantize=True, scale=None)), ('cf2', dict(quantize=True, scale=None)), ('af1', dict(quantize=True, scale=None)) @@ -140,6 +141,7 @@ if __name__ == "__main__": checkpoint = torch.load(checkpoint_path, map_location='cpu') model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) model.load_state_dict(checkpoint['state_dict']) + remove_all_weight_norm(model, verbose=True) # CWriter model_name = checkpoint['setup']['model']['name'] diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 7e8e739c..51d65c3e 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -70,7 +70,8 @@ class LACE(NNSBase): softquant=False, sparsify=False, sparsification_schedule=[10000, 30000, 100], - sparsification_density=0.5): + sparsification_density=0.5, + apply_weight_norm=False): super().__init__(skip=skip, preemph=preemph) @@ -95,21 +96,21 @@ class LACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) # comb filters left_pad = self.kernel_size // 2 right_pad = self.kernel_size - 1 - left_pad - self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) - self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) if sparsify: - self.sparsify = create_sparsifier(self, *sparsification_schedule) + self.sparsifier = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 7e021930..19bc6caf 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -30,6 +30,8 @@ import torch from torch import nn import torch.nn.functional as F +from torch.nn.utils import weight_norm + import numpy as np @@ -74,11 +76,11 @@ class NoLACE(NNSBase): softquant=False, sparsify=False, sparsification_schedule=[100, 1000, 100], - sparsification_density=0.5): + sparsification_density=0.5, + apply_weight_norm=False): super().__init__(skip=skip, preemph=preemph) - self.num_features = num_features self.cond_dim = cond_dim self.pitch_max = pitch_max @@ -91,6 +93,8 @@ class NoLACE(NNSBase): self.hidden_feature_dim = hidden_feature_dim self.partial_lookahead = partial_lookahead + norm = weight_norm if apply_weight_norm else lambda x, name=None: x + # pitch embedding self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) @@ -99,35 +103,35 @@ class NoLACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density, apply_weight_norm=apply_weight_norm) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) # comb filters left_pad = self.kernel_size // 2 right_pad = self.kernel_size - 1 - left_pad - self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) - self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # spectral shaping - self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # non-linear transforms - self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant) - self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant) - self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant) + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after, softquant=softquant, apply_weight_norm=apply_weight_norm) # combinators - self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) - self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) - self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) + self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) + self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant, apply_weight_norm=apply_weight_norm) # feature transforms - self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) - self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_cf1 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_cf2 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_af1 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_af2 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) + self.post_af3 = norm(nn.Conv1d(cond_dim, cond_dim, 2)) if softquant: self.post_cf1 = soft_quant(self.post_cf1) diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 799064d2..75421449 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -32,6 +32,7 @@ sys.path.append('../dnntools') import torch from torch import nn import torch.nn.functional as F +from torch.nn.utils import weight_norm from utils.complexity import _conv1d_flop_count from utils.softquant import soft_quant @@ -45,7 +46,8 @@ class SilkFeatureNetPL(nn.Module): hidden_feature_dim=64, softquant=False, sparsify=True, - sparsification_density=0.5): + sparsification_density=0.5, + apply_weight_norm=False): super(SilkFeatureNetPL, self).__init__() @@ -53,10 +55,12 @@ class SilkFeatureNetPL(nn.Module): self.num_channels = num_channels self.hidden_feature_dim = hidden_feature_dim - self.conv1 = nn.Conv1d(feature_dim, self.hidden_feature_dim, 1) - self.conv2 = nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2) - self.tconv = nn.ConvTranspose1d(num_channels, num_channels, 4, 4) - self.gru = nn.GRU(num_channels, num_channels, batch_first=True) + norm = weight_norm if apply_weight_norm else lambda x, name=None: x + + self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1)) + self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2)) + self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4)) + self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0') if softquant: self.conv2 = soft_quant(self.conv2) diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py index a116fd72..0d87ca19 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -49,6 +49,7 @@ class LimitedAdaptiveComb1d(nn.Module): global_gain_limits_db=[-6, 6], norm_p=2, softquant=False, + apply_weight_norm=False, **kwargs): """ @@ -99,20 +100,22 @@ class LimitedAdaptiveComb1d(nn.Module): else: self.name = name + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # network for generating convolution weights - self.conv_kernel = nn.Linear(feature_dim, kernel_size) + self.conv_kernel = norm(nn.Linear(feature_dim, kernel_size)) if softquant: self.conv_kernel = soft_quant(self.conv_kernel) # comb filter gain - self.filter_gain = nn.Linear(feature_dim, 1) + self.filter_gain = norm(nn.Linear(feature_dim, 1)) self.log_gain_limit = gain_limit_db * 0.11512925464970229 with torch.no_grad(): self.filter_gain.bias[:] = max(0.1, 4 + self.log_gain_limit) - self.global_filter_gain = nn.Linear(feature_dim, 1) + self.global_filter_gain = norm(nn.Linear(feature_dim, 1)) log_min, log_max = global_gain_limits_db[0] * 0.11512925464970229, global_gain_limits_db[1] * 0.11512925464970229 self.filter_gain_a = (log_max - log_min) / 2 self.filter_gain_b = (log_max + log_min) / 2 diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py index af6ec0ec..55df8c14 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -52,6 +52,7 @@ class LimitedAdaptiveConv1d(nn.Module): shape_gain_db=0, norm_p=2, softquant=False, + apply_weight_norm=False, **kwargs): """ @@ -101,14 +102,16 @@ class LimitedAdaptiveConv1d(nn.Module): else: self.name = name + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # network for generating convolution weights - self.conv_kernel = nn.Linear(feature_dim, in_channels * out_channels * kernel_size) + self.conv_kernel = norm(nn.Linear(feature_dim, in_channels * out_channels * kernel_size)) if softquant: self.conv_kernel = soft_quant(self.conv_kernel) self.shape_gain = min(1, 10**(shape_gain_db / 20)) - self.filter_gain = nn.Linear(feature_dim, out_channels) + self.filter_gain = norm(nn.Linear(feature_dim, out_channels)) log_min, log_max = gain_limits_db[0] * 0.11512925464970229, gain_limits_db[1] * 0.11512925464970229 self.filter_gain_a = (log_max - log_min) / 2 self.filter_gain_b = (log_max + log_min) / 2 diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 0d2052ee..fa7bf348 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -14,7 +14,8 @@ class TDShaper(nn.Module): avg_pool_k=4, innovate=False, pool_after=False, - softquant=False + softquant=False, + apply_weight_norm=False ): """ @@ -47,20 +48,22 @@ class TDShaper(nn.Module): assert frame_size % avg_pool_k == 0 self.env_dim = frame_size // avg_pool_k + 1 + norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x + # feature transform - self.feature_alpha1_f = nn.Conv1d(self.feature_dim, frame_size, 2) - self.feature_alpha1_t = nn.Conv1d(self.env_dim, frame_size, 2) - self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2)) + self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2)) + self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2)) if softquant: self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) if self.innovate: - self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) - self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) + self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) - self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2) - self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2)) + self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2)) def flop_count(self, rate): diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py index c4355b4e..68ee4bfd 100644 --- a/dnn/torch/osce/utils/misc.py +++ b/dnn/torch/osce/utils/misc.py @@ -28,6 +28,7 @@ """ import torch +from torch.nn.utils import remove_weight_norm def count_parameters(model, verbose=False): total = 0 @@ -72,4 +73,23 @@ def create_weights(s_real, s_gen, alpha): weight = torch.exp(alpha * (sr[-1] - sg[-1])) weights.append(weight) - return weights
\ No newline at end of file + return weights + + +def _get_candidates(module: torch.nn.Module): + candidates = [] + for key in module.__dict__.keys(): + if hasattr(module, key + '_v'): + candidates.append(key) + return candidates + +def remove_all_weight_norm(model : torch.nn.Module, verbose=False): + for name, m in model.named_modules(): + candidates = _get_candidates(m) + + for candidate in candidates: + try: + remove_weight_norm(m, name=candidate) + if verbose: print(f'removed weight norm on weight {name}.{candidate}') + except: + pass diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index 524f1cc3..039edd9b 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -282,7 +282,8 @@ def print_conv1d_layer(writer : CWriter, bias : np.ndarray, scale=1/128, format : str = 'torch', - quantize=False): + quantize=False, + sparse=False): if format == "torch": @@ -290,7 +291,7 @@ def print_conv1d_layer(writer : CWriter, weight = np.transpose(weight, (2, 1, 0)) lin_weight = np.reshape(weight, (-1, weight.shape[-1])) - print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize) + print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=sparse, diagonal=False, quantize=quantize) writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n") @@ -369,7 +370,8 @@ def print_tconv1d_layer(writer : CWriter, bias : np.ndarray, stride: int, scale=1/128, - quantize=False): + quantize=False, + sparse=False): in_channels, out_channels, kernel_size = weight.shape @@ -377,7 +379,7 @@ def print_tconv1d_layer(writer : CWriter, linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0) linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten() - print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize) + print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize, sparse=sparse) writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n") writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n") diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index 7da1a4d8..af5d3e59 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -275,7 +275,7 @@ def load_torch_dense_weights(where, dense): dense.bias.set_(torch.from_numpy(b)) -def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): +def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False): w = conv.weight.detach().cpu().numpy().copy() if conv.bias is None: @@ -285,7 +285,7 @@ def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=Fa if isinstance(where, CWriter): - return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize) + return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse) else: os.makedirs(where, exist_ok=True) @@ -305,7 +305,7 @@ def load_torch_conv1d_weights(where, conv): conv.bias.set_(torch.from_numpy(b)) -def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): +def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False): w = conv.weight.detach().cpu().numpy().copy() if conv.bias is None: @@ -315,7 +315,7 @@ def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=F if isinstance(where, CWriter): - return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize) + return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse) else: os.makedirs(where, exist_ok=True) |