Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-29 05:20:00 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-29 05:20:00 +0300
commit39dc3a187b6cee71677f731410ed7947bb9095fd (patch)
treee6239fd71a28ebe71991912bae861bfc035af317
parent3a313c9f9052d9b4dfce1e9412e6d071e47d3bbe (diff)
refactor predict dir,file,format args so they can be used elsewhere if needed
-rw-r--r--stanza/models/constituency/utils.py8
-rw-r--r--stanza/models/constituency_parser.py6
2 files changed, 10 insertions, 4 deletions
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
index 8b6e7df8..01dc3522 100644
--- a/stanza/models/constituency/utils.py
+++ b/stanza/models/constituency/utils.py
@@ -199,3 +199,11 @@ def initialize_linear(linear, nonlinearity, bias):
if nonlinearity in ('relu', 'leaky_relu'):
nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity)
nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5)
+
+def add_predict_output_args(parser):
+ """
+ Args specifically for the output location of data
+ """
+ parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
+ parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
+ parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions')
diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py
index bc4a8331..33e3cd05 100644
--- a/stanza/models/constituency_parser.py
+++ b/stanza/models/constituency_parser.py
@@ -137,7 +137,7 @@ from stanza.models.constituency import retagging
from stanza.models.constituency import trainer
from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory
from stanza.models.constituency.parse_transitions import TransitionScheme
-from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, NONLINEARITY
+from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, NONLINEARITY, add_predict_output_args
logger = logging.getLogger('stanza')
@@ -181,9 +181,7 @@ def parse_args(args=None):
parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])
parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one')
- parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
- parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
- parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions')
+ add_predict_output_args(parser)
parser.add_argument('--lang', type=str, help='Language')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")