diff options
Diffstat (limited to 'stanza/models/lemmatizer.py')
-rw-r--r-- | stanza/models/lemmatizer.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/stanza/models/lemmatizer.py b/stanza/models/lemmatizer.py index f75ce884..c20becca 100644 --- a/stanza/models/lemmatizer.py +++ b/stanza/models/lemmatizer.py @@ -59,6 +59,7 @@ def parse_args(args=None): parser.add_argument('--num_edit', type=int, default=len(edit.EDIT_TO_ID)) parser.add_argument('--alpha', type=float, default=1.0) parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.') + parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.') parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.') @@ -69,7 +70,7 @@ def parse_args(args=None): parser.add_argument('--batch_size', type=int, default=50) parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.') parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.') - parser.add_argument('--model_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.') + parser.add_argument('--save_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.') parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) @@ -100,16 +101,16 @@ def main(args=None): def train(args): # load data logger.info("[Loading data with batch size {}...]".format(args['batch_size'])) - train_doc = Document(CoNLL.conll2dict(input_file=args['train_file'])) + train_doc = CoNLL.conll2doc(input_file=args['train_file']) train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False) vocab = train_batch.vocab args['vocab_size'] = vocab['char'].size args['pos_vocab_size'] = vocab['pos'].size - dev_doc = Document(CoNLL.conll2dict(input_file=args['eval_file'])) + dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True) - utils.ensure_dir(args['model_dir']) - model_file = '{}/{}_lemmatizer.pt'.format(args['model_dir'], args['lang']) + utils.ensure_dir(args['save_dir']) + model_file = os.path.join(args['save_dir'], '{}_lemmatizer.pt'.format(args['lang'])) # pred and gold path system_pred_file = args['output_file'] @@ -130,7 +131,7 @@ def train(args): logger.info("Evaluating on dev set...") dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS])) dev_batch.doc.set([LEMMA], dev_preds) - CoNLL.dict2conll(dev_batch.doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) _, _, dev_f = scorer.score(system_pred_file, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) @@ -177,7 +178,7 @@ def train(args): logger.info("[Ensembling dict with seq2seq model...]") dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds) dev_batch.doc.set([LEMMA], dev_preds) - CoNLL.dict2conll(dev_batch.doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(dev_batch.doc, system_pred_file) _, _, dev_score = scorer.score(system_pred_file, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch @@ -207,7 +208,7 @@ def evaluate(args): # file paths system_pred_file = args['output_file'] gold_file = args['gold_file'] - model_file = '{}/{}_lemmatizer.pt'.format(args['model_dir'], args['lang']) + model_file = os.path.join(args['save_dir'], '{}_lemmatizer.pt'.format(args['lang'])) # load model use_cuda = args['cuda'] and not args['cpu'] @@ -220,7 +221,7 @@ def evaluate(args): # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) - doc = Document(CoNLL.conll2dict(input_file=args['eval_file'])) + doc = CoNLL.conll2doc(input_file=args['eval_file']) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) # skip eval if dev data does not exist @@ -249,7 +250,7 @@ def evaluate(args): # write to file and score batch.doc.set([LEMMA], preds) - CoNLL.dict2conll(batch.doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(batch.doc, system_pred_file) if gold_file is not None: _, _, score = scorer.score(system_pred_file, gold_file) |