Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-01-23 02:40:24 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-01-23 02:40:24 +0300
commit92b03f3bb34023a82f174766e5a6ad4a4a1db6ff (patch)
tree98dbce596ee3dde5258d5089690413a565749209
parent85bec64c0b4c5fd71d6fab0c4b1a62a41a5eb509 (diff)
055: Adding word level examples
-rw-r--r--examples/word_level/wmt_2018/en_cs/microtransquest.py117
-rw-r--r--examples/word_level/wmt_2018/en_cs/microtransquest_config.py98
2 files changed, 215 insertions, 0 deletions
diff --git a/examples/word_level/wmt_2018/en_cs/microtransquest.py b/examples/word_level/wmt_2018/en_cs/microtransquest.py
new file mode 100644
index 0000000..c000e03
--- /dev/null
+++ b/examples/word_level/wmt_2018/en_cs/microtransquest.py
@@ -0,0 +1,117 @@
+import os
+import shutil
+
+from sklearn.model_selection import train_test_split
+
+from examples.word_level.common.util import reader
+from examples.word_level.wmt_2018.en_cs.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \
+ TRAIN_SOURCE_TAGS_FILE, \
+ TRAIN_TARGET_FILE, \
+ TRAIN_TARGET_TAGS_FLE, MODEL_TYPE, MODEL_NAME, microtransquest_config, TEST_PATH, TEST_SOURCE_FILE, \
+ TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE
+from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process
+from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel
+
+if not os.path.exists(TEMP_DIRECTORY):
+ os.makedirs(TEMP_DIRECTORY)
+
+raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE,
+ TRAIN_TARGET_TAGS_FLE)
+raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE)
+
+test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config)
+
+fold_sources_tags = []
+fold_targets_tags = []
+
+for i in range(microtransquest_config["n_fold"]):
+
+ if os.path.exists(microtransquest_config['output_dir']) and os.path.isdir(microtransquest_config['output_dir']):
+ shutil.rmtree(microtransquest_config['output_dir'])
+
+ if microtransquest_config["evaluate_during_training"]:
+ raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i)
+ train_df = prepare_data(raw_train, args=microtransquest_config)
+ eval_df = prepare_data(raw_eval, args=microtransquest_config)
+ tags = train_df['labels'].unique().tolist()
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
+ model.train_model(train_df, eval_df=eval_df)
+ model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags,
+ args=microtransquest_config)
+
+ else:
+ train_df = prepare_data(raw_train_df, args=microtransquest_config)
+ tags = train_df['labels'].unique().tolist()
+ model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config)
+ model.train_model(train_df)
+
+ predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True)
+ sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config)
+ fold_sources_tags.append(sources_tags)
+ fold_targets_tags.append(targets_tags)
+
+source_predictions = []
+for sentence_id in range(len(test_sentences)):
+ majority_prediction = []
+ predictions = []
+ for fold_prediction in fold_sources_tags:
+ predictions.append(fold_prediction[sentence_id])
+
+ sentence_length = len(predictions[0])
+
+ for word_id in range(sentence_length):
+ word_prediction = []
+ for prediction in predictions:
+ word_prediction.append(prediction[word_id])
+ majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
+ source_predictions.append(majority_prediction)
+
+target_predictions = []
+for sentence_id in range(len(test_sentences)):
+ majority_prediction = []
+ predictions = []
+ for fold_prediction in fold_targets_tags:
+ predictions.append(fold_prediction[sentence_id])
+
+ sentence_length = len(predictions[0])
+
+ for word_id in range(sentence_length):
+ word_prediction = []
+ for prediction in predictions:
+ word_prediction.append(prediction[word_id])
+ majority_prediction.append(max(set(word_prediction), key=word_prediction.count))
+ target_predictions.append(majority_prediction)
+
+test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist()
+test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist()
+
+with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f:
+ for sentence_id, (test_source_sentence, source_prediction) in enumerate(
+ zip(test_source_sentences, source_predictions)):
+ words = test_source_sentence.split()
+ word_predictions = source_prediction.split()
+ for word_id, (word, word_prediction) in enumerate(zip(words, word_predictions)):
+ f.write("MicroTransQuest" + "\t" + "source" + "\t" +
+ str(sentence_id) + "\t" + str(word_id) + "\t"
+ + word + "\t" + word_prediction + '\n')
+
+with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open(
+ os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f:
+ for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)):
+ target_sentence = test_sentence.split("[SEP]")[1]
+ words = target_sentence.split()
+ word_predictions = target_prediction.split()
+ gap_index = 0
+ word_index = 0
+ for word, word_prediction in zip(words, word_predictions):
+ if word == microtransquest_config["tag"]:
+ gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" +
+ str(sentence_id) + "\t" + str(gap_index) + "\t"
+ + "gap" + "\t" + word_prediction + '\n')
+ gap_index += 1
+ else:
+ target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" +
+ str(sentence_id) + "\t" + str(word_index) + "\t"
+ + word + "\t" + word_prediction + '\n')
+ word_index += 1
+
diff --git a/examples/word_level/wmt_2018/en_cs/microtransquest_config.py b/examples/word_level/wmt_2018/en_cs/microtransquest_config.py
new file mode 100644
index 0000000..8941451
--- /dev/null
+++ b/examples/word_level/wmt_2018/en_cs/microtransquest_config.py
@@ -0,0 +1,98 @@
+from multiprocessing import cpu_count
+
+TRAIN_PATH = "examples/word_level/wmt_2018/en_cs/data/en_cs.smt"
+TRAIN_SOURCE_FILE = "train.src"
+TRAIN_SOURCE_TAGS_FILE = "train.source_tags"
+TRAIN_TARGET_FILE = "train.mt"
+TRAIN_TARGET_TAGS_FLE = "train.tags"
+
+TEST_PATH = "examples/word_level/wmt_2018/en_cs/data/en_cs.smt"
+TEST_SOURCE_FILE = "test.src"
+TEST_TARGET_FILE = "test.mt"
+
+TEST_SOURCE_TAGS_FILE = "predictions_src.txt"
+TEST_TARGET_TAGS_FILE = "predictions_mt.txt"
+TEST_TARGET_GAPS_FILE = "predictions_gaps.txt"
+
+
+SEED = 777
+TEMP_DIRECTORY = "temp/data"
+GOOGLE_DRIVE = False
+DRIVE_FILE_ID = None
+MODEL_TYPE = "xlmroberta"
+MODEL_NAME = "xlm-roberta-large"
+
+microtransquest_config = {
+ 'output_dir': 'temp/outputs/',
+ "best_model_dir": "temp/outputs/best_model",
+ 'cache_dir': 'temp/cache_dir/',
+
+ 'fp16': False,
+ 'fp16_opt_level': 'O1',
+ 'max_seq_length': 200,
+ 'train_batch_size': 8,
+ 'gradient_accumulation_steps': 1,
+ 'eval_batch_size': 8,
+ 'num_train_epochs': 3,
+ 'weight_decay': 0,
+ 'learning_rate': 2e-5,
+ 'adam_epsilon': 1e-8,
+ 'warmup_ratio': 0.1,
+ 'warmup_steps': 0,
+ 'max_grad_norm': 1.0,
+ 'do_lower_case': False,
+
+ 'logging_steps': 300,
+ 'save_steps': 300,
+ "no_cache": False,
+ "no_save": False,
+ "save_recent_only": True,
+ 'save_model_every_epoch': False,
+ 'n_fold': 3,
+ 'evaluate_during_training': True,
+ "evaluate_during_training_silent": True,
+ 'evaluate_during_training_steps': 300,
+ "evaluate_during_training_verbose": True,
+ 'use_cached_eval_features': False,
+ "save_best_model": True,
+ 'save_eval_checkpoints': True,
+ 'tensorboard_dir': None,
+ "save_optimizer_and_scheduler": True,
+
+ 'regression': True,
+
+ 'overwrite_output_dir': True,
+ 'reprocess_input_data': True,
+
+ 'process_count': cpu_count() - 2 if cpu_count() > 2 else 1,
+ 'n_gpu': 1,
+ 'use_multiprocessing': True,
+ "multiprocessing_chunksize": 500,
+ 'silent': False,
+
+ 'wandb_project': None,
+ 'wandb_kwargs': {},
+
+ "use_early_stopping": True,
+ "early_stopping_patience": 10,
+ "early_stopping_delta": 0,
+ "early_stopping_metric": "eval_loss",
+ "early_stopping_metric_minimize": True,
+ "early_stopping_consider_epochs": False,
+
+ "manual_seed": SEED,
+
+ "add_tag": False,
+ "tag": "_",
+
+ "default_quality": "OK",
+
+ "config": {},
+ "local_rank": -1,
+ "encoding": None,
+
+ "source_column": "source",
+ "target_column": "target",
+ "source_tags_column": "source_tags",
+ "target_tags_column": "target_tags",
+}