diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-08 14:00:49 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-08 14:05:04 +0300 |
commit | 999ddbe09ca1e521cc8d71005c1ecd8513e47611 (patch) | |
tree | 85022704a9352cb03751863f9a1ae788afcc805b /dnn | |
parent | e968878f06fb96e022f09c0fab60b88ab3f3ac81 (diff) |
more sparsification stuff
Diffstat (limited to 'dnn')
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/utils.py | 27 | ||||
-rw-r--r-- | dnn/torch/osce/adv_train_model.py | 4 | ||||
-rw-r--r-- | dnn/torch/osce/engine/engine.py | 4 | ||||
-rw-r--r-- | dnn/torch/osce/models/lace.py | 8 | ||||
-rw-r--r-- | dnn/torch/osce/models/no_lace.py | 18 | ||||
-rw-r--r-- | dnn/torch/osce/models/silk_feature_net_pl.py | 19 | ||||
-rw-r--r-- | dnn/torch/osce/train_model.py | 4 | ||||
-rw-r--r-- | dnn/torch/osce/utils/misc.py | 10 | ||||
-rw-r--r-- | dnn/torch/osce/utils/templates.py | 15 |
9 files changed, 79 insertions, 30 deletions
diff --git a/dnn/torch/dnntools/dnntools/sparsification/utils.py b/dnn/torch/dnntools/dnntools/sparsification/utils.py index da9dc89e..42f22353 100644 --- a/dnn/torch/dnntools/dnntools/sparsification/utils.py +++ b/dnn/torch/dnntools/dnntools/sparsification/utils.py @@ -36,8 +36,29 @@ def create_sparsifier(module, start, stop, interval): return sparsify -def estimate_parameters(module): + +def count_parameters(model, verbose=False): + total = 0 + for name, p in model.named_parameters(): + count = torch.ones_like(p).sum().item() + + if verbose: + print(f"{name}: {count} parameters") + + total += count + + return total + +def estimate_nonzero_parameters(module): num_zero_parameters = 0 if hasattr(module, 'sparsify'): - if isinstance(module, torch.nn.Conv1d): - pass + params = module.sparsification_params + if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d): + num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0]) + elif isinstance(module, torch.nn.GRU): + num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0]) + num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0]) + elif isinstance(module, torch.nn.Linear): + num_zero_parameters = module.in_features * module.out_features * params[0] + else: + raise ValueError(f'unknown sparsification method for module of type {type(module)}') diff --git a/dnn/torch/osce/adv_train_model.py b/dnn/torch/osce/adv_train_model.py index 9cd32000..dcfb65f1 100644 --- a/dnn/torch/osce/adv_train_model.py +++ b/dnn/torch/osce/adv_train_model.py @@ -408,6 +408,10 @@ for ep in range(1, epochs + 1): optimizer.step() + # sparsification + if hasattr(model, 'sparsifier'): + model.sparsifier() + running_model_grad_norm += get_grad_norm(model).detach().cpu().item() running_adv_loss += gen_loss.detach().cpu().item() running_disc_loss += disc_loss.detach().cpu().item() diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py index 2ccc0277..0762c898 100644 --- a/dnn/torch/osce/engine/engine.py +++ b/dnn/torch/osce/engine/engine.py @@ -47,8 +47,8 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, scheduler.step() # sparsification - if hasattr(model, 'sparsify'): - model.sparsify(True) + if hasattr(model, 'sparsifier'): + model.sparsifier() # update running loss running_loss += float(loss.cpu()) diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py index 78c1a717..7e8e739c 100644 --- a/dnn/torch/osce/models/lace.py +++ b/dnn/torch/osce/models/lace.py @@ -68,7 +68,9 @@ class LACE(NNSBase): partial_lookahead=True, norm_p=2, softquant=False, - sparsify=False): + sparsify=False, + sparsification_schedule=[10000, 30000, 100], + sparsification_density=0.5): super().__init__(skip=skip, preemph=preemph) @@ -93,7 +95,7 @@ class LACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) @@ -107,7 +109,7 @@ class LACE(NNSBase): self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant) if sparsify: - self.sparsify = create_sparsifier(self, 500, 2000, 100) + self.sparsify = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 5654db6c..7e021930 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -72,7 +72,9 @@ class NoLACE(NNSBase): avg_pool_k=4, pool_after=False, softquant=False, - sparsify=False): + sparsify=False, + sparsification_schedule=[100, 1000, 100], + sparsification_density=0.5): super().__init__(skip=skip, preemph=preemph) @@ -97,7 +99,7 @@ class NoLACE(NNSBase): # feature net if partial_lookahead: - self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant) + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density) else: self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) @@ -136,13 +138,13 @@ class NoLACE(NNSBase): if sparsify: - mark_for_sparsification(self.post_cf1, (0.25, [8, 4])) - mark_for_sparsification(self.post_cf2, (0.25, [8, 4])) - mark_for_sparsification(self.post_af1, (0.25, [8, 4])) - mark_for_sparsification(self.post_af2, (0.25, [8, 4])) - mark_for_sparsification(self.post_af3, (0.25, [8, 4])) + mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4])) + mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4])) - self.sparsify = create_sparsifier(self, 500, 1000, 100) + self.sparsifier = create_sparsifier(self, *sparsification_schedule) def flop_count(self, rate=16000, verbose=False): diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py index 72f8531c..799064d2 100644 --- a/dnn/torch/osce/models/silk_feature_net_pl.py +++ b/dnn/torch/osce/models/silk_feature_net_pl.py @@ -44,7 +44,8 @@ class SilkFeatureNetPL(nn.Module): num_channels=256, hidden_feature_dim=64, softquant=False, - sparsify=True): + sparsify=True, + sparsification_density=0.5): super(SilkFeatureNetPL, self).__init__() @@ -64,17 +65,17 @@ class SilkFeatureNetPL(nn.Module): if sparsify: - mark_for_sparsification(self.conv2, (0.25, [8, 4])) - mark_for_sparsification(self.tconv, (0.25, [8, 4])) + mark_for_sparsification(self.conv2, (sparsification_density, [8, 4])) + mark_for_sparsification(self.tconv, (sparsification_density, [8, 4])) mark_for_sparsification( self.gru, { - 'W_ir' : (0.25, [8, 4], False), - 'W_iz' : (0.25, [8, 4], False), - 'W_in' : (0.25, [8, 4], False), - 'W_hr' : (0.125, [8, 4], True), - 'W_hz' : (0.125, [8, 4], True), - 'W_hn' : (0.125, [8, 4], True), + 'W_ir' : (sparsification_density, [8, 4], False), + 'W_iz' : (sparsification_density, [8, 4], False), + 'W_in' : (sparsification_density, [8, 4], False), + 'W_hr' : (sparsification_density, [8, 4], True), + 'W_hz' : (sparsification_density, [8, 4], True), + 'W_hn' : (sparsification_density, [8, 4], True), } ) diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py index 6e2514b9..34cc638c 100644 --- a/dnn/torch/osce/train_model.py +++ b/dnn/torch/osce/train_model.py @@ -54,7 +54,7 @@ from engine.engine import train_one_epoch, evaluate from utils.silk_features import load_inference_data -from utils.misc import count_parameters +from utils.misc import count_parameters, count_nonzero_parameters from losses.stft_loss import MRSTFTLoss, MRLogMelLoss @@ -292,6 +292,6 @@ for ep in range(1, epochs + 1): torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) - print() + print(f"non-zero parameters: {count_nonzero_parameters(model)}\n") print('Done') diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py index 6fe3dfa8..c4355b4e 100644 --- a/dnn/torch/osce/utils/misc.py +++ b/dnn/torch/osce/utils/misc.py @@ -41,7 +41,17 @@ def count_parameters(model, verbose=False): return total +def count_nonzero_parameters(model, verbose=False): + total = 0 + for name, p in model.named_parameters(): + count = torch.count_nonzero(p).item() + + if verbose: + print(f"{name}: {count} non-zero parameters") + total += count + + return total def retain_grads(module): for p in module.parameters(): if p.requires_grad: diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index 0d731127..5fc84ef1 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -51,7 +51,10 @@ lace_setup = { 'pitch_max': 300, 'preemph': 0.85, 'skip': 91, - 'softquant': True + 'softquant': True, + 'sparsify': False, + 'sparsification_density': 0.4, + 'sparsification_schedule': [10000, 40000, 200] } }, 'data': { @@ -108,7 +111,10 @@ nolace_setup = { 'pitch_max': 300, 'preemph': 0.85, 'skip': 91, - 'softquant': True + 'softquant': True, + 'sparsify': False, + 'sparsification_density': 0.4, + 'sparsification_schedule': [10000, 40000, 200] } }, 'data': { @@ -163,7 +169,10 @@ nolace_setup_adv = { 'pitch_max': 300, 'preemph': 0.85, 'skip': 91, - 'softquant': True + 'softquant': True, + 'sparsify': False, + 'sparsification_density': 0.4, + 'sparsification_schedule': [0, 0, 200] } }, 'data': { |