diff options
author | John Bauer <horatio@gmail.com> | 2022-10-29 05:20:00 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-29 05:20:00 +0300 |
commit | 39dc3a187b6cee71677f731410ed7947bb9095fd (patch) | |
tree | e6239fd71a28ebe71991912bae861bfc035af317 | |
parent | 3a313c9f9052d9b4dfce1e9412e6d071e47d3bbe (diff) |
refactor predict dir,file,format args so they can be used elsewhere if needed
-rw-r--r-- | stanza/models/constituency/utils.py | 8 | ||||
-rw-r--r-- | stanza/models/constituency_parser.py | 6 |
2 files changed, 10 insertions, 4 deletions
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py index 8b6e7df8..01dc3522 100644 --- a/stanza/models/constituency/utils.py +++ b/stanza/models/constituency/utils.py @@ -199,3 +199,11 @@ def initialize_linear(linear, nonlinearity, bias): if nonlinearity in ('relu', 'leaky_relu'): nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity) nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5) + +def add_predict_output_args(parser): + """ + Args specifically for the output location of data + """ + parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging') + parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions') + parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions') diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index bc4a8331..33e3cd05 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -137,7 +137,7 @@ from stanza.models.constituency import retagging from stanza.models.constituency import trainer from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory from stanza.models.constituency.parse_transitions import TransitionScheme -from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, NONLINEARITY +from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, NONLINEARITY, add_predict_output_args logger = logging.getLogger('stanza') @@ -181,9 +181,7 @@ def parse_args(args=None): parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.') parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer']) parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one') - parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging') - parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions') - parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions') + add_predict_output_args(parser) parser.add_argument('--lang', type=str, help='Language') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") |