diff options
Diffstat (limited to 'dnn/torch/osce/utils/misc.py')
-rw-r--r-- | dnn/torch/osce/utils/misc.py | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/dnn/torch/osce/utils/misc.py b/dnn/torch/osce/utils/misc.py index 6fe3dfa8..68ee4bfd 100644 --- a/dnn/torch/osce/utils/misc.py +++ b/dnn/torch/osce/utils/misc.py @@ -28,6 +28,7 @@ """ import torch +from torch.nn.utils import remove_weight_norm def count_parameters(model, verbose=False): total = 0 @@ -41,7 +42,17 @@ def count_parameters(model, verbose=False): return total +def count_nonzero_parameters(model, verbose=False): + total = 0 + for name, p in model.named_parameters(): + count = torch.count_nonzero(p).item() + + if verbose: + print(f"{name}: {count} non-zero parameters") + + total += count + return total def retain_grads(module): for p in module.parameters(): if p.requires_grad: @@ -62,4 +73,23 @@ def create_weights(s_real, s_gen, alpha): weight = torch.exp(alpha * (sr[-1] - sg[-1])) weights.append(weight) - return weights
\ No newline at end of file + return weights + + +def _get_candidates(module: torch.nn.Module): + candidates = [] + for key in module.__dict__.keys(): + if hasattr(module, key + '_v'): + candidates.append(key) + return candidates + +def remove_all_weight_norm(model : torch.nn.Module, verbose=False): + for name, m in model.named_modules(): + candidates = _get_candidates(m) + + for candidate in candidates: + try: + remove_weight_norm(m, name=candidate) + if verbose: print(f'removed weight norm on weight {name}.{candidate}') + except: + pass |