Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-13 06:39:31 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-13 06:39:31 +0300
commitaa88cff80971a48e2ac6592e1dc081ea6bb8d4f7 (patch)
treedf799409f2a805372d3078643785903211af55a9
parent7f4bd869ab9776935cdaa80985b58fc33d963de9 (diff)
Add a tool to evaluate treebanks that are written out by a parser, such as when the constiuency_parser has --predict_file turned on. Allows for easy checking of what happens when multiple models are mixed together.
-rw-r--r--stanza/models/constituency/evaluate_treebanks.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/stanza/models/constituency/evaluate_treebanks.py b/stanza/models/constituency/evaluate_treebanks.py
new file mode 100644
index 00000000..11f3084b
--- /dev/null
+++ b/stanza/models/constituency/evaluate_treebanks.py
@@ -0,0 +1,36 @@
+"""
+Read multiple treebanks, score the results.
+
+Reports the k-best score if multiple predicted treebanks are given.
+"""
+
+import argparse
+
+from stanza.models.constituency import tree_reader
+from stanza.server.parser_eval import EvaluateParser, ParseResult
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold')
+ parser.add_argument('gold', type=str, help='Which file to load as the gold trees')
+ parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions. If more than one is given, the evaluation will be "k-best" with the first prediction treated as the canonical')
+ args = parser.parse_args()
+
+ print("Loading gold treebank: " + args.gold)
+ gold = tree_reader.read_treebank(args.gold)
+ print("Loading predicted treebanks: " + args.pred)
+ pred = [tree_reader.read_treebank(x) for x in args.pred]
+
+ full_results = [ParseResult(parses[0], [*parses[1:]])
+ for parses in zip(gold, *pred)]
+
+ if len(pred) <= 1:
+ kbest = None
+ else:
+ kbest = len(pred)
+
+ with EvaluateParser(kbest=kbest) as evaluator:
+ response = evaluator.process(full_results)
+
+if __name__ == '__main__':
+ main()