diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-22 21:18:00 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-04-22 21:18:00 +0300 |
commit | 53e5daa9beb03fa86d7422d9807066f7e6a4009c (patch) | |
tree | 99eaa76eacb920e814aa91d06053c4c17b7e2a3f | |
parent | a3fe38c57dd2426f282ef8351e66581a0a96e325 (diff) |
057: Code Refactoring - Siamese Architectures
-rwxr-xr-x | examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py | 18 | ||||
-rw-r--r-- | transquest/algo/sentence_level/siamesetransquest/run_model.py | 186 |
2 files changed, 105 insertions, 99 deletions
diff --git a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py index 74400fe..480f75b 100755 --- a/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py +++ b/examples/sentence_level/wmt_2020/ro_en/siamesetransquest.py @@ -80,15 +80,15 @@ if siamesetransquest_config["evaluate_during_training"]: train_df, eval_df = train_test_split(train, test_size=0.1, random_state=SEED * i) - word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ - 'max_seq_length']) - - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False) - - model = SiameseTransQuestModel(modules=[word_embedding_model, pooling_model]) + # word_embedding_model = models.Transformer(MODEL_NAME, max_seq_length=siamesetransquest_config[ + # 'max_seq_length']) + # + # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), + # pooling_mode_mean_tokens=True, + # pooling_mode_cls_token=False, + # pooling_mode_max_tokens=False) + + model = SiameseTransQuestModel(MODEL_NAME) train_samples = [] eval_samples = [] diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index 2da72f2..66e7c3b 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -38,94 +38,100 @@ class SiameseTransQuestModel(nn.Sequential): :param modules: This parameter can be used to create custom SentenceTransformer models from scratch. :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. """ - def __init__(self, model_name_or_path: str = None, modules: Iterable[nn.Module] = None, device: str = None): + def __init__(self, model_name_or_path: str = None, device: str = None): save_model_to = None - if model_name_or_path is not None and model_name_or_path != "": - logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path)) - model_path = model_name_or_path - - if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'): - logger.info("Did not find folder {}".format(model_path)) - - if '\\' in model_path or model_path.count('/') > 1: - raise AttributeError("Path {} not found".format(model_path)) - - model_path = __DOWNLOAD_SERVER__ + model_path + '.zip' - logger.info("Search model on server: {}".format(model_path)) - - if model_path.startswith('http://') or model_path.startswith('https://'): - model_url = model_path - folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250][0:-4] #remove .zip file end - - cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME') - if cache_folder is None: - try: - from torch.hub import _get_torch_home - torch_cache_home = _get_torch_home() - except ImportError: - torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) - - cache_folder = os.path.join(torch_cache_home, 'sentence_transformers') - - model_path = os.path.join(cache_folder, folder_name) - - if not os.path.exists(model_path) or not os.listdir(model_path): - if os.path.exists(model_path): - os.remove(model_path) - - model_url = model_url.rstrip("/") - logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path)) - - model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part" - try: - zip_save_path = os.path.join(model_path_tmp, 'model.zip') - http_get(model_url, zip_save_path) - with ZipFile(zip_save_path, 'r') as zip: - zip.extractall(model_path_tmp) - os.remove(zip_save_path) - os.rename(model_path_tmp, model_path) - except requests.exceptions.HTTPError as e: - shutil.rmtree(model_path_tmp) - if e.response.status_code == 429: - raise Exception("Too many requests were detected from this IP for the model {}. Please contact info@nils-reimers.de for more information.".format(model_name_or_path)) - - if e.response.status_code == 404: - logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url)) - logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path)) - - save_model_to = model_path - model_path = None - transformer_model = Transformer(model_name_or_path) - pooling_model = Pooling(transformer_model.get_word_embedding_dimension()) - modules = [transformer_model, pooling_model] - else: - raise e - except Exception as e: - shutil.rmtree(model_path) - raise e - - - #### Load from disk - if model_path is not None: - logger.info("Load SentenceTransformer from folder: {}".format(model_path)) - - if os.path.exists(os.path.join(model_path, 'config.json')): - with open(os.path.join(model_path, 'config.json')) as fIn: - config = json.load(fIn) - if config['__version__'] > __version__: - logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__)) - - with open(os.path.join(model_path, 'modules.json')) as fIn: - contained_modules = json.load(fIn) - - modules = OrderedDict() - for module_config in contained_modules: - module_class = import_from_string(module_config['type']) - module = module_class.load(os.path.join(model_path, module_config['path'])) - modules[module_config['name']] = module - - + transformer_model = Transformer(model_name_or_path, max_seq_length=80) + pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, + pooling_mode_cls_token=False, + pooling_mode_max_tokens=False) + modules = [transformer_model, pooling_model] + + # if model_name_or_path is not None and model_name_or_path != "": + # logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path)) + # model_path = model_name_or_path + # + # if not os.path.isdir(model_path) and not model_path.startswith('http://') and not model_path.startswith('https://'): + # logger.info("Did not find folder {}".format(model_path)) + # + # if '\\' in model_path or model_path.count('/') > 1: + # raise AttributeError("Path {} not found".format(model_path)) + # + # model_path = __DOWNLOAD_SERVER__ + model_path + '.zip' + # logger.info("Search model on server: {}".format(model_path)) + # + # if model_path.startswith('http://') or model_path.startswith('https://'): + # model_url = model_path + # folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250][0:-4] #remove .zip file end + # + # cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME') + # if cache_folder is None: + # try: + # from torch.hub import _get_torch_home + # torch_cache_home = _get_torch_home() + # except ImportError: + # torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) + # + # cache_folder = os.path.join(torch_cache_home, 'sentence_transformers') + # + # model_path = os.path.join(cache_folder, folder_name) + # + # if not os.path.exists(model_path) or not os.listdir(model_path): + # if os.path.exists(model_path): + # os.remove(model_path) + # + # model_url = model_url.rstrip("/") + # logger.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path)) + # + # model_path_tmp = model_path.rstrip("/").rstrip("\\")+"_part" + # try: + # zip_save_path = os.path.join(model_path_tmp, 'model.zip') + # http_get(model_url, zip_save_path) + # with ZipFile(zip_save_path, 'r') as zip: + # zip.extractall(model_path_tmp) + # os.remove(zip_save_path) + # os.rename(model_path_tmp, model_path) + # except requests.exceptions.HTTPError as e: + # shutil.rmtree(model_path_tmp) + # if e.response.status_code == 429: + # raise Exception("Too many requests were detected from this IP for the model {}. Please contact info@nils-reimers.de for more information.".format(model_name_or_path)) + # + # if e.response.status_code == 404: + # logger.warning('SentenceTransformer-Model {} not found. Try to create it from scratch'.format(model_url)) + # logger.warning('Try to create Transformer Model {} with mean pooling'.format(model_name_or_path)) + # + # save_model_to = model_path + # model_path = None + # transformer_model = Transformer(model_name_or_path) + # pooling_model = Pooling(transformer_model.get_word_embedding_dimension()) + # modules = [transformer_model, pooling_model] + # else: + # raise e + # except Exception as e: + # shutil.rmtree(model_path) + # raise e + # + # + # # #### Load from disk + # if model_path is not None: + # logger.info("Load SentenceTransformer from folder: {}".format(model_path)) + # + # if os.path.exists(os.path.join(model_path, 'config.json')): + # with open(os.path.join(model_path, 'config.json')) as fIn: + # config = json.load(fIn) + # if config['__version__'] > __version__: + # logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(config['__version__'], __version__)) + # + # with open(os.path.join(model_path, 'modules.json')) as fIn: + # contained_modules = json.load(fIn) + # + # modules = OrderedDict() + # for module_config in contained_modules: + # module_class = import_from_string(module_config['type']) + # module = module_class.load(os.path.join(model_path, module_config['path'])) + # modules[module_config['name']] = module + # + # if modules is not None and not isinstance(modules, OrderedDict): modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)]) @@ -387,10 +393,10 @@ class SiameseTransQuestModel(nn.Sequential): for idx, name in enumerate(self._modules): module = self._modules[name] - model_path = os.path.join(path, str(idx)+"_"+type(module).__name__) - os.makedirs(model_path, exist_ok=True) - module.save(model_path) - contained_modules.append({'idx': idx, 'name': name, 'path': os.path.basename(model_path), 'type': type(module).__module__}) + # model_path = os.path.join(path, str(idx)+"_"+type(module).__name__) + os.makedirs(path, exist_ok=True) + module.save(path) + contained_modules.append({'idx': idx, 'name': name, 'path': os.path.basename(path), 'type': type(module).__module__}) with open(os.path.join(path, 'modules.json'), 'w') as fOut: json.dump(contained_modules, fOut, indent=2) |