Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/c_export/c_writer.py')
-rw-r--r--dnn/torch/weight-exchange/wexchange/c_export/c_writer.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
index 36050881..2745f337 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
@@ -38,7 +38,8 @@ class CWriter:
create_state_struct=False,
enable_binary_blob=True,
model_struct_name="Model",
- nnet_header="nnet.h"):
+ nnet_header="nnet.h",
+ add_typedef=False):
"""
Writer class for creating souce and header files for weight exports to C
@@ -73,6 +74,7 @@ class CWriter:
self.enable_binary_blob = enable_binary_blob
self.create_state_struct = create_state_struct
self.model_struct_name = model_struct_name
+ self.add_typedef = add_typedef
# for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>)
self.layer_dict = OrderedDict()
@@ -119,11 +121,17 @@ f"""
# create model type
if self.enable_binary_blob:
- self.header.write(f"\nstruct {self.model_struct_name} {{")
+ if self.add_typedef:
+ self.header.write(f"\ntypedef struct {{")
+ else:
+ self.header.write(f"\nstruct {self.model_struct_name} {{")
for name, data in self.layer_dict.items():
layer_type = data[0]
self.header.write(f"\n {layer_type} {name};")
- self.header.write(f"\n}};\n")
+ if self.add_typedef:
+ self.header.write(f"\n}} {self.model_struct_name};\n")
+ else:
+ self.header.write(f"\n}};\n")
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.header.write(f"\n{init_prototype};\n")