diff options
author | John Bauer <horatio@gmail.com> | 2022-10-29 08:25:15 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-29 08:25:15 +0300 |
commit | 0a527352cd0d61d6385ed54e6d454c14b4593e5b (patch) | |
tree | 1e0e28911789fa85ae5ec0bcfb6fdef83e6bdc73 | |
parent | bcb874e0cba2d950e1bd716c635095f680742d0e (diff) |
Move uses_xpos() to the model itself, add it Ensemble. Will make it easier to generalize selftrain.py to use Ensemble as well
-rw-r--r-- | stanza/models/constituency/ensemble.py | 5 | ||||
-rw-r--r-- | stanza/models/constituency/lstm_model.py | 3 | ||||
-rw-r--r-- | stanza/models/constituency/trainer.py | 5 | ||||
-rw-r--r-- | stanza/pipeline/constituency_processor.py | 2 |
4 files changed, 10 insertions, 5 deletions
diff --git a/stanza/models/constituency/ensemble.py b/stanza/models/constituency/ensemble.py index c53db07f..0ab75316 100644 --- a/stanza/models/constituency/ensemble.py +++ b/stanza/models/constituency/ensemble.py @@ -55,11 +55,16 @@ class Ensemble: raise ValueError("Models %s and %s are incompatible: different constituents" % (filenames[0], filenames[model_idx])) if self.models[0].root_labels != model.root_labels: raise ValueError("Models %s and %s are incompatible: different root_labels" % (filenames[0], filenames[model_idx])) + if self.models[0].uses_xpos() != model.uses_xpos(): + raise ValueError("Models %s and %s are incompatible: different uses_xpos" % (filenames[0], filenames[model_idx])) def eval(self): for model in self.models: model.eval() + def uses_xpos(self): + return self.models[0].uses_xpos() + def build_batch_from_tagged_words(self, batch_size, data_iterator): """ Read from the data_iterator batch_size tagged sentences and turn them into new parsing states diff --git a/stanza/models/constituency/lstm_model.py b/stanza/models/constituency/lstm_model.py index 8419a3a9..42e8fd83 100644 --- a/stanza/models/constituency/lstm_model.py +++ b/stanza/models/constituency/lstm_model.py @@ -573,6 +573,9 @@ class LSTMModel(BaseModel, nn.Module): def num_words_known(self, words): return sum(word in self.vocab_map or word.lower() in self.vocab_map for word in words) + def uses_xpos(self): + return self.args['retag_package'] is not None and self.args['retag_method'] == 'xpos' + def add_unsaved_module(self, name, module): """ Adds a module which will not be saved to disk diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 2d35f44c..000af875 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -57,9 +57,6 @@ class Trainer: self.best_f1 = best_f1 self.best_epoch = best_epoch - def uses_xpos(self): - return self.model.args['retag_package'] is not None and self.model.args['retag_method'] == 'xpos' - def save(self, filename, save_optimizer=True): """ Save the model (and by default the optimizer) to the given path @@ -253,7 +250,7 @@ def parse_text(args, model, retag_pipeline): logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk)) doc = retag_pipeline(chunk) logger.info("Retagging finished. Parsing tagged text") - if args['retag_method'] == 'xpos': + if model.uses_xpos(): words = [[(w.text, w.xpos) for w in s.words] for s in doc.sentences] else: words = [[(w.text, w.upos) for w in s.words] for s in doc.sentences] diff --git a/stanza/pipeline/constituency_processor.py b/stanza/pipeline/constituency_processor.py index 0bac762f..820b4527 100644 --- a/stanza/pipeline/constituency_processor.py +++ b/stanza/pipeline/constituency_processor.py @@ -56,7 +56,7 @@ class ConstituencyProcessor(UDProcessor): def process(self, document): sentences = document.sentences - if self._model.uses_xpos(): + if self._model.model.uses_xpos(): words = [[(w.text, w.xpos) for w in s.words] for s in sentences] else: words = [[(w.text, w.upos) for w in s.words] for s in sentences] |