Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-08 21:08:18 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-08 21:08:18 +0300
commit6027ab101ec2a6aa113a6182833c69020466bdd8 (patch)
tree376da338dbff076b3be2789f81f3b7d2e820d7e2
parent6c146e285e18b82763ba3d47d584680c8fd68123 (diff)
Use the same foundation cache as the retag_pipeline to avoid reloading the same pretrains multiple times in the constituency
-rw-r--r--stanza/models/constituency/trainer.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index e71986fe..797d9231 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -209,14 +209,16 @@ def evaluate(args, model_file, retag_pipeline):
kbest = args['num_generate'] + 1
else:
kbest = None
+
with EvaluateParser(kbest=kbest) as evaluator:
+ foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
load_args = {
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
'charlm_forward_file': args['charlm_forward_file'],
'charlm_backward_file': args['charlm_backward_file'],
'cuda': args['cuda'],
}
- trainer = Trainer.load(model_file, args=load_args)
+ trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
treebank = tree_reader.read_treebank(args['eval_file'])
logger.info("Read %d trees for evaluation", len(treebank))
@@ -474,7 +476,7 @@ def train(args, model_save_file, model_load_file, model_save_latest_file, model_
dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos'])
logger.info("Retagging finished")
- foundation_cache = FoundationCache()
+ foundation_cache = retag_pipeline.foundation_cache if retag_pipeline else FoundationCache()
trainer, train_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, foundation_cache, model_load_file)
iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, foundation_cache, model_save_file, model_save_latest_file, model_save_each_file, evaluator)