Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2024-01-04 14:03:50 +0300
committerJan Buethe <jbuethe@amazon.de>2024-01-19 18:44:25 +0300
commit4cc6f28d65abe8ebebe44192a2a3f546efaff2cd (patch)
treea78bce8e92dfcc1ba1938b90df2c0d266241512c
parentcb8afe5a2a59f3c736939cd26c935029fc9779ca (diff)
fixed softquant module
-rw-r--r--dnn/torch/osce/utils/softquant.py28
1 files changed, 20 insertions, 8 deletions
diff --git a/dnn/torch/osce/utils/softquant.py b/dnn/torch/osce/utils/softquant.py
index 3917488b..527350e2 100644
--- a/dnn/torch/osce/utils/softquant.py
+++ b/dnn/torch/osce/utils/softquant.py
@@ -9,10 +9,13 @@ class SoftQuant:
self.quantization_noise = None
self.scale = scale
- def __call__(self, module, inputs, *args):
- if self.quantization_noise is None:
+ def __call__(self, module, inputs, *args, before=True):
+ if not module.training: return
+
+ if before:
self.quantization_noise = dict()
for name in self.names:
+ print(f"adding noise to {module}.{name}")
weight = getattr(module, name)
self.quantization_noise[name] = \
self.scale * weight.abs().max() * 2 * (torch.rand_like(weight) - 0.5)
@@ -20,6 +23,7 @@ class SoftQuant:
weight.data[:] = weight + self.quantization_noise[name]
else:
for name in self.names:
+ print(f"removing noise from {module}.{name}")
weight = getattr(module, name)
with torch.no_grad():
weight.data[:] = weight - self.quantization_noise[name]
@@ -32,8 +36,14 @@ class SoftQuant:
if not hasattr(module, name):
raise ValueError("")
- module.register_forward_pre_hook(fn)
- module.register_forward_hook(fn)
+ fn_before = lambda *x : fn(*x, before=True)
+ fn_after = lambda *x : fn(*x, before=False)
+ setattr(fn_before, 'sqm', fn)
+ setattr(fn_after, 'sqm', fn)
+
+
+ module.register_forward_pre_hook(fn_before)
+ module.register_forward_hook(fn_after)
module
@@ -46,10 +56,12 @@ def soft_quant(module, names=['weight'], scale=0.5/127):
def remove_soft_quant(module, names=['weight']):
for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, SoftQuant) and hook.names == names:
- del module._forward_pre_hooks[k]
+ if hasattr(hook, 'sqm'):
+ if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
+ del module._forward_pre_hooks[k]
for k, hook in module._forward_hooks.items():
- if isinstance(hook, SoftQuant) and hook.names == names:
- del module._forward_hooks[k]
+ if hasattr(hook, 'sqm'):
+ if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
+ del module._forward_hooks[k]
return module \ No newline at end of file