diff options
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/torch/torch.py')
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/torch/torch.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index f7e16032..af5d3e59 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -153,7 +153,7 @@ def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/1 np.save(where, 'weight_global_gain.npy', w_global_gain) np.save(where, 'bias_global_gain.npy', b_global_gain) -def dump_torch_tdshaper(where, shaper, name='tdshaper'): +def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128): if isinstance(where, CWriter): where.header.write(f""" @@ -165,7 +165,8 @@ def dump_torch_tdshaper(where, shaper, name='tdshaper'): """ ) - dump_torch_conv1d_weights(where, shaper.feature_alpha1, name + "_alpha1") + dump_torch_conv1d_weights(where, shaper.feature_alpha1_f, name + "_alpha1_f", quantize=quantize, scale=scale) + dump_torch_conv1d_weights(where, shaper.feature_alpha1_t, name + "_alpha1_t") dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2") if shaper.innovate: @@ -274,7 +275,7 @@ def load_torch_dense_weights(where, dense): dense.bias.set_(torch.from_numpy(b)) -def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): +def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False): w = conv.weight.detach().cpu().numpy().copy() if conv.bias is None: @@ -284,7 +285,7 @@ def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=Fa if isinstance(where, CWriter): - return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize) + return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse) else: os.makedirs(where, exist_ok=True) @@ -304,7 +305,7 @@ def load_torch_conv1d_weights(where, conv): conv.bias.set_(torch.from_numpy(b)) -def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): +def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False): w = conv.weight.detach().cpu().numpy().copy() if conv.bias is None: @@ -314,7 +315,7 @@ def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=F if isinstance(where, CWriter): - return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize) + return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse) else: os.makedirs(where, exist_ok=True) |