diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 01:02:14 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-23 01:02:14 +0300 |
commit | 30ccc5ddedcaab2d2bc08457eef093a471e0197b (patch) | |
tree | f6d69754e4c176e07aff1f99e7d78efec5bf26ed | |
parent | 6d7f101d962980238d19a32b7219636526e45e1f (diff) |
057: Code Refactoring - Siamese Architectures
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/models/Pooling.py | 2 | ||||
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/run_model.py | 4 |
2 files changed, 3 insertions, 3 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py b/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py index efafce4..0ecdf20 100644 --- a/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py +++ b/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py @@ -85,7 +85,7 @@ class Pooling(nn.Module): @staticmethod def load(input_path): - with open(os.path.join(input_path, 'config.json')) as fIn: + with open(os.path.join(input_path, 'pooling_config.json')) as fIn: config = json.load(fIn) return Pooling(**config) diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index 8148894..a271420 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -401,8 +401,8 @@ class SiameseTransQuestModel(nn.Sequential): with open(os.path.join(path, 'modules.json'), 'w') as fOut: json.dump(contained_modules, fOut, indent=2) - # with open(os.path.join(path, 'config.json'), 'w') as fOut: - # json.dump({'__version__': __version__}, fOut, indent=2) + with open(os.path.join(path, 'siamese_config.json'), 'w') as fOut: + json.dump({'__version__': __version__}, fOut, indent=2) def smart_batching_collate(self, batch): """ |