Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-29 08:25:15 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-29 08:25:15 +0300
commit0a527352cd0d61d6385ed54e6d454c14b4593e5b (patch)
tree1e0e28911789fa85ae5ec0bcfb6fdef83e6bdc73
parentbcb874e0cba2d950e1bd716c635095f680742d0e (diff)
Move uses_xpos() to the model itself, add it Ensemble. Will make it easier to generalize selftrain.py to use Ensemble as well
-rw-r--r--stanza/models/constituency/ensemble.py5
-rw-r--r--stanza/models/constituency/lstm_model.py3
-rw-r--r--stanza/models/constituency/trainer.py5
-rw-r--r--stanza/pipeline/constituency_processor.py2
4 files changed, 10 insertions, 5 deletions
diff --git a/stanza/models/constituency/ensemble.py b/stanza/models/constituency/ensemble.py
index c53db07f..0ab75316 100644
--- a/stanza/models/constituency/ensemble.py
+++ b/stanza/models/constituency/ensemble.py
@@ -55,11 +55,16 @@ class Ensemble:
raise ValueError("Models %s and %s are incompatible: different constituents" % (filenames[0], filenames[model_idx]))
if self.models[0].root_labels != model.root_labels:
raise ValueError("Models %s and %s are incompatible: different root_labels" % (filenames[0], filenames[model_idx]))
+ if self.models[0].uses_xpos() != model.uses_xpos():
+ raise ValueError("Models %s and %s are incompatible: different uses_xpos" % (filenames[0], filenames[model_idx]))
def eval(self):
for model in self.models:
model.eval()
+ def uses_xpos(self):
+ return self.models[0].uses_xpos()
+
def build_batch_from_tagged_words(self, batch_size, data_iterator):
"""
Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py
index 8419a3a9..42e8fd83 100644
--- a/stanza/models/constituency/lstm_model.py
+++ b/stanza/models/constituency/lstm_model.py
@@ -573,6 +573,9 @@ class LSTMModel(BaseModel, nn.Module):
def num_words_known(self, words):
return sum(word in self.vocab_map or word.lower() in self.vocab_map for word in words)
+ def uses_xpos(self):
+ return self.args['retag_package'] is not None and self.args['retag_method'] == 'xpos'
+
def add_unsaved_module(self, name, module):
"""
Adds a module which will not be saved to disk
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 2d35f44c..000af875 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -57,9 +57,6 @@ class Trainer:
self.best_f1 = best_f1
self.best_epoch = best_epoch
- def uses_xpos(self):
- return self.model.args['retag_package'] is not None and self.model.args['retag_method'] == 'xpos'
-
def save(self, filename, save_optimizer=True):
"""
Save the model (and by default the optimizer) to the given path
@@ -253,7 +250,7 @@ def parse_text(args, model, retag_pipeline):
logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
doc = retag_pipeline(chunk)
logger.info("Retagging finished. Parsing tagged text")
- if args['retag_method'] == 'xpos':
+ if model.uses_xpos():
words = [[(w.text, w.xpos) for w in s.words] for s in doc.sentences]
else:
words = [[(w.text, w.upos) for w in s.words] for s in doc.sentences]
diff --git a/stanza/pipeline/constituency_processor.py b/stanza/pipeline/constituency_processor.py
index 0bac762f..820b4527 100644
--- a/stanza/pipeline/constituency_processor.py
+++ b/stanza/pipeline/constituency_processor.py
@@ -56,7 +56,7 @@ class ConstituencyProcessor(UDProcessor):
def process(self, document):
sentences = document.sentences
- if self._model.uses_xpos():
+ if self._model.model.uses_xpos():
words = [[(w.text, w.xpos) for w in s.words] for s in sentences]
else:
words = [[(w.text, w.upos) for w in s.words] for s in sentences]