diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-23 20:48:25 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-11-23 20:48:25 +0300 |
commit | 85bddfcba421122526c034d709d1b94862ff3c3a (patch) | |
tree | 596d55cb370abe852696cc38b324b9249c70f090 | |
parent | 96a8e414356b23bf1cb91eab6797abdc663d7967 (diff) |
added add_typedef option to CWriter
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/c_writer.py | 14 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/torch/torch.py | 1 |
2 files changed, 11 insertions, 4 deletions
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py index 36050881..2745f337 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py @@ -38,7 +38,8 @@ class CWriter: create_state_struct=False, enable_binary_blob=True, model_struct_name="Model", - nnet_header="nnet.h"): + nnet_header="nnet.h", + add_typedef=False): """ Writer class for creating souce and header files for weight exports to C @@ -73,6 +74,7 @@ class CWriter: self.enable_binary_blob = enable_binary_blob self.create_state_struct = create_state_struct self.model_struct_name = model_struct_name + self.add_typedef = add_typedef # for binary blob format, format is key=<layer name>, value=(<layer type>, <init call>) self.layer_dict = OrderedDict() @@ -119,11 +121,17 @@ f""" # create model type if self.enable_binary_blob: - self.header.write(f"\nstruct {self.model_struct_name} {{") + if self.add_typedef: + self.header.write(f"\ntypedef struct {{") + else: + self.header.write(f"\nstruct {self.model_struct_name} {{") for name, data in self.layer_dict.items(): layer_type = data[0] self.header.write(f"\n {layer_type} {name};") - self.header.write(f"\n}};\n") + if self.add_typedef: + self.header.write(f"\n}} {self.model_struct_name};\n") + else: + self.header.write(f"\n}};\n") init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" self.header.write(f"\n{init_prototype};\n") diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index 65a9645b..2dcee1c5 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -333,7 +333,6 @@ def load_torch_conv2d_weights(where, conv): def dump_torch_embedding_weights(where, embed, name='embed', scale=1/128, sparse=False, diagonal=False, quantize=False): - print("quantize = ", quantize) w = embed.weight.detach().cpu().numpy().copy().transpose() b = np.zeros(w.shape[0], dtype=w.dtype) |