Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/c_export/common.py')
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/common.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py
index 524f1cc3..039edd9b 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -282,7 +282,8 @@ def print_conv1d_layer(writer : CWriter,
bias : np.ndarray,
scale=1/128,
format : str = 'torch',
- quantize=False):
+ quantize=False,
+ sparse=False):
if format == "torch":
@@ -290,7 +291,7 @@ def print_conv1d_layer(writer : CWriter,
weight = np.transpose(weight, (2, 1, 0))
lin_weight = np.reshape(weight, (-1, weight.shape[-1]))
- print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize)
+ print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=sparse, diagonal=False, quantize=quantize)
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n")
@@ -369,7 +370,8 @@ def print_tconv1d_layer(writer : CWriter,
bias : np.ndarray,
stride: int,
scale=1/128,
- quantize=False):
+ quantize=False,
+ sparse=False):
in_channels, out_channels, kernel_size = weight.shape
@@ -377,7 +379,7 @@ def print_tconv1d_layer(writer : CWriter,
linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0)
linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten()
- print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize)
+ print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize, sparse=sparse)
writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n")
writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n")