diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-02-12 15:10:37 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-02-12 15:10:37 +0300 |
commit | 5f531cb8792862fb809cfcc7529e7438b80f6e5f (patch) | |
tree | a6b7414bd788f57c6fcffcb1cf3604a1c408d4e8 | |
parent | 2b2f832085cb55e4ae4a09eca243e03a5c255c68 (diff) |
055: Adding word level examples
21 files changed, 93 insertions, 311 deletions
diff --git a/examples/app/__init__.py b/examples/app/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/examples/app/__init__.py +++ /dev/null diff --git a/examples/app/monotransquest_app_test.py b/examples/app/monotransquest_app_test.py deleted file mode 100644 index 433bfce..0000000 --- a/examples/app/monotransquest_app_test.py +++ /dev/null @@ -1,17 +0,0 @@ -import logging - -from transquest.app.monotransquest_app import MonoTransQuestApp - - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - -test_sentences = [ - [ - "Jocurile de oferă noi provocări pentru IA în domeniul teoriei jocurilor.", - "Games provide new challenges for IA in the area of gambling theory" - ] -] - -app = MonoTransQuestApp("monotransquest-da-ro_en", use_cuda=False, force_download=True) -print(app.predict_quality(test_sentences)) diff --git a/transquest/algo/sentence_level/monotransquest/models/bert_model.py b/transquest/algo/sentence_level/monotransquest/models/bert_model.py index ca20dd6..b8ec4a4 100755 --- a/transquest/algo/sentence_level/monotransquest/models/bert_model.py +++ b/transquest/algo/sentence_level/monotransquest/models/bert_model.py @@ -1,4 +1,3 @@ -import torch import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel @@ -44,14 +43,14 @@ class BertForSequenceClassification(BertPreTrainedModel): self.init_weights() def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, ): outputs = self.bert( diff --git a/transquest/algo/sentence_level/monotransquest/models/distilbert_model.py b/transquest/algo/sentence_level/monotransquest/models/distilbert_model.py index 4de0b38..2710618 100755 --- a/transquest/algo/sentence_level/monotransquest/models/distilbert_model.py +++ b/transquest/algo/sentence_level/monotransquest/models/distilbert_model.py @@ -1,5 +1,5 @@ import torch.nn as nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.models.distilbert.modeling_distilbert import DistilBertModel, DistilBertPreTrainedModel @@ -44,7 +44,8 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): self.init_weights() def forward( - self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, class_weights=None, + self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, + class_weights=None, ): distilbert_output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask) hidden_state = distilbert_output[0] # (bs, seq_len, dim) diff --git a/transquest/algo/sentence_level/monotransquest/models/roberta_model.py b/transquest/algo/sentence_level/monotransquest/models/roberta_model.py index 6d10790..41b190a 100755 --- a/transquest/algo/sentence_level/monotransquest/models/roberta_model.py +++ b/transquest/algo/sentence_level/monotransquest/models/roberta_model.py @@ -1,6 +1,5 @@ -import torch -import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss +from transformers import BertPreTrainedModel from transformers.models.roberta.modeling_roberta import ( ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, RobertaClassificationHead, @@ -8,8 +7,6 @@ from transformers.models.roberta.modeling_roberta import ( RobertaModel, ) -from transformers import BertPreTrainedModel - class RobertaForSequenceClassification(BertPreTrainedModel): r""" @@ -51,14 +48,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel): self.weight = weight def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, ): outputs = self.roberta( input_ids, diff --git a/transquest/algo/sentence_level/monotransquest/models/xlm_model.py b/transquest/algo/sentence_level/monotransquest/models/xlm_model.py index 6bb9fca..c7c8380 100755 --- a/transquest/algo/sentence_level/monotransquest/models/xlm_model.py +++ b/transquest/algo/sentence_level/monotransquest/models/xlm_model.py @@ -1,5 +1,3 @@ -import torch -import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss from transformers.models.xlm.modeling_xlm import SequenceSummary, XLMModel, XLMPreTrainedModel @@ -43,17 +41,17 @@ class XLMForSequenceClassification(XLMPreTrainedModel): self.init_weights() def forward( - self, - input_ids=None, - attention_mask=None, - langs=None, - token_type_ids=None, - position_ids=None, - lengths=None, - cache=None, - head_mask=None, - inputs_embeds=None, - labels=None, + self, + input_ids=None, + attention_mask=None, + langs=None, + token_type_ids=None, + position_ids=None, + lengths=None, + cache=None, + head_mask=None, + inputs_embeds=None, + labels=None, ): transformer_outputs = self.transformer( input_ids, diff --git a/transquest/algo/sentence_level/siamesetransquest/datasets/sentence_label_dataset.py b/transquest/algo/sentence_level/siamesetransquest/datasets/sentence_label_dataset.py index 09cb6aa..66b7ba7 100644 --- a/transquest/algo/sentence_level/siamesetransquest/datasets/sentence_label_dataset.py +++ b/transquest/algo/sentence_level/siamesetransquest/datasets/sentence_label_dataset.py @@ -7,8 +7,8 @@ import torch from torch.utils.data import Dataset from tqdm import tqdm -from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel class SentenceLabelDataset(Dataset): @@ -84,7 +84,7 @@ class SentenceLabelDataset(Dataset): if hasattr(model, 'max_seq_length') and model.max_seq_length is not None and model.max_seq_length > 0 and len( - tokenized_text) >= model.max_seq_length: + tokenized_text) >= model.max_seq_length: too_long += 1 if example.label in label_sent_mapping: label_sent_mapping[example.label].append(ex_index) diff --git a/transquest/algo/sentence_level/siamesetransquest/datasets/sentences_dataset.py b/transquest/algo/sentence_level/siamesetransquest/datasets/sentences_dataset.py index 4c4b693..d964ea5 100644 --- a/transquest/algo/sentence_level/siamesetransquest/datasets/sentences_dataset.py +++ b/transquest/algo/sentence_level/siamesetransquest/datasets/sentences_dataset.py @@ -5,8 +5,8 @@ import torch from torch.utils.data import Dataset from tqdm import tqdm -from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel from transquest.algo.sentence_level.siamesetransquest.readers.input_example import InputExample +from transquest.algo.sentence_level.siamesetransquest.run_model import SiameseTransQuestModel class SentencesDataset(Dataset): @@ -23,7 +23,7 @@ class SentencesDataset(Dataset): """ if show_progress_bar is None: show_progress_bar = ( - logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) + logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) self.show_progress_bar = show_progress_bar self.convert_input_examples(examples, model) diff --git a/transquest/algo/sentence_level/siamesetransquest/evaluation/binary_embedding_similarity_evaluator.py b/transquest/algo/sentence_level/siamesetransquest/evaluation/binary_embedding_similarity_evaluator.py index cb773fe..770f1e0 100644 --- a/transquest/algo/sentence_level/siamesetransquest/evaluation/binary_embedding_similarity_evaluator.py +++ b/transquest/algo/sentence_level/siamesetransquest/evaluation/binary_embedding_similarity_evaluator.py @@ -8,8 +8,8 @@ from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_d from torch.utils.data import DataLoader from tqdm import tqdm -from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.evaluation import SentenceEvaluator +from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.util import batch_to_device diff --git a/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py b/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py index 8d6d7d6..b63a95b 100644 --- a/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py +++ b/transquest/algo/sentence_level/siamesetransquest/evaluation/embedding_similarity_evaluator.py @@ -9,8 +9,8 @@ from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_d from torch.utils.data import DataLoader from tqdm import tqdm -from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.evaluation import SentenceEvaluator +from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.util import batch_to_device @@ -44,7 +44,7 @@ class EmbeddingSimilarityEvaluator(SentenceEvaluator): if show_progress_bar is None: show_progress_bar = ( - logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) + logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) self.show_progress_bar = show_progress_bar self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/transquest/algo/sentence_level/siamesetransquest/evaluation/translation_evaluator.py b/transquest/algo/sentence_level/siamesetransquest/evaluation/translation_evaluator.py index 36382f5..09f2b49 100644 --- a/transquest/algo/sentence_level/siamesetransquest/evaluation/translation_evaluator.py +++ b/transquest/algo/sentence_level/siamesetransquest/evaluation/translation_evaluator.py @@ -8,8 +8,8 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm -from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.evaluation import SentenceEvaluator +from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.util import batch_to_device @@ -40,7 +40,7 @@ class TranslationEvaluator(SentenceEvaluator): if show_progress_bar is None: show_progress_bar = ( - logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) + logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) self.show_progress_bar = show_progress_bar self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/transquest/algo/sentence_level/siamesetransquest/evaluation/triplet_evaluator.py b/transquest/algo/sentence_level/siamesetransquest/evaluation/triplet_evaluator.py index 2be2c69..95590f1 100644 --- a/transquest/algo/sentence_level/siamesetransquest/evaluation/triplet_evaluator.py +++ b/transquest/algo/sentence_level/siamesetransquest/evaluation/triplet_evaluator.py @@ -7,8 +7,8 @@ from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_d from torch.utils.data import DataLoader from tqdm import tqdm -from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction from transquest.algo.sentence_level.siamesetransquest.evaluation import SentenceEvaluator +from transquest.algo.sentence_level.siamesetransquest.evaluation import SimilarityFunction class TripletEvaluator(SentenceEvaluator): diff --git a/transquest/algo/sentence_level/siamesetransquest/run_model.py b/transquest/algo/sentence_level/siamesetransquest/run_model.py index d9bac73..d485465 100644 --- a/transquest/algo/sentence_level/siamesetransquest/run_model.py +++ b/transquest/algo/sentence_level/siamesetransquest/run_model.py @@ -116,7 +116,7 @@ class SiameseTransQuestModel(nn.Sequential): self.eval() if show_progress_bar is None: show_progress_bar = ( - logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) + logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG) all_embeddings = [] length_sorted_idx = np.argsort([len(sen) for sen in sentences]) diff --git a/transquest/algo/word_level/microtransquest/format.py b/transquest/algo/word_level/microtransquest/format.py index 1543dd6..607c2c6 100644 --- a/transquest/algo/word_level/microtransquest/format.py +++ b/transquest/algo/word_level/microtransquest/format.py @@ -9,7 +9,8 @@ def prepare_data(raw_df, args): sentence_id = 0 data = [] - for source_sentence, source_tag_line, target_sentence, target_tag_lind in zip(source_sentences, source_tags, target_sentences, target_tags): + for source_sentence, source_tag_line, target_sentence, target_tag_lind in zip(source_sentences, source_tags, + target_sentences, target_tags): for word, tag in zip(source_sentence.split(), source_tag_line.split()): data.append([sentence_id, word, tag]) @@ -18,7 +19,6 @@ def prepare_data(raw_df, args): target_words = target_sentence.split() target_tags = target_tag_lind.split() - data.append([sentence_id, args["tag"], target_tags.pop(0)]) for word in target_words: @@ -31,7 +31,6 @@ def prepare_data(raw_df, args): def prepare_testdata(raw_df, args): - source_sentences = raw_df[args["source_column"]].tolist() target_sentences = raw_df[args["target_column"]].tolist() @@ -77,7 +76,8 @@ def post_process(predicted_sentences, test_sentences, args): assert len(source_tags) == len(source_sentence.split()) if len(target_sentence.split()) > len(target_tags): - target_tags = target_tags + [args["default_quality"] for x in range(len(target_sentence.split()) - len(target_tags))] + target_tags = target_tags + [args["default_quality"] for x in + range(len(target_sentence.split()) - len(target_tags))] assert len(target_tags) == len(target_sentence.split()) sources_tags.append(source_tags) @@ -85,7 +85,6 @@ def post_process(predicted_sentences, test_sentences, args): return sources_tags, targets_tags - # def post_process(predicted_sentences, test_sentences): # sources_tags = [] # targets_tags = [] @@ -108,7 +107,3 @@ def post_process(predicted_sentences, test_sentences, args): # targets_tags.append(target_tags) # # return sources_tags, targets_tags - - - - diff --git a/transquest/algo/word_level/microtransquest/model_args.py b/transquest/algo/word_level/microtransquest/model_args.py index 3f81826..4a5ca66 100644 --- a/transquest/algo/word_level/microtransquest/model_args.py +++ b/transquest/algo/word_level/microtransquest/model_args.py @@ -19,4 +19,3 @@ class MicroTransQuestArgs(TransQuestArgs): add_tag: bool = False tag: str = "<gap>" default_quality: str = "OK" - diff --git a/transquest/algo/word_level/microtransquest/run_model.py b/transquest/algo/word_level/microtransquest/run_model.py index 57c5612..d75b184 100755 --- a/transquest/algo/word_level/microtransquest/run_model.py +++ b/transquest/algo/word_level/microtransquest/run_model.py @@ -1,72 +1,48 @@ from __future__ import absolute_import, division, print_function import glob -import json import logging import math import os import random import shutil +import tempfile import warnings from dataclasses import asdict -from multiprocessing import cpu_count -import tempfile from pathlib import Path import numpy as np import pandas as pd import torch -from scipy.stats import pearsonr from seqeval.metrics import classification_report, f1_score, precision_score, recall_score from tensorboardX import SummaryWriter from torch.nn import CrossEntropyLoss from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from tqdm.auto import tqdm, trange -from transformers.optimization import ( - get_constant_schedule, - get_constant_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, -) -from transformers.optimization import AdamW, Adafactor from transformers import ( - WEIGHTS_NAME, - AutoConfig, - AutoModelForTokenClassification, - AutoTokenizer, BertConfig, BertForTokenClassification, BertTokenizer, - BertweetTokenizer, - CamembertConfig, - CamembertForTokenClassification, - CamembertTokenizer, DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer, - ElectraConfig, - ElectraForTokenClassification, - ElectraTokenizer, - LongformerConfig, - LongformerForTokenClassification, - LongformerTokenizer, - MobileBertConfig, - MobileBertForTokenClassification, - MobileBertTokenizer, RobertaConfig, RobertaForTokenClassification, RobertaTokenizer, XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer, - LayoutLMConfig, - LayoutLMForTokenClassification, - LayoutLMTokenizer, ) -from wandb import config from transformers.convert_graph_to_onnx import convert, quantize +from transformers.optimization import AdamW, Adafactor +from transformers.optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, +) from transquest.algo.word_level.microtransquest.model_args import MicroTransQuestArgs from transquest.algo.word_level.microtransquest.utils import sweep_config_to_sweep_values, InputExample, \ @@ -84,15 +60,15 @@ logger = logging.getLogger(__name__) class MicroTransQuestModel: def __init__( - self, - model_type, - model_name, - labels=None, - args=None, - use_cuda=True, - cuda_device=-1, - onnx_execution_provider=None, - **kwargs, + self, + model_type, + model_name, + labels=None, + args=None, + use_cuda=True, + cuda_device=-1, + onnx_execution_provider=None, + **kwargs, ): """ Initializes a NERModel @@ -251,7 +227,7 @@ class MicroTransQuestModel: self.args.wandb_project = None def train_model( - self, train_data, output_dir=None, show_running_loss=True, args=None, eval_data=None, verbose=True, **kwargs + self, train_data, output_dir=None, show_running_loss=True, args=None, eval_data=None, verbose=True, **kwargs ): """ Trains the model using 'train_data' @@ -483,7 +459,7 @@ class MicroTransQuestModel: global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( - len(train_dataloader) // args.gradient_accumulation_steps + len(train_dataloader) // args.gradient_accumulation_steps ) logger.info(" Continuing training from checkpoint, will skip to saved global_step") @@ -596,8 +572,8 @@ class MicroTransQuestModel: self.save_model(output_dir_current, optimizer, scheduler, model=model) if args.evaluate_during_training and ( - args.evaluate_during_training_steps > 0 - and global_step % args.evaluate_during_training_steps == 0 + args.evaluate_during_training_steps > 0 + and global_step % args.evaluate_during_training_steps == 0 ): output_dir_current = os.path.join(output_dir, "checkpoint-{}".format(global_step)) @@ -1031,7 +1007,7 @@ class MicroTransQuestModel: else: preds = np.append(preds, output[0], axis=0) out_input_ids = np.append(out_input_ids, inputs_onnx["input_ids"], axis=0) - out_attention_mask = np.append(out_attention_mask, inputs_onnx["attention_mask"], axis=0,) + out_attention_mask = np.append(out_attention_mask, inputs_onnx["attention_mask"], axis=0, ) out_label_ids = np.zeros_like(out_input_ids) for index in range(len(out_label_ids)): out_label_ids[index][0] = -100 @@ -1209,8 +1185,8 @@ class MicroTransQuestModel: os.makedirs(self.args.cache_dir, exist_ok=True) if os.path.exists(cached_features_file) and ( - (not args.reprocess_input_data and not no_cache) - or (mode == "dev" and args.use_cached_eval_features and not no_cache) + (not args.reprocess_input_data and not no_cache) + or (mode == "dev" and args.use_cached_eval_features and not no_cache) ): features = torch.load(cached_features_file) logger.info(f" Features loaded from cache at {cached_features_file}") diff --git a/transquest/algo/word_level/microtransquest/utils.py b/transquest/algo/word_level/microtransquest/utils.py index 774079b..372a18a 100755 --- a/transquest/algo/word_level/microtransquest/utils.py +++ b/transquest/algo/word_level/microtransquest/utils.py @@ -18,12 +18,9 @@ from __future__ import absolute_import, division, print_function import linecache -import logging -import os from io import open from multiprocessing import Pool, cpu_count -import pandas as pd import torch from torch.functional import split from torch.nn import CrossEntropyLoss @@ -161,7 +158,7 @@ def get_examples_from_df(data, bbox=False): ] else: return [ - InputExample(guid=sentence_id, words=sentence_df["words"].tolist(), labels=sentence_df["labels"].tolist(),) + InputExample(guid=sentence_id, words=sentence_df["words"].tolist(), labels=sentence_df["labels"].tolist(), ) for sentence_id, sentence_df in data.groupby(["sentence_id"]) ] @@ -290,29 +287,29 @@ def convert_example_to_feature(example_row): input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, bboxes=bboxes ) else: - return InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids,) + return InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, ) def convert_examples_to_features( - examples, - label_list, - max_seq_length, - tokenizer, - cls_token_at_end=False, - cls_token="[CLS]", - cls_token_segment_id=1, - sep_token="[SEP]", - sep_token_extra=False, - pad_on_left=False, - pad_token=0, - pad_token_segment_id=0, - pad_token_label_id=-1, - sequence_a_segment_id=0, - mask_padding_with_zero=True, - process_count=cpu_count() - 2, - chunksize=500, - silent=False, - use_multiprocessing=True, + examples, + label_list, + max_seq_length, + tokenizer, + cls_token_at_end=False, + cls_token="[CLS]", + cls_token_segment_id=1, + sep_token="[SEP]", + sep_token_extra=False, + pad_on_left=False, + pad_token=0, + pad_token_segment_id=0, + pad_token_label_id=-1, + sequence_a_segment_id=0, + mask_padding_with_zero=True, + process_count=cpu_count() - 2, + chunksize=500, + silent=False, + use_multiprocessing=True, ): """ Loads a data file into a list of `InputBatch`s `cls_token_at_end` define the location of the CLS token: diff --git a/transquest/app/__init__.py b/transquest/app/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/transquest/app/__init__.py +++ /dev/null diff --git a/transquest/app/monotransquest_app.py b/transquest/app/monotransquest_app.py deleted file mode 100644 index 8187259..0000000 --- a/transquest/app/monotransquest_app.py +++ /dev/null @@ -1,59 +0,0 @@ -import logging -import os - -from transquest.algo.sentence_level.monotransquest.run_model import MonoTransQuestModel -from transquest.app.util.model_downloader import GoogleDriveDownloader as gdd - -logger = logging.getLogger(__name__) - - -class MonoTransQuestApp: - def __init__(self, model_name_or_path, model_type=None, use_cuda=True, force_download=False, cuda_device=-1): - - self.model_name_or_path = model_name_or_path - self.model_type = model_type - self.use_cuda = use_cuda - self.cuda_device = cuda_device - - - MODEL_CONFIG = { - "monotransquest-da-si_en": ("xlmroberta", "1-UXvna_RGnb6_TTRr4vSGCqA5yl0SYn9", 3.8), - "monotransquest-da-ro_en": ("xlmroberta", "1-aeDbR_ftqsTslFJbNybebj5MAhPfIw8", 3.8) - } - - if model_name_or_path in MODEL_CONFIG: - self.trained_model_type, self.drive_id, self.size = MODEL_CONFIG[model_name_or_path] - - try: - from torch.hub import _get_torch_home - torch_cache_home = _get_torch_home() - except ImportError: - torch_cache_home = os.path.expanduser( - os.getenv('TORCH_HOME', os.path.join( - os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) - default_cache_path = os.path.join(torch_cache_home, 'transquest') - self.model_path = os.path.join(default_cache_path, self.model_name_or_path) - if force_download or (not os.path.exists(self.model_path) or not os.listdir(self.model_path)): - logger.info( - "Downloading a MonoTransQuest model and saving it at {}".format(self.model_path)) - - gdd.download_file_from_google_drive(file_id=self.drive_id, - dest_path=os.path.join(self.model_path, "model.zip"), - showsize=True, unzip=True, overwrite=True) - - self.model = MonoTransQuestModel(self.trained_model_type, self.model_path, use_cuda=self.use_cuda, - cuda_device=self.cuda_device) - - else: - self.model = MonoTransQuestModel(model_type, self.model_name_or_path, use_cuda=self.use_cuda, - cuda_device=self.cuda_device) - - @staticmethod - def _download(drive_id, model_name): - gdd.download_file_from_google_drive(file_id=drive_id, - dest_path=os.path.join(".transquest", model_name, "model.zip"), - unzip=True) - - def predict_quality(self, test_sentence_pairs): - predictions, raw_outputs = self.model.predict(test_sentence_pairs) - return predictions diff --git a/transquest/app/util/__init__.py b/transquest/app/util/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/transquest/app/util/__init__.py +++ /dev/null diff --git a/transquest/app/util/model_downloader.py b/transquest/app/util/model_downloader.py deleted file mode 100644 index 9999db2..0000000 --- a/transquest/app/util/model_downloader.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import print_function -import requests -import zipfile -import warnings -from sys import stdout -from os import makedirs -from os.path import dirname -from os.path import exists - - -class GoogleDriveDownloader: - """ - Minimal class to download shared files from Google Drive. - """ - - CHUNK_SIZE = 32768 - DOWNLOAD_URL = 'https://docs.google.com/uc?export=download' - - @staticmethod - def download_file_from_google_drive(file_id, dest_path, overwrite=False, unzip=False, showsize=False): - """ - Downloads a shared file from google drive into a given folder. - Optionally unzips it. - Parameters - ---------- - file_id: str - the file identifier. - You can obtain it from the sharable link. - dest_path: str - the destination where to save the downloaded file. - Must be a path (for example: './downloaded_file.txt') - overwrite: bool - optional, if True forces re-download and overwrite. - unzip: bool - optional, if True unzips a file. - If the file is not a zip file, ignores it. - showsize: bool - optional, if True print the current download size. - Returns - ------- - None - """ - - destination_directory = dirname(dest_path) - if not exists(destination_directory): - makedirs(destination_directory) - - if not exists(dest_path) or overwrite: - - session = requests.Session() - - print('Downloading {} into {}... '.format(file_id, dest_path), end='') - stdout.flush() - - response = session.get(GoogleDriveDownloader.DOWNLOAD_URL, params={'id': file_id}, stream=True) - - token = GoogleDriveDownloader._get_confirm_token(response) - if token: - params = {'id': file_id, 'confirm': token} - response = session.get(GoogleDriveDownloader.DOWNLOAD_URL, params=params, stream=True) - - if showsize: - print() # Skip to the next line - - current_download_size = [0] - GoogleDriveDownloader._save_response_content(response, dest_path, showsize, current_download_size) - print('Done.') - - if unzip: - try: - print('Unzipping...', end='') - stdout.flush() - with zipfile.ZipFile(dest_path, 'r') as z: - z.extractall(destination_directory) - print('Done.') - except zipfile.BadZipfile: - warnings.warn('Ignoring `unzip` since "{}" does not look like a valid zip file'.format(file_id)) - - @staticmethod - def _get_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith('download_warning'): - return value - return None - - @staticmethod - def _save_response_content(response, destination, showsize, current_size): - with open(destination, 'wb') as f: - for chunk in response.iter_content(GoogleDriveDownloader.CHUNK_SIZE): - if chunk: # filter out keep-alive new chunks - f.write(chunk) - if showsize: - print('\r' + GoogleDriveDownloader.sizeof_fmt(current_size[0]), end=' ') - stdout.flush() - current_size[0] += GoogleDriveDownloader.CHUNK_SIZE - - # From https://stackoverflow.com/questions/1094841/reusable-library-to-get-human-readable-version-of-file-size - @staticmethod - def sizeof_fmt(num, suffix='B'): - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: - if abs(num) < 1024.0: - return '{:.1f} {}{}'.format(num, unit, suffix) - num /= 1024.0 - return '{:.1f} {}{}'.format(num, 'Yi', suffix)
\ No newline at end of file |