Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/torch/torch.py')
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/torch.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py
index f7e16032..af5d3e59 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -153,7 +153,7 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/1
np.save(where, 'weight_global_gain.npy', w_global_gain)
np.save(where, 'bias_global_gain.npy', b_global_gain)
-def dump_torch_tdshaper(where, shaper, name='tdshaper'):
+def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128):
if isinstance(where, CWriter):
where.header.write(f"""
@@ -165,7 +165,8 @@ def dump_torch_tdshaper(where, shaper, name='tdshaper'):
"""
)
- dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1")
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1_f, name + "_alpha1_f", quantize=quantize, scale=scale)
+ dump_torch_conv1d_weights(where, shaper.feature_alpha1_t, name + "_alpha1_t")
dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2")
if shaper.innovate:
@@ -274,7 +275,7 @@ def load_torch_dense_weights(where, dense):
dense.bias.set_(torch.from_numpy(b))
-def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
+def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
w = conv.weight.detach().cpu().numpy().copy()
if conv.bias is None:
@@ -284,7 +285,7 @@ def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=Fa
if isinstance(where, CWriter):
- return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize)
+ return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse)
else:
os.makedirs(where, exist_ok=True)
@@ -304,7 +305,7 @@ def load_torch_conv1d_weights(where, conv):
conv.bias.set_(torch.from_numpy(b))
-def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
+def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
w = conv.weight.detach().cpu().numpy().copy()
if conv.bias is None:
@@ -314,7 +315,7 @@ def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=F
if isinstance(where, CWriter):
- return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize)
+ return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse)
else:
os.makedirs(where, exist_ok=True)