diff options
Diffstat (limited to 'dnn/torch/dnntools/dnntools/sparsification/utils.py')
-rw-r--r-- | dnn/torch/dnntools/dnntools/sparsification/utils.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/dnn/torch/dnntools/dnntools/sparsification/utils.py b/dnn/torch/dnntools/dnntools/sparsification/utils.py new file mode 100644 index 00000000..42f22353 --- /dev/null +++ b/dnn/torch/dnntools/dnntools/sparsification/utils.py @@ -0,0 +1,64 @@ +import torch + +from dnntools.sparsification import GRUSparsifier, LinearSparsifier, Conv1dSparsifier, ConvTranspose1dSparsifier + +def mark_for_sparsification(module, params): + setattr(module, 'sparsify', True) + setattr(module, 'sparsification_params', params) + return module + +def create_sparsifier(module, start, stop, interval): + sparsifier_list = [] + for m in module.modules(): + if hasattr(m, 'sparsify'): + if isinstance(m, torch.nn.GRU): + sparsifier_list.append( + GRUSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + elif isinstance(m, torch.nn.Linear): + sparsifier_list.append( + LinearSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + elif isinstance(m, torch.nn.Conv1d): + sparsifier_list.append( + Conv1dSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + elif isinstance(m, torch.nn.ConvTranspose1d): + sparsifier_list.append( + ConvTranspose1dSparsifier([(m, m.sparsification_params)], start, stop, interval) + ) + else: + print(f"[create_sparsifier] warning: module {m} marked for sparsification but no suitable sparsifier exists.") + + def sparsify(verbose=False): + for sparsifier in sparsifier_list: + sparsifier.step(verbose) + + return sparsify + + +def count_parameters(model, verbose=False): + total = 0 + for name, p in model.named_parameters(): + count = torch.ones_like(p).sum().item() + + if verbose: + print(f"{name}: {count} parameters") + + total += count + + return total + +def estimate_nonzero_parameters(module): + num_zero_parameters = 0 + if hasattr(module, 'sparsify'): + params = module.sparsification_params + if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.ConvTranspose1d): + num_zero_parameters = torch.ones_like(module.weight).sum().item() * (1 - params[0]) + elif isinstance(module, torch.nn.GRU): + num_zero_parameters = module.input_size * module.hidden_size * (3 - params['W_ir'][0] - params['W_iz'][0] - params['W_in'][0]) + num_zero_parameters += module.hidden_size * module.hidden_size * (3 - params['W_hr'][0] - params['W_hz'][0] - params['W_hn'][0]) + elif isinstance(module, torch.nn.Linear): + num_zero_parameters = module.in_features * module.out_features * params[0] + else: + raise ValueError(f'unknown sparsification method for module of type {type(module)}') |