Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/models/mwt_expander.py')
-rw-r--r--stanza/models/mwt_expander.py23
1 files changed, 12 insertions, 11 deletions
diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py
index 6ce83250..33b1d5a9 100644
--- a/stanza/models/mwt_expander.py
+++ b/stanza/models/mwt_expander.py
@@ -55,6 +55,7 @@ def parse_args(args=None):
parser.add_argument('--max_dec_len', type=int, default=50)
parser.add_argument('--beam_size', type=int, default=1)
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
+ parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
@@ -98,16 +99,16 @@ def train(args):
# load data
logger.debug('max_dec_len: %d' % args['max_dec_len'])
logger.debug("Loading data with batch size {}...".format(args['batch_size']))
- train_doc = Document(CoNLL.conll2dict(input_file=args['train_file']))
+ train_doc = CoNLL.conll2doc(input_file=args['train_file'])
train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
vocab = train_batch.vocab
args['vocab_size'] = vocab.size
- dev_doc = Document(CoNLL.conll2dict(input_file=args['eval_file']))
+ dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
utils.ensure_dir(args['save_dir'])
- model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
- else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand'])
+ save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
+ model_file = os.path.join(args['save_dir'], save_name)
# pred and gold path
system_pred_file = args['output_file']
@@ -126,7 +127,7 @@ def train(args):
dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True))
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds)
- CoNLL.dict2conll(doc.to_dict(), system_pred_file)
+ CoNLL.write_doc2conll(doc, system_pred_file)
_, _, dev_f = scorer.score(system_pred_file, gold_file)
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
@@ -168,7 +169,7 @@ def train(args):
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds)
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds)
- CoNLL.dict2conll(doc.to_dict(), system_pred_file)
+ CoNLL.write_doc2conll(doc, system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, gold_file)
train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
@@ -198,7 +199,7 @@ def train(args):
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds)
doc = copy.deepcopy(dev_batch.doc)
doc.set_mwt_expansions(dev_preds)
- CoNLL.dict2conll(doc.to_dict(), system_pred_file)
+ CoNLL.write_doc2conll(doc, system_pred_file)
_, _, dev_score = scorer.score(system_pred_file, gold_file)
logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100))
best_f = max(best_f, dev_score)
@@ -207,8 +208,8 @@ def evaluate(args):
# file paths
system_pred_file = args['output_file']
gold_file = args['gold_file']
- model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
- else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand'])
+ save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
+ model_file = os.path.join(args['save_dir'], save_name)
# load model
use_cuda = args['cuda'] and not args['cpu']
@@ -222,7 +223,7 @@ def evaluate(args):
# load data
logger.debug("Loading data with batch size {}...".format(args['batch_size']))
- doc = Document(CoNLL.conll2dict(input_file=args['eval_file']))
+ doc = CoNLL.conll2doc(input_file=args['eval_file'])
batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
if len(batch) > 0:
@@ -245,7 +246,7 @@ def evaluate(args):
# write to file and score
doc = copy.deepcopy(batch.doc)
doc.set_mwt_expansions(preds)
- CoNLL.dict2conll(doc.to_dict(), system_pred_file)
+ CoNLL.write_doc2conll(doc, system_pred_file)
if gold_file is not None:
_, _, score = scorer.score(system_pred_file, gold_file)