Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/osce/utils/misc.py')
-rw-r--r--dnn/torch/osce/utils/misc.py32
1 files changed, 31 insertions, 1 deletions
diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py
index 6fe3dfa8..68ee4bfd 100644
--- a/dnn/torch/osce/utils/misc.py
+++ b/dnn/torch/osce/utils/misc.py
@@ -28,6 +28,7 @@
"""
import torch
+from torch.nn.utils import remove_weight_norm
def count_parameters(model, verbose=False):
total = 0
@@ -41,7 +42,17 @@ def count_parameters(model, verbose=False):
return total
+def count_nonzero_parameters(model, verbose=False):
+ total = 0
+ for name, p in model.named_parameters():
+ count = torch.count_nonzero(p).item()
+
+ if verbose:
+ print(f"{name}: {count} non-zero parameters")
+
+ total += count
+ return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
@@ -62,4 +73,23 @@ def create_weights(s_real, s_gen, alpha):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
- return weights \ No newline at end of file
+ return weights
+
+
+def _get_candidates(module: torch.nn.Module):
+ candidates = []
+ for key in module.__dict__.keys():
+ if hasattr(module, key + '_v'):
+ candidates.append(key)
+ return candidates
+
+def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
+ for name, m in model.named_modules():
+ candidates = _get_candidates(m)
+
+ for candidate in candidates:
+ try:
+ remove_weight_norm(m, name=candidate)
+ if verbose: print(f'removed weight norm on weight {name}.{candidate}')
+ except:
+ pass