diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-01-23 02:40:24 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-01-23 02:40:24 +0300 |
commit | 92b03f3bb34023a82f174766e5a6ad4a4a1db6ff (patch) | |
tree | 98dbce596ee3dde5258d5089690413a565749209 | |
parent | 85bec64c0b4c5fd71d6fab0c4b1a62a41a5eb509 (diff) |
055: Adding word level examples
-rw-r--r-- | examples/word_level/wmt_2018/en_cs/microtransquest.py | 117 | ||||
-rw-r--r-- | examples/word_level/wmt_2018/en_cs/microtransquest_config.py | 98 |
2 files changed, 215 insertions, 0 deletions
diff --git a/examples/word_level/wmt_2018/en_cs/microtransquest.py b/examples/word_level/wmt_2018/en_cs/microtransquest.py new file mode 100644 index 0000000..c000e03 --- /dev/null +++ b/examples/word_level/wmt_2018/en_cs/microtransquest.py @@ -0,0 +1,117 @@ +import os +import shutil + +from sklearn.model_selection import train_test_split + +from examples.word_level.common.util import reader +from examples.word_level.wmt_2018.en_cs.microtransquest_config import TRAIN_PATH, TRAIN_SOURCE_FILE, \ + TRAIN_SOURCE_TAGS_FILE, \ + TRAIN_TARGET_FILE, \ + TRAIN_TARGET_TAGS_FLE, MODEL_TYPE, MODEL_NAME, microtransquest_config, TEST_PATH, TEST_SOURCE_FILE, \ + TEST_TARGET_FILE, TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE, SEED, TEST_TARGET_TAGS_FILE, TEST_TARGET_GAPS_FILE +from transquest.algo.word_level.microtransquest.format import prepare_data, prepare_testdata, post_process +from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel + +if not os.path.exists(TEMP_DIRECTORY): + os.makedirs(TEMP_DIRECTORY) + +raw_train_df = reader(TRAIN_PATH, microtransquest_config, TRAIN_SOURCE_FILE, TRAIN_TARGET_FILE, TRAIN_SOURCE_TAGS_FILE, + TRAIN_TARGET_TAGS_FLE) +raw_test_df = reader(TEST_PATH, microtransquest_config, TEST_SOURCE_FILE, TEST_TARGET_FILE) + +test_sentences = prepare_testdata(raw_test_df, args=microtransquest_config) + +fold_sources_tags = [] +fold_targets_tags = [] + +for i in range(microtransquest_config["n_fold"]): + + if os.path.exists(microtransquest_config['output_dir']) and os.path.isdir(microtransquest_config['output_dir']): + shutil.rmtree(microtransquest_config['output_dir']) + + if microtransquest_config["evaluate_during_training"]: + raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) + train_df = prepare_data(raw_train, args=microtransquest_config) + eval_df = prepare_data(raw_eval, args=microtransquest_config) + tags = train_df['labels'].unique().tolist() + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) + model.train_model(train_df, eval_df=eval_df) + model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=tags, + args=microtransquest_config) + + else: + train_df = prepare_data(raw_train_df, args=microtransquest_config) + tags = train_df['labels'].unique().tolist() + model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=tags, args=microtransquest_config) + model.train_model(train_df) + + predicted_labels, raw_predictions = model.predict(test_sentences, split_on_space=True) + sources_tags, targets_tags = post_process(predicted_labels, test_sentences, args=microtransquest_config) + fold_sources_tags.append(sources_tags) + fold_targets_tags.append(targets_tags) + +source_predictions = [] +for sentence_id in range(len(test_sentences)): + majority_prediction = [] + predictions = [] + for fold_prediction in fold_sources_tags: + predictions.append(fold_prediction[sentence_id]) + + sentence_length = len(predictions[0]) + + for word_id in range(sentence_length): + word_prediction = [] + for prediction in predictions: + word_prediction.append(prediction[word_id]) + majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) + source_predictions.append(majority_prediction) + +target_predictions = [] +for sentence_id in range(len(test_sentences)): + majority_prediction = [] + predictions = [] + for fold_prediction in fold_targets_tags: + predictions.append(fold_prediction[sentence_id]) + + sentence_length = len(predictions[0]) + + for word_id in range(sentence_length): + word_prediction = [] + for prediction in predictions: + word_prediction.append(prediction[word_id]) + majority_prediction.append(max(set(word_prediction), key=word_prediction.count)) + target_predictions.append(majority_prediction) + +test_source_sentences = raw_test_df[microtransquest_config["source_column"]].tolist() +test_target_sentences = raw_test_df[microtransquest_config["target_column"]].tolist() + +with open(os.path.join(TEMP_DIRECTORY, TEST_SOURCE_TAGS_FILE), 'w') as f: + for sentence_id, (test_source_sentence, source_prediction) in enumerate( + zip(test_source_sentences, source_predictions)): + words = test_source_sentence.split() + word_predictions = source_prediction.split() + for word_id, (word, word_prediction) in enumerate(zip(words, word_predictions)): + f.write("MicroTransQuest" + "\t" + "source" + "\t" + + str(sentence_id) + "\t" + str(word_id) + "\t" + + word + "\t" + word_prediction + '\n') + +with open(os.path.join(TEMP_DIRECTORY, TEST_TARGET_TAGS_FILE), 'w') as target_f, open( + os.path.join(TEMP_DIRECTORY, TEST_TARGET_GAPS_FILE), 'w') as gap_f: + for sentence_id, (test_sentence, target_prediction) in enumerate(zip(test_sentences, target_predictions)): + target_sentence = test_sentence.split("[SEP]")[1] + words = target_sentence.split() + word_predictions = target_prediction.split() + gap_index = 0 + word_index = 0 + for word, word_prediction in zip(words, word_predictions): + if word == microtransquest_config["tag"]: + gap_f.write("MicroTransQuest" + "\t" + "gap" + "\t" + + str(sentence_id) + "\t" + str(gap_index) + "\t" + + "gap" + "\t" + word_prediction + '\n') + gap_index += 1 + else: + target_f.write("MicroTransQuest" + "\t" + "mt" + "\t" + + str(sentence_id) + "\t" + str(word_index) + "\t" + + word + "\t" + word_prediction + '\n') + word_index += 1 + diff --git a/examples/word_level/wmt_2018/en_cs/microtransquest_config.py b/examples/word_level/wmt_2018/en_cs/microtransquest_config.py new file mode 100644 index 0000000..8941451 --- /dev/null +++ b/examples/word_level/wmt_2018/en_cs/microtransquest_config.py @@ -0,0 +1,98 @@ +from multiprocessing import cpu_count + +TRAIN_PATH = "examples/word_level/wmt_2018/en_cs/data/en_cs.smt" +TRAIN_SOURCE_FILE = "train.src" +TRAIN_SOURCE_TAGS_FILE = "train.source_tags" +TRAIN_TARGET_FILE = "train.mt" +TRAIN_TARGET_TAGS_FLE = "train.tags" + +TEST_PATH = "examples/word_level/wmt_2018/en_cs/data/en_cs.smt" +TEST_SOURCE_FILE = "test.src" +TEST_TARGET_FILE = "test.mt" + +TEST_SOURCE_TAGS_FILE = "predictions_src.txt" +TEST_TARGET_TAGS_FILE = "predictions_mt.txt" +TEST_TARGET_GAPS_FILE = "predictions_gaps.txt" + + +SEED = 777 +TEMP_DIRECTORY = "temp/data" +GOOGLE_DRIVE = False +DRIVE_FILE_ID = None +MODEL_TYPE = "xlmroberta" +MODEL_NAME = "xlm-roberta-large" + +microtransquest_config = { + 'output_dir': 'temp/outputs/', + "best_model_dir": "temp/outputs/best_model", + 'cache_dir': 'temp/cache_dir/', + + 'fp16': False, + 'fp16_opt_level': 'O1', + 'max_seq_length': 200, + 'train_batch_size': 8, + 'gradient_accumulation_steps': 1, + 'eval_batch_size': 8, + 'num_train_epochs': 3, + 'weight_decay': 0, + 'learning_rate': 2e-5, + 'adam_epsilon': 1e-8, + 'warmup_ratio': 0.1, + 'warmup_steps': 0, + 'max_grad_norm': 1.0, + 'do_lower_case': False, + + 'logging_steps': 300, + 'save_steps': 300, + "no_cache": False, + "no_save": False, + "save_recent_only": True, + 'save_model_every_epoch': False, + 'n_fold': 3, + 'evaluate_during_training': True, + "evaluate_during_training_silent": True, + 'evaluate_during_training_steps': 300, + "evaluate_during_training_verbose": True, + 'use_cached_eval_features': False, + "save_best_model": True, + 'save_eval_checkpoints': True, + 'tensorboard_dir': None, + "save_optimizer_and_scheduler": True, + + 'regression': True, + + 'overwrite_output_dir': True, + 'reprocess_input_data': True, + + 'process_count': cpu_count() - 2 if cpu_count() > 2 else 1, + 'n_gpu': 1, + 'use_multiprocessing': True, + "multiprocessing_chunksize": 500, + 'silent': False, + + 'wandb_project': None, + 'wandb_kwargs': {}, + + "use_early_stopping": True, + "early_stopping_patience": 10, + "early_stopping_delta": 0, + "early_stopping_metric": "eval_loss", + "early_stopping_metric_minimize": True, + "early_stopping_consider_epochs": False, + + "manual_seed": SEED, + + "add_tag": False, + "tag": "_", + + "default_quality": "OK", + + "config": {}, + "local_rank": -1, + "encoding": None, + + "source_column": "source", + "target_column": "target", + "source_tags_column": "source_tags", + "target_tags_column": "target_tags", +} |