diff options
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/c_export/common.py')
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/common.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index 524f1cc3..039edd9b 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -282,7 +282,8 @@ def print_conv1d_layer(writer : CWriter, bias : np.ndarray, scale=1/128, format : str = 'torch', - quantize=False): + quantize=False, + sparse=False): if format == "torch": @@ -290,7 +291,7 @@ def print_conv1d_layer(writer : CWriter, weight = np.transpose(weight, (2, 1, 0)) lin_weight = np.reshape(weight, (-1, weight.shape[-1])) - print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize) + print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=sparse, diagonal=False, quantize=quantize) writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n") @@ -369,7 +370,8 @@ def print_tconv1d_layer(writer : CWriter, bias : np.ndarray, stride: int, scale=1/128, - quantize=False): + quantize=False, + sparse=False): in_channels, out_channels, kernel_size = weight.shape @@ -377,7 +379,7 @@ def print_tconv1d_layer(writer : CWriter, linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0) linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten() - print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize) + print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize, sparse=sparse) writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n") writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n") |