Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 01:02:14 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-23 01:02:14 +0300
commit30ccc5ddedcaab2d2bc08457eef093a471e0197b (patch)
treef6d69754e4c176e07aff1f99e7d78efec5bf26ed
parent6d7f101d962980238d19a32b7219636526e45e1f (diff)
057: Code Refactoring - Siamese Architectures
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/models/Pooling.py2
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py4
2 files changed, 3 insertions, 3 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py b/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py
index efafce4..0ecdf20 100644
--- a/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py
+++ b/transquest/algo/sentence_level/siamesetransquest/models/Pooling.py
@@ -85,7 +85,7 @@ class Pooling(nn.Module):
@staticmethod
def load(input_path):
- with open(os.path.join(input_path, 'config.json')) as fIn:
+ with open(os.path.join(input_path, 'pooling_config.json')) as fIn:
config = json.load(fIn)
return Pooling(**config)
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 8148894..a271420 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -401,8 +401,8 @@ class SiameseTransQuestModel(nn.Sequential):
with open(os.path.join(path, 'modules.json'), 'w') as fOut:
json.dump(contained_modules, fOut, indent=2)
- # with open(os.path.join(path, 'config.json'), 'w') as fOut:
- # json.dump({'__version__': __version__}, fOut, indent=2)
+ with open(os.path.join(path, 'siamese_config.json'), 'w') as fOut:
+ json.dump({'__version__': __version__}, fOut, indent=2)
def smart_batching_collate(self, batch):
"""