Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-08 14:00:49 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-08 14:05:04 +0300
commit999ddbe09ca1e521cc8d71005c1ecd8513e47611 (patch)
tree85022704a9352cb03751863f9a1ae788afcc805b
parente968878f06fb96e022f09c0fab60b88ab3f3ac81 (diff)
more sparsification stuff
-rw-r--r--dnn/torch/dnntools/dnntools/sparsification/utils.py27
-rw-r--r--dnn/torch/osce/adv_train_model.py4
-rw-r--r--dnn/torch/osce/engine/engine.py4
-rw-r--r--dnn/torch/osce/models/lace.py8
-rw-r--r--dnn/torch/osce/models/no_lace.py18
-rw-r--r--dnn/torch/osce/models/silk_feature_net_pl.py19
-rw-r--r--dnn/torch/osce/train_model.py4
-rw-r--r--dnn/torch/osce/utils/misc.py10
-rw-r--r--dnn/torch/osce/utils/templates.py15
9 files changed, 79 insertions, 30 deletions
diff --git a/dnn/torch/dnntools/dnntools/sparsification/utils.py b/dnn/torch/dnntools/dnntools/sparsification/utils.py
index da9dc89e..42f22353 100644
--- a/dnn/torch/dnntools/dnntools/sparsification/utils.py
+++ b/dnn/torch/dnntools/dnntools/sparsification/utils.py
@@ -36,8 +36,29 @@ def create_sparsifier(module, start, stop, interval):
return sparsify
-def estimate_parameters(module):
+
+def count_parameters(model, verbose=False):
+ total = 0
+ for name, p in model.named_parameters():
+ count = torch.ones_like(p).sum().item()
+
+ if verbose:
+ print(f"{name}: {count} parameters")
+
+ total += count
+
+ return total
+
+def estimate_nonzero_parameters(module):
num_zero_parameters = 0
if hasattr(module, 'sparsify'):
- if isinstance(module, torch.nn.Conv1d):
- pass
+ params = module.sparsification_params
+ if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d):
+ num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0])
+ elif isinstance(module, torch.nn.GRU):
+ num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0])
+ num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0])
+ elif isinstance(module, torch.nn.Linear):
+ num_zero_parameters = module.in_features * module.out_features * params[0]
+ else:
+ raise ValueError(f'unknown sparsification method for module of type {type(module)}')
diff --git a/dnn/torch/osce/adv_train_model.py b/dnn/torch/osce/adv_train_model.py
index 9cd32000..dcfb65f1 100644
--- a/dnn/torch/osce/adv_train_model.py
+++ b/dnn/torch/osce/adv_train_model.py
@@ -408,6 +408,10 @@ for ep in range(1, epochs + 1):
optimizer.step()
+ # sparsification
+ if hasattr(model, 'sparsifier'):
+ model.sparsifier()
+
running_model_grad_norm += get_grad_norm(model).detach().cpu().item()
running_adv_loss += gen_loss.detach().cpu().item()
running_disc_loss += disc_loss.detach().cpu().item()
diff --git a/dnn/torch/osce/engine/engine.py b/dnn/torch/osce/engine/engine.py
index 2ccc0277..0762c898 100644
--- a/dnn/torch/osce/engine/engine.py
+++ b/dnn/torch/osce/engine/engine.py
@@ -47,8 +47,8 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
scheduler.step()
# sparsification
- if hasattr(model, 'sparsify'):
- model.sparsify(True)
+ if hasattr(model, 'sparsifier'):
+ model.sparsifier()
# update running loss
running_loss += float(loss.cpu())
diff --git a/dnn/torch/osce/models/lace.py b/dnn/torch/osce/models/lace.py
index 78c1a717..7e8e739c 100644
--- a/dnn/torch/osce/models/lace.py
+++ b/dnn/torch/osce/models/lace.py
@@ -68,7 +68,9 @@ class LACE(NNSBase):
partial_lookahead=True,
norm_p=2,
softquant=False,
- sparsify=False):
+ sparsify=False,
+ sparsification_schedule=[10000, 30000, 100],
+ sparsification_density=0.5):
super().__init__(skip=skip, preemph=preemph)
@@ -93,7 +95,7 @@ class LACE(NNSBase):
# feature net
if partial_lookahead:
- self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant)
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
@@ -107,7 +109,7 @@ class LACE(NNSBase):
self.af1 = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p, softquant=softquant)
if sparsify:
- self.sparsify = create_sparsifier(self, 500, 2000, 100)
+ self.sparsify = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py
index 5654db6c..7e021930 100644
--- a/dnn/torch/osce/models/no_lace.py
+++ b/dnn/torch/osce/models/no_lace.py
@@ -72,7 +72,9 @@ class NoLACE(NNSBase):
avg_pool_k=4,
pool_after=False,
softquant=False,
- sparsify=False):
+ sparsify=False,
+ sparsification_schedule=[100, 1000, 100],
+ sparsification_density=0.5):
super().__init__(skip=skip, preemph=preemph)
@@ -97,7 +99,7 @@ class NoLACE(NNSBase):
# feature net
if partial_lookahead:
- self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant)
+ self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim, softquant=softquant, sparsify=sparsify, sparsification_density=sparsification_density)
else:
self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
@@ -136,13 +138,13 @@ class NoLACE(NNSBase):
if sparsify:
- mark_for_sparsification(self.post_cf1, (0.25, [8, 4]))
- mark_for_sparsification(self.post_cf2, (0.25, [8, 4]))
- mark_for_sparsification(self.post_af1, (0.25, [8, 4]))
- mark_for_sparsification(self.post_af2, (0.25, [8, 4]))
- mark_for_sparsification(self.post_af3, (0.25, [8, 4]))
+ mark_for_sparsification(self.post_cf1, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_cf2, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_af1, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_af2, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.post_af3, (sparsification_density, [8, 4]))
- self.sparsify = create_sparsifier(self, 500, 1000, 100)
+ self.sparsifier = create_sparsifier(self, *sparsification_schedule)
def flop_count(self, rate=16000, verbose=False):
diff --git a/dnn/torch/osce/models/silk_feature_net_pl.py b/dnn/torch/osce/models/silk_feature_net_pl.py
index 72f8531c..799064d2 100644
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -44,7 +44,8 @@ class SilkFeatureNetPL(nn.Module):
num_channels=256,
hidden_feature_dim=64,
softquant=False,
- sparsify=True):
+ sparsify=True,
+ sparsification_density=0.5):
super(SilkFeatureNetPL, self).__init__()
@@ -64,17 +65,17 @@ class SilkFeatureNetPL(nn.Module):
if sparsify:
- mark_for_sparsification(self.conv2, (0.25, [8, 4]))
- mark_for_sparsification(self.tconv, (0.25, [8, 4]))
+ mark_for_sparsification(self.conv2, (sparsification_density, [8, 4]))
+ mark_for_sparsification(self.tconv, (sparsification_density, [8, 4]))
mark_for_sparsification(
self.gru,
{
- 'W_ir' : (0.25, [8, 4], False),
- 'W_iz' : (0.25, [8, 4], False),
- 'W_in' : (0.25, [8, 4], False),
- 'W_hr' : (0.125, [8, 4], True),
- 'W_hz' : (0.125, [8, 4], True),
- 'W_hn' : (0.125, [8, 4], True),
+ 'W_ir' : (sparsification_density, [8, 4], False),
+ 'W_iz' : (sparsification_density, [8, 4], False),
+ 'W_in' : (sparsification_density, [8, 4], False),
+ 'W_hr' : (sparsification_density, [8, 4], True),
+ 'W_hz' : (sparsification_density, [8, 4], True),
+ 'W_hn' : (sparsification_density, [8, 4], True),
}
)
diff --git a/dnn/torch/osce/train_model.py b/dnn/torch/osce/train_model.py
index 6e2514b9..34cc638c 100644
--- a/dnn/torch/osce/train_model.py
+++ b/dnn/torch/osce/train_model.py
@@ -54,7 +54,7 @@ from engine.engine import train_one_epoch, evaluate
from utils.silk_features import load_inference_data
-from utils.misc import count_parameters
+from utils.misc import count_parameters, count_nonzero_parameters
from losses.stft_loss import MRSTFTLoss, MRLogMelLoss
@@ -292,6 +292,6 @@ for ep in range(1, epochs + 1):
torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
- print()
+ print(f"non-zero parameters: {count_nonzero_parameters(model)}\n")
print('Done')
diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py
index 6fe3dfa8..c4355b4e 100644
--- a/dnn/torch/osce/utils/misc.py
+++ b/dnn/torch/osce/utils/misc.py
@@ -41,7 +41,17 @@ def count_parameters(model, verbose=False):
return total
+def count_nonzero_parameters(model, verbose=False):
+ total = 0
+ for name, p in model.named_parameters():
+ count = torch.count_nonzero(p).item()
+
+ if verbose:
+ print(f"{name}: {count} non-zero parameters")
+ total += count
+
+ return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py
index 0d731127..5fc84ef1 100644
--- a/dnn/torch/osce/utils/templates.py
+++ b/dnn/torch/osce/utils/templates.py
@@ -51,7 +51,10 @@ lace_setup = {
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
- 'softquant': True
+ 'softquant': True,
+ 'sparsify': False,
+ 'sparsification_density': 0.4,
+ 'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
@@ -108,7 +111,10 @@ nolace_setup = {
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
- 'softquant': True
+ 'softquant': True,
+ 'sparsify': False,
+ 'sparsification_density': 0.4,
+ 'sparsification_schedule': [10000, 40000, 200]
}
},
'data': {
@@ -163,7 +169,10 @@ nolace_setup_adv = {
'pitch_max': 300,
'preemph': 0.85,
'skip': 91,
- 'softquant': True
+ 'softquant': True,
+ 'sparsify': False,
+ 'sparsification_density': 0.4,
+ 'sparsification_schedule': [0, 0, 200]
}
},
'data': {