diff options
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/c_export/common.py')
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/common.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index 98a45c3f..5dd9f138 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -124,6 +124,7 @@ def extract_diagonal(A): return diag, B def quantize_weight(weight, scale): + scale = scale + 1e-30 Aq = np.round(weight / scale).astype('int') if Aq.max() > 127 or Aq.min() <= -128: raise ValueError("value out of bounds in quantize_weight") @@ -227,7 +228,7 @@ def print_linear_layer(writer : CWriter, nb_inputs, nb_outputs = weight.shape - if scale is None: + if scale is None and quantize: scale = compute_scaling(weight) @@ -359,4 +360,4 @@ def print_gru_layer(writer : CWriter, writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {N}\n") writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n") - return N
\ No newline at end of file + return N |