diff options
author | John Bauer <horatio@gmail.com> | 2022-09-08 21:08:18 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-08 21:08:18 +0300 |
commit | 6027ab101ec2a6aa113a6182833c69020466bdd8 (patch) | |
tree | 376da338dbff076b3be2789f81f3b7d2e820d7e2 | |
parent | 6c146e285e18b82763ba3d47d584680c8fd68123 (diff) |
Use the same foundation cache as the retag_pipeline to avoid reloading the same pretrains multiple times in the constituency
-rw-r--r-- | stanza/models/constituency/trainer.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index e71986fe..797d9231 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -209,14 +209,16 @@ def evaluate(args, model_file, retag_pipeline): kbest = args['num_generate'] + 1 else: kbest = None + with EvaluateParser(kbest=kbest) as evaluator: + foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() load_args = { 'wordvec_pretrain_file': args['wordvec_pretrain_file'], 'charlm_forward_file': args['charlm_forward_file'], 'charlm_backward_file': args['charlm_backward_file'], 'cuda': args['cuda'], } - trainer = Trainer.load(model_file, args=load_args) + trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache) treebank = tree_reader.read_treebank(args['eval_file']) logger.info("Read %d trees for evaluation", len(treebank)) @@ -474,7 +476,7 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, model_ dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) logger.info("Retagging finished") - foundation_cache = FoundationCache() + foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache() trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file) iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator) |