Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-04-22 21:18:00 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-04-22 21:18:00 +0300
commit53e5daa9beb03fa86d7422d9807066f7e6a4009c (patch)
tree99eaa76eacb920e814aa91d06053c4c17b7e2a3f
parenta3fe38c57dd2426f282ef8351e66581a0a96e325 (diff)
057: Code Refactoring - Siamese Architectures
-rwxr-xr-xexamples/sentence_level/wmt_2020/ro_en/siamesetransquest.py18
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/run_model.py186
2 files changed, 105 insertions, 99 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
index 74400fe..480f75b 100755
--- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
+++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py
@@ -80,15 +80,15 @@ if siamesetransquest_config["evaluate_during_training"]:
train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i)
- word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[
- 'max_seq_length'])
-
- pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
- pooling_mode_mean_tokens=True,
- pooling_mode_cls_token=False,
- pooling_mode_max_tokens=False)
-
- model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model])
+ # word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[
+ # 'max_seq_length'])
+ #
+ # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
+ # pooling_mode_mean_tokens=True,
+ # pooling_mode_cls_token=False,
+ # pooling_mode_max_tokens=False)
+
+ model = SiameseTransQuestModel(MODEL_NAME)
train_samples = []
eval_samples = []
diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py
index 2da72f2..66e7c3b 100644
--- a/transquest/algo/sentence_level/siamesetransquest/run_model.py
+++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py
@@ -38,94 +38,100 @@ class SiameseTransQuestModel(nn.Sequential):
:param modules: This parameter can be used to create custom SentenceTransformer models from scratch.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
"""
- def __init__(self, model_name_or_path: str = None, modules: Iterable[nn.Module] = None, device: str = None):
+ def __init__(self, model_name_or_path: str = None, device: str = None):
save_model_to = None
- if model_name_or_path is not None and model_name_or_path != "":
- logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path))
- model_path = model_name_or_path
-
- if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'):
- logger.info("Did not find folder {}".format(model_path))
-
- if '\\' in model_path or model_path.count('/') > 1:
- raise AttributeError("Path {} not found".format(model_path))
-
- model_path = __DOWNLOAD_SERVER__ + model_path + '.zip'
- logger.info("Search model on server: {}".format(model_path))
-
- if model_path.startswith('http://') or model_path.startswith('https://'):
- model_url = model_path
- folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250][0:-4] #remove .zip file end
-
- cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
- if cache_folder is None:
- try:
- from torch.hub import _get_torch_home
- torch_cache_home = _get_torch_home()
- except ImportError:
- torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
-
- cache_folder = os.path.join(torch_cache_home, 'sentence_transformers')
-
- model_path = os.path.join(cache_folder, folder_name)
-
- if not os.path.exists(model_path) or not os.listdir(model_path):
- if os.path.exists(model_path):
- os.remove(model_path)
-
- model_url = model_url.rstrip("/")
- logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path))
-
- model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part"
- try:
- zip_save_path = os.path.join(model_path_tmp, 'model.zip')
- http_get(model_url, zip_save_path)
- with ZipFile(zip_save_path, 'r') as zip:
- zip.extractall(model_path_tmp)
- os.remove(zip_save_path)
- os.rename(model_path_tmp, model_path)
- except requests.exceptions.HTTPError as e:
- shutil.rmtree(model_path_tmp)
- if e.response.status_code == 429:
- raise Exception("Too many requests were detected from this IP for the model {}. Please contact info@nils-reimers.de for more information.".format(model_name_or_path))
-
- if e.response.status_code == 404:
- logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url))
- logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path))
-
- save_model_to = model_path
- model_path = None
- transformer_model = Transformer(model_name_or_path)
- pooling_model = Pooling(transformer_model.get_word_embedding_dimension())
- modules = [transformer_model, pooling_model]
- else:
- raise e
- except Exception as e:
- shutil.rmtree(model_path)
- raise e
-
-
- #### Load from disk
- if model_path is not None:
- logger.info("Load SentenceTransformer from folder: {}".format(model_path))
-
- if os.path.exists(os.path.join(model_path, 'config.json')):
- with open(os.path.join(model_path, 'config.json')) as fIn:
- config = json.load(fIn)
- if config['__version__'] > __version__:
- logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__))
-
- with open(os.path.join(model_path, 'modules.json')) as fIn:
- contained_modules = json.load(fIn)
-
- modules = OrderedDict()
- for module_config in contained_modules:
- module_class = import_from_string(module_config['type'])
- module = module_class.load(os.path.join(model_path, module_config['path']))
- modules[module_config['name']] = module
-
-
+ transformer_model = Transformer(model_name_or_path, max_seq_length=80)
+ pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True,
+ pooling_mode_cls_token=False,
+ pooling_mode_max_tokens=False)
+ modules = [transformer_model, pooling_model]
+
+ # if model_name_or_path is not None and model_name_or_path != "":
+ # logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path))
+ # model_path = model_name_or_path
+ #
+ # if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'):
+ # logger.info("Did not find folder {}".format(model_path))
+ #
+ # if '\\' in model_path or model_path.count('/') > 1:
+ # raise AttributeError("Path {} not found".format(model_path))
+ #
+ # model_path = __DOWNLOAD_SERVER__ + model_path + '.zip'
+ # logger.info("Search model on server: {}".format(model_path))
+ #
+ # if model_path.startswith('http://') or model_path.startswith('https://'):
+ # model_url = model_path
+ # folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250][0:-4] #remove .zip file end
+ #
+ # cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
+ # if cache_folder is None:
+ # try:
+ # from torch.hub import _get_torch_home
+ # torch_cache_home = _get_torch_home()
+ # except ImportError:
+ # torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
+ #
+ # cache_folder = os.path.join(torch_cache_home, 'sentence_transformers')
+ #
+ # model_path = os.path.join(cache_folder, folder_name)
+ #
+ # if not os.path.exists(model_path) or not os.listdir(model_path):
+ # if os.path.exists(model_path):
+ # os.remove(model_path)
+ #
+ # model_url = model_url.rstrip("/")
+ # logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path))
+ #
+ # model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part"
+ # try:
+ # zip_save_path = os.path.join(model_path_tmp, 'model.zip')
+ # http_get(model_url, zip_save_path)
+ # with ZipFile(zip_save_path, 'r') as zip:
+ # zip.extractall(model_path_tmp)
+ # os.remove(zip_save_path)
+ # os.rename(model_path_tmp, model_path)
+ # except requests.exceptions.HTTPError as e:
+ # shutil.rmtree(model_path_tmp)
+ # if e.response.status_code == 429:
+ # raise Exception("Too many requests were detected from this IP for the model {}. Please contact info@nils-reimers.de for more information.".format(model_name_or_path))
+ #
+ # if e.response.status_code == 404:
+ # logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url))
+ # logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path))
+ #
+ # save_model_to = model_path
+ # model_path = None
+ # transformer_model = Transformer(model_name_or_path)
+ # pooling_model = Pooling(transformer_model.get_word_embedding_dimension())
+ # modules = [transformer_model, pooling_model]
+ # else:
+ # raise e
+ # except Exception as e:
+ # shutil.rmtree(model_path)
+ # raise e
+ #
+ #
+ # # #### Load from disk
+ # if model_path is not None:
+ # logger.info("Load SentenceTransformer from folder: {}".format(model_path))
+ #
+ # if os.path.exists(os.path.join(model_path, 'config.json')):
+ # with open(os.path.join(model_path, 'config.json')) as fIn:
+ # config = json.load(fIn)
+ # if config['__version__'] > __version__:
+ # logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__))
+ #
+ # with open(os.path.join(model_path, 'modules.json')) as fIn:
+ # contained_modules = json.load(fIn)
+ #
+ # modules = OrderedDict()
+ # for module_config in contained_modules:
+ # module_class = import_from_string(module_config['type'])
+ # module = module_class.load(os.path.join(model_path, module_config['path']))
+ # modules[module_config['name']] = module
+ #
+ #
if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
@@ -387,10 +393,10 @@ class SiameseTransQuestModel(nn.Sequential):
for idx, name in enumerate(self._modules):
module = self._modules[name]
- model_path = os.path.join(path, str(idx)+"_"+type(module).__name__)
- os.makedirs(model_path, exist_ok=True)
- module.save(model_path)
- contained_modules.append({'idx': idx, 'name': name, 'path': os.path.basename(model_path), 'type': type(module).__module__})
+ # model_path = os.path.join(path, str(idx)+"_"+type(module).__name__)
+ os.makedirs(path, exist_ok=True)
+ module.save(path)
+ contained_modules.append({'idx': idx, 'name': name, 'path': os.path.basename(path), 'type': type(module).__module__})
with open(os.path.join(path, 'modules.json'), 'w') as fOut:
json.dump(contained_modules, fOut, indent=2)