diff options
author | John Bauer <horatio@gmail.com> | 2022-09-13 06:39:31 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-13 06:39:31 +0300 |
commit | aa88cff80971a48e2ac6592e1dc081ea6bb8d4f7 (patch) | |
tree | df799409f2a805372d3078643785903211af55a9 | |
parent | 7f4bd869ab9776935cdaa80985b58fc33d963de9 (diff) |
Add a tool to evaluate treebanks that are written out by a parser, such as when the constiuency_parser has --predict_file turned on. Allows for easy checking of what happens when multiple models are mixed together.
-rw-r--r-- | stanza/models/constituency/evaluate_treebanks.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/stanza/models/constituency/evaluate_treebanks.py b/stanza/models/constituency/evaluate_treebanks.py new file mode 100644 index 00000000..11f3084b --- /dev/null +++ b/stanza/models/constituency/evaluate_treebanks.py @@ -0,0 +1,36 @@ +""" +Read multiple treebanks, score the results. + +Reports the k-best score if multiple predicted treebanks are given. +""" + +import argparse + +from stanza.models.constituency import tree_reader +from stanza.server.parser_eval import EvaluateParser, ParseResult + + +def main(): + parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold') + parser.add_argument('gold', type=str, help='Which file to load as the gold trees') + parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions. If more than one is given, the evaluation will be "k-best" with the first prediction treated as the canonical') + args = parser.parse_args() + + print("Loading gold treebank: " + args.gold) + gold = tree_reader.read_treebank(args.gold) + print("Loading predicted treebanks: " + args.pred) + pred = [tree_reader.read_treebank(x) for x in args.pred] + + full_results = [ParseResult(parses[0], [*parses[1:]]) + for parses in zip(gold, *pred)] + + if len(pred) <= 1: + kbest = None + else: + kbest = len(pred) + + with EvaluateParser(kbest=kbest) as evaluator: + response = evaluator.process(full_results) + +if __name__ == '__main__': + main() |