diff options
author | John Bauer <horatio@gmail.com> | 2022-09-13 06:21:35 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-13 06:21:35 +0300 |
commit | 7f4bd869ab9776935cdaa80985b58fc33d963de9 (patch) | |
tree | 1814bf55a85e0e4833e80b497903a72c28d0d6f0 | |
parent | b8ba4a7ada33cd1d9ee2c0af98458554c75ded48 (diff) |
Refactor a little bit. Make it so the scoring interface can handle either scored trees or trees with no score (another option would be to attach the score directly to a tree)
-rw-r--r-- | stanza/models/constituency/trainer.py | 7 | ||||
-rw-r--r-- | stanza/server/parser_eval.py | 15 | ||||
-rw-r--r-- | stanza/tests/server/test_parser_eval.py | 33 |
3 files changed, 41 insertions, 14 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py index 8e270959..e6b38a3e 100644 --- a/stanza/models/constituency/trainer.py +++ b/stanza/models/constituency/trainer.py @@ -34,7 +34,7 @@ from stanza.models.constituency.parse_transitions import State, TransitionScheme from stanza.models.constituency.parse_tree import Tree from stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY -from stanza.server.parser_eval import EvaluateParser +from stanza.server.parser_eval import EvaluateParser, ParseResult, ScoredTree tqdm = utils.get_tqdm() @@ -823,9 +823,6 @@ def build_batch_from_tagged_words(batch_size, data_iterator, model): tree_batch = parse_transitions.initial_state_from_words(tree_batch, model) return tree_batch -ParseResult = namedtuple("ParseResult", ['gold', 'predictions']) -ParsePrediction = namedtuple("ParsePrediction", ['tree', 'score']) - @torch.no_grad() def parse_sentences(data_iterator, build_batch_fn, batch_size, model, best=True): """ @@ -864,7 +861,7 @@ def parse_sentences(data_iterator, build_batch_fn, batch_size, model, best=True) predicted_tree = tree.get_tree(model) gold_tree = tree.gold_tree # TODO: put an actual score here? - treebank.append(ParseResult(gold_tree, [ParsePrediction(predicted_tree, 1.0)])) + treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, 1.0)])) treebank_indices.append(batch_indices[idx]) remove.add(idx) diff --git a/stanza/server/parser_eval.py b/stanza/server/parser_eval.py index 5f0eaaf0..93e8ce1f 100644 --- a/stanza/server/parser_eval.py +++ b/stanza/server/parser_eval.py @@ -1,5 +1,8 @@ +""" +This class runs a Java process to evaluate a treebank prediction using CoreNLP +""" - +from collections import namedtuple import stanza from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse @@ -8,6 +11,9 @@ from stanza.server.java_protobuf_requests import send_request, build_tree, JavaP EVALUATE_JAVA = "edu.stanford.nlp.parser.metrics.EvaluateExternalParser" +ParseResult = namedtuple("ParseResult", ['gold', 'predictions']) +ScoredTree = namedtuple("ScoredTree", ['tree', 'score']) + def build_request(treebank): """ treebank should be a list of pairs: [gold, predictions] @@ -19,7 +25,12 @@ def build_request(treebank): for gold, predictions in treebank: parse_result = request.treebank.add() parse_result.gold.CopyFrom(build_tree(gold, None)) - for prediction, score in predictions: + for pred in predictions: + if isinstance(pred, tuple): + prediction, score = pred + else: + prediction = pred + score = None parse_result.predicted.append(build_tree(prediction, score)) return request diff --git a/stanza/tests/server/test_parser_eval.py b/stanza/tests/server/test_parser_eval.py index c75f5319..aacda49f 100644 --- a/stanza/tests/server/test_parser_eval.py +++ b/stanza/tests/server/test_parser_eval.py @@ -13,27 +13,46 @@ from stanza.tests import * pytestmark = [pytest.mark.travis, pytest.mark.client] -def build_one_tree_treebank(): +def build_one_tree_treebank(fake_scores=True): text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))" trees = tree_reader.read_trees(text) assert len(trees) == 1 gold = trees[0] - prediction = (gold, 1.0) + if fake_scores: + prediction = (gold, 1.0) + else: + prediction = gold treebank = [(gold, [prediction])] return treebank -def test_build_request_one_tree(): - treebank = build_one_tree_treebank() +def check_build(fake_scores=True): + treebank = build_one_tree_treebank(fake_scores) request = build_request(treebank) assert len(request.treebank) == 1 check_tree(request.treebank[0].gold, treebank[0][0], None) assert len(request.treebank[0].predicted) == 1 - check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1]) + if fake_scores: + check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1]) + else: + check_tree(request.treebank[0].predicted[0], treebank[0][1][0], None) -def test_score_one_tree(): - treebank = build_one_tree_treebank() +def test_build_tuple_request(): + check_build(True) + +def test_build_notuple_request(): + check_build(False) + +def test_score_one_tree_tuples(): + treebank = build_one_tree_treebank(True) + + with EvaluateParser(classpath="$CLASSPATH") as ep: + response = ep.process(treebank) + assert response.f1 == pytest.approx(1.0) + +def test_score_one_tree_notuples(): + treebank = build_one_tree_treebank(False) with EvaluateParser(classpath="$CLASSPATH") as ep: response = ep.process(treebank) |