diff options
Diffstat (limited to 'stanza/models/tokenizer.py')
-rw-r--r-- | stanza/models/tokenizer.py | 15 |
1 files changed, 6 insertions, 9 deletions
diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index d663bcb3..54bc729f 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -1,7 +1,7 @@ """ Entry point for training and evaluating a neural tokenizer. -This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of +This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of recurrent and convolutional architectures. For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf. """ @@ -11,6 +11,7 @@ from copy import copy import logging import random import numpy as np +import os import torch from stanza.models.common import utils @@ -28,12 +29,10 @@ def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument('--txt_file', type=str, help="Input plaintext file") parser.add_argument('--label_file', type=str, default=None, help="Character-level label file") - parser.add_argument('--json_file', type=str, default=None, help="JSON file with pre-chunked units") parser.add_argument('--mwt_json_file', type=str, default=None, help="JSON file for MWT expansions") parser.add_argument('--conll_file', type=str, default=None, help="CoNLL file for output") parser.add_argument('--dev_txt_file', type=str, help="(Train only) Input plaintext file for the dev set") parser.add_argument('--dev_label_file', type=str, default=None, help="(Train only) Character-level label file for the dev set") - parser.add_argument('--dev_json_file', type=str, default=None, help="(Train only) JSON file with pre-chunked units for the dev set") parser.add_argument('--dev_conll_gold', type=str, default=None, help="(Train only) CoNLL-U file for the dev set for early stopping") parser.add_argument('--lang', type=str, help="Language") parser.add_argument('--shorthand', type=str, help="UD treebank shorthand") @@ -58,6 +57,7 @@ def parse_args(args=None): parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability") parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability") parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN") + parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.") parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay") parser.add_argument('--max_seqlen', type=int, default=100, help="Maximum sequence length to consider at a time") parser.add_argument('--batch_size', type=int, default=32, help="Batch size to use") @@ -92,8 +92,8 @@ def main(args=None): args['feat_funcs'] = ['space_before', 'capitalized', 'all_caps', 'numeric'] args['feat_dim'] = len(args['feat_funcs']) - args['save_name'] = "{}/{}".format(args['save_dir'], args['save_name']) if args['save_name'] is not None \ - else '{}/{}_tokenizer.pt'.format(args['save_dir'], args['shorthand']) + save_name = args['save_name'] if args['save_name'] else '{}_tokenizer.pt'.format(args['shorthand']) + args['save_name'] = os.path.join(args['save_dir'], save_name) utils.ensure_dir(args['save_dir']) if args['mode'] == 'train': @@ -105,7 +105,6 @@ def train(args): mwt_dict = load_mwt_dict(args['mwt_json_file']) train_input_files = { - 'json': args['json_file'], 'txt': args['txt_file'], 'label': args['label_file'] } @@ -114,7 +113,6 @@ def train(args): args['vocab_size'] = len(vocab) dev_input_files = { - 'json': args['dev_json_file'], 'txt': args['dev_txt_file'], 'label': args['dev_label_file'] } @@ -127,7 +125,7 @@ def train(args): trainer = Trainer(args=args, vocab=vocab, use_cuda=args['cuda']) if args['load_name'] is not None: - load_name = "{}/{}".format(args['save_dir'], args['load_name']) + load_name = os.path.join(args['save_dir'], args['load_name']) trainer.load(load_name) trainer.change_lr(args['lr0']) @@ -187,7 +185,6 @@ def evaluate(args): args[k] = loaded_args[k] eval_input_files = { - 'json': args['json_file'], 'txt': args['txt_file'], 'label': args['label_file'] } |