diff options
Diffstat (limited to 'stanza/models/mwt_expander.py')
-rw-r--r-- | stanza/models/mwt_expander.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index 6ce83250..33b1d5a9 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -55,6 +55,7 @@ def parse_args(args=None): parser.add_argument('--max_dec_len', type=int, default=50) parser.add_argument('--beam_size', type=int, default=1) parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type') + parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.') parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.') parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.') @@ -98,16 +99,16 @@ def train(args): # load data logger.debug('max_dec_len: %d' % args['max_dec_len']) logger.debug("Loading data with batch size {}...".format(args['batch_size'])) - train_doc = Document(CoNLL.conll2dict(input_file=args['train_file'])) + train_doc = CoNLL.conll2doc(input_file=args['train_file']) train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False) vocab = train_batch.vocab args['vocab_size'] = vocab.size - dev_doc = Document(CoNLL.conll2dict(input_file=args['eval_file'])) + dev_doc = CoNLL.conll2doc(input_file=args['eval_file']) dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True) utils.ensure_dir(args['save_dir']) - model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \ - else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand']) + save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand']) + model_file = os.path.join(args['save_dir'], save_name) # pred and gold path system_pred_file = args['output_file'] @@ -126,7 +127,7 @@ def train(args): dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True)) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds) - CoNLL.dict2conll(doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(doc, system_pred_file) _, _, dev_f = scorer.score(system_pred_file, gold_file) logger.info("Dev F1 = {:.2f}".format(dev_f * 100)) @@ -168,7 +169,7 @@ def train(args): dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds) - CoNLL.dict2conll(doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(doc, system_pred_file) _, _, dev_score = scorer.score(system_pred_file, gold_file) train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch @@ -198,7 +199,7 @@ def train(args): dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds) doc = copy.deepcopy(dev_batch.doc) doc.set_mwt_expansions(dev_preds) - CoNLL.dict2conll(doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(doc, system_pred_file) _, _, dev_score = scorer.score(system_pred_file, gold_file) logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100)) best_f = max(best_f, dev_score) @@ -207,8 +208,8 @@ def evaluate(args): # file paths system_pred_file = args['output_file'] gold_file = args['gold_file'] - model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \ - else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand']) + save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand']) + model_file = os.path.join(args['save_dir'], save_name) # load model use_cuda = args['cuda'] and not args['cpu'] @@ -222,7 +223,7 @@ def evaluate(args): # load data logger.debug("Loading data with batch size {}...".format(args['batch_size'])) - doc = Document(CoNLL.conll2dict(input_file=args['eval_file'])) + doc = CoNLL.conll2doc(input_file=args['eval_file']) batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True) if len(batch) > 0: @@ -245,7 +246,7 @@ def evaluate(args): # write to file and score doc = copy.deepcopy(batch.doc) doc.set_mwt_expansions(preds) - CoNLL.dict2conll(doc.to_dict(), system_pred_file) + CoNLL.write_doc2conll(doc, system_pred_file) if gold_file is not None: _, _, score = scorer.score(system_pred_file, gold_file) |