Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-23 20:48:25 +0300
committerJan Buethe <jbuethe@amazon.de>2023-11-23 20:48:25 +0300
commit85bddfcba421122526c034d709d1b94862ff3c3a (patch)
tree596d55cb370abe852696cc38b324b9249c70f090
parent96a8e414356b23bf1cb91eab6797abdc663d7967 (diff)
added add_typedef option to CWriter
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/c_writer.py14
-rw-r--r--dnn/torch/weight-exchange/wexchange/torch/torch.py1
2 files changed, 11 insertions, 4 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
index 36050881..2745f337 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
@@ -38,7 +38,8 @@ class CWriter:
create_state_struct=False,
enable_binary_blob=True,
model_struct_name="Model",
- nnet_header="nnet.h"):
+ nnet_header="nnet.h",
+ add_typedef=False):
"""
Writer class for creating souce and header files for weight exports to C
@@ -73,6 +74,7 @@ class CWriter:
self.enable_binary_blob = enable_binary_blob
self.create_state_struct = create_state_struct
self.model_struct_name = model_struct_name
+ self.add_typedef = add_typedef
# for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>)
self.layer_dict = OrderedDict()
@@ -119,11 +121,17 @@ f"""
# create model type
if self.enable_binary_blob:
- self.header.write(f"\nstruct {self.model_struct_name} {{")
+ if self.add_typedef:
+ self.header.write(f"\ntypedef struct {{")
+ else:
+ self.header.write(f"\nstruct {self.model_struct_name} {{")
for name, data in self.layer_dict.items():
layer_type = data[0]
self.header.write(f"\n {layer_type} {name};")
- self.header.write(f"\n}};\n")
+ if self.add_typedef:
+ self.header.write(f"\n}} {self.model_struct_name};\n")
+ else:
+ self.header.write(f"\n}};\n")
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.header.write(f"\n{init_prototype};\n")
diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py
index 65a9645b..2dcee1c5 100644
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -333,7 +333,6 @@ def load_torch_conv2d_weights(where, conv):
def dump_torch_embedding_weights(where, embed, name='embed', scale=1/128, sparse=False, diagonal=False, quantize=False):
- print("quantize = ", quantize)
w = embed.weight.detach().cpu().numpy().copy().transpose()
b = np.zeros(w.shape[0], dtype=w.dtype)