diff options
author | Jan Buethe <jbuethe@amazon.de> | 2024-01-04 14:03:50 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2024-01-19 18:44:25 +0300 |
commit | 4cc6f28d65abe8ebebe44192a2a3f546efaff2cd (patch) | |
tree | a78bce8e92dfcc1ba1938b90df2c0d266241512c | |
parent | cb8afe5a2a59f3c736939cd26c935029fc9779ca (diff) |
fixed softquant module
-rw-r--r-- | dnn/torch/osce/utils/softquant.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/dnn/torch/osce/utils/softquant.py b/dnn/torch/osce/utils/softquant.py index 3917488b..527350e2 100644 --- a/dnn/torch/osce/utils/softquant.py +++ b/dnn/torch/osce/utils/softquant.py @@ -9,10 +9,13 @@ class SoftQuant: self.quantization_noise = None self.scale = scale - def __call__(self, module, inputs, *args): - if self.quantization_noise is None: + def __call__(self, module, inputs, *args, before=True): + if not module.training: return + + if before: self.quantization_noise = dict() for name in self.names: + print(f"adding noise to {module}.{name}") weight = getattr(module, name) self.quantization_noise[name] = \ self.scale * weight.abs().max() * 2 * (torch.rand_like(weight) - 0.5) @@ -20,6 +23,7 @@ class SoftQuant: weight.data[:] = weight + self.quantization_noise[name] else: for name in self.names: + print(f"removing noise from {module}.{name}") weight = getattr(module, name) with torch.no_grad(): weight.data[:] = weight - self.quantization_noise[name] @@ -32,8 +36,14 @@ class SoftQuant: if not hasattr(module, name): raise ValueError("") - module.register_forward_pre_hook(fn) - module.register_forward_hook(fn) + fn_before = lambda *x : fn(*x, before=True) + fn_after = lambda *x : fn(*x, before=False) + setattr(fn_before, 'sqm', fn) + setattr(fn_after, 'sqm', fn) + + + module.register_forward_pre_hook(fn_before) + module.register_forward_hook(fn_after) module @@ -46,10 +56,12 @@ def soft_quant(module, names=['weight'], scale=0.5/127): def remove_soft_quant(module, names=['weight']): for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, SoftQuant) and hook.names == names: - del module._forward_pre_hooks[k] + if hasattr(hook, 'sqm'): + if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names: + del module._forward_pre_hooks[k] for k, hook in module._forward_hooks.items(): - if isinstance(hook, SoftQuant) and hook.names == names: - del module._forward_hooks[k] + if hasattr(hook, 'sqm'): + if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names: + del module._forward_hooks[k] return module
\ No newline at end of file |