Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'transquest/algo/sentence_level/siamesetransquest/models/Transformer.py')
-rw-r--r--transquest/algo/sentence_level/siamesetransquest/models/Transformer.py35
1 files changed, 17 insertions, 18 deletions
diff --git a/transquest/algo/sentence_level/siamesetransquest/models/Transformer.py b/transquest/algo/sentence_level/siamesetransquest/models/Transformer.py
index f17d382..aac9aa0 100644
--- a/transquest/algo/sentence_level/siamesetransquest/models/Transformer.py
+++ b/transquest/algo/sentence_level/siamesetransquest/models/Transformer.py
@@ -1,8 +1,9 @@
-from torch import nn
-from transformers import AutoModel, AutoTokenizer, AutoConfig
import json
-from typing import List, Dict, Optional, Union, Tuple
import os
+from typing import List, Dict, Optional, Union, Tuple
+
+from torch import nn
+from transformers import AutoModel, AutoTokenizer, AutoConfig
class Transformer(nn.Module):
@@ -16,6 +17,7 @@ class Transformer(nn.Module):
:param tokenizer_args: Arguments (key, value pairs) passed to the Huggingface Tokenizer model
:param do_lower_case: If true, lowercases the input (independet if the model is cased or not)
"""
+
def __init__(self, model_name_or_path: str, max_seq_length: Optional[int] = None,
model_args: Dict = {}, cache_dir: Optional[str] = None,
tokenizer_args: Dict = {}, do_lower_case: bool = False):
@@ -38,11 +40,12 @@ class Transformer(nn.Module):
output_tokens = output_states[0]
cls_tokens = output_tokens[:, 0, :] # CLS token is first token
- features.update({'token_embeddings': output_tokens, 'cls_token_embeddings': cls_tokens, 'attention_mask': features['attention_mask']})
+ features.update({'token_embeddings': output_tokens, 'cls_token_embeddings': cls_tokens,
+ 'attention_mask': features['attention_mask']})
if self.auto_model.config.output_hidden_states:
all_layer_idx = 2
- if len(output_states) < 3: #Some models only output last_hidden_states and all_hidden_states
+ if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
@@ -75,18 +78,17 @@ class Transformer(nn.Module):
batch2.append(text_tuple[1])
to_tokenize = [batch1, batch2]
- #strip
+ # strip
to_tokenize = [[s.strip() for s in col] for col in to_tokenize]
- #Lowercase
+ # Lowercase
if self.do_lower_case:
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
-
- output.update(self.tokenizer(*to_tokenize, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_seq_length))
+ output.update(self.tokenizer(*to_tokenize, padding=True, truncation='longest_first', return_tensors="pt",
+ max_length=self.max_seq_length))
return output
-
def get_config_dict(self):
return {key: self.__dict__[key] for key in self.config_keys}
@@ -99,8 +101,11 @@ class Transformer(nn.Module):
@staticmethod
def load(input_path: str):
- #Old classes used other config names than 'sentence_bert_config.json'
- for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json', 'sentence_distilbert_config.json', 'sentence_camembert_config.json', 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json', 'sentence_xlnet_config.json']:
+ # Old classes used other config names than 'sentence_bert_config.json'
+ for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json',
+ 'sentence_distilbert_config.json', 'sentence_camembert_config.json',
+ 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json',
+ 'sentence_xlnet_config.json']:
sbert_config_path = os.path.join(input_path, config_name)
if os.path.exists(sbert_config_path):
break
@@ -108,9 +113,3 @@ class Transformer(nn.Module):
with open(sbert_config_path) as fIn:
config = json.load(fIn)
return Transformer(model_name_or_path=input_path, **config)
-
-
-
-
-
-