diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-07-28 02:43:17 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-07-28 02:43:17 +0300 |
commit | d663b0f39ae66dfa5a1c1d8c94ebaddba12a06af (patch) | |
tree | c4ff223ef2520297323bcab2b456c3a93f956d8e | |
parent | 1b81d8a7e252c92f4cdd16710a9600467136160c (diff) |
switched to FIFO ordering in weight-exchange and switched to direct include in export_rdovaeopus-ng-linear
-rw-r--r-- | dnn/torch/rdovae/export_rdovae_weights.py | 3 | ||||
-rw-r--r-- | dnn/torch/rdovae/libs/wexchange-1.5-py3-none-any.whl | bin | 13762 -> 0 bytes | |||
-rw-r--r-- | dnn/torch/rdovae/requirements.txt | 3 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/c_writer.py | 4 | ||||
-rw-r--r-- | dnn/torch/weight-exchange/wexchange/c_export/common.py | 2 |
5 files changed, 8 insertions, 4 deletions
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index 8a6d0c03..f9c1db81 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -29,6 +29,9 @@ import os import argparse +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) parser = argparse.ArgumentParser() diff --git a/dnn/torch/rdovae/libs/wexchange-1.5-py3-none-any.whl b/dnn/torch/rdovae/libs/wexchange-1.5-py3-none-any.whl Binary files differdeleted file mode 100644 index 44410fad..00000000 --- a/dnn/torch/rdovae/libs/wexchange-1.5-py3-none-any.whl +++ /dev/null diff --git a/dnn/torch/rdovae/requirements.txt b/dnn/torch/rdovae/requirements.txt index b002ff79..9225ea84 100644 --- a/dnn/torch/rdovae/requirements.txt +++ b/dnn/torch/rdovae/requirements.txt @@ -1,5 +1,4 @@ numpy scipy torch -tqdm -libs/wexchange-1.5-py3-none-any.whl
\ No newline at end of file +tqdm
\ No newline at end of file diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py index c3a1e2cf..36050881 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py @@ -78,7 +78,7 @@ class CWriter: self.layer_dict = OrderedDict() # for binary blob format, format is key=<layer name>, value=<layer type> - self.weight_arrays = set() + self.weight_arrays = [] # form model struct, format is key=<layer name>, value=<number of elements> self.state_dict = OrderedDict() @@ -134,6 +134,8 @@ f""" if self.enable_binary_blob: # create weight array + if len(set(self.weight_arrays)) != len(self.weight_arrays): + raise ValueError("error: detected duplicates in weight arrays") self.source.write("\n#ifndef USE_WEIGHTS_FILE\n") self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n") for name in self.weight_arrays: diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index 1e723d33..d8b3f7e7 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -54,7 +54,7 @@ f''' #ifndef USE_WEIGHTS_FILE ''' ) - writer.weight_arrays.add(name) + writer.weight_arrays.append(name) if reshape_8x4: vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8)) |