diff options
Diffstat (limited to 'dnn/torch/weight-exchange/wexchange/c_export/c_writer.py')
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/c_writer.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py index 36050881..2745f337 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py @@ -38,7 +38,8 @@ class CWriter: create_state_struct=False, enable_binary_blob=True, model_struct_name="Model", - nnet_header="nnet.h"): + nnet_header="nnet.h", + add_typedef=False): """ Writer class for creating souce and header files for weight exports to C @@ -73,6 +74,7 @@ class CWriter: self.enable_binary_blob = enable_binary_blob self.create_state_struct = create_state_struct self.model_struct_name = model_struct_name + self.add_typedef = add_typedef # for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>) self.layer_dict = OrderedDict() @@ -119,11 +121,17 @@ f""" # create model type if self.enable_binary_blob: - self.header.write(f"\nstruct {self.model_struct_name} {{") + if self.add_typedef: + self.header.write(f"\ntypedef struct {{") + else: + self.header.write(f"\nstruct {self.model_struct_name} {{") for name, data in self.layer_dict.items(): layer_type = data[0] self.header.write(f"\n {layer_type} {name};") - self.header.write(f"\n}};\n") + if self.add_typedef: + self.header.write(f"\n}} {self.model_struct_name};\n") + else: + self.header.write(f"\n}};\n") init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" self.header.write(f"\n{init_prototype};\n") |