Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-08-08 11:46:11 +0300
committerJan Buethe <jbuethe@amazon.de>2023-08-08 11:46:11 +0300
commit5160d7fdfada61e48d4c3df110cd773ec5e9ce7e (patch)
treeefde2bce288143ee612cb0a692bd013f68f5cc2b
parent6cba42f999866b614c0d70b09b603c175fbdbc4d (diff)
improved auto-scaling in wexchange
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/common.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py
index e5263bf2..a8986816 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -174,14 +174,15 @@ def print_sparse_weight(writer, A, name, scale=1/128, have_diag=True, quantize=F
def compute_scaling(weight):
""" computes optimal scaling vector for weight of shape (features_in, features_out) """
- n_in, _ = weight.shape
- n_in2 = 2 * (n_in // 2)
+ n_in, n_out = weight.shape
+ assert n_in % 4 == 0 and n_out % 8 == 0
- weight_sums = np.abs(weight[: n_in2 : 2]) + np.abs(weight[1 : n_in : 2])
- weight_max = weight_sums.max(axis=0)
- if n_in % 2: weight_max = np.maximum(weight_max, np.abs(weight[-1]))
+ weight_max_abs = np.max(np.abs(weight), axis=0)
+ weight_max_sum = np.max(np.abs(weight[: n_in : 2] + weight[1 : n_in : 2]), axis=0)
+ scale_max = weight_max_abs / 127
+ scale_sum = weight_max_sum / 129
- scale = weight_max / 127
+ scale = np.maximum(scale_max, scale_sum)
return scale