Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-09-13 06:21:35 +0300
committerJohn Bauer <horatio@gmail.com>2022-09-13 06:21:35 +0300
commit7f4bd869ab9776935cdaa80985b58fc33d963de9 (patch)
tree1814bf55a85e0e4833e80b497903a72c28d0d6f0
parentb8ba4a7ada33cd1d9ee2c0af98458554c75ded48 (diff)
Refactor a little bit. Make it so the scoring interface can handle either scored trees or trees with no score (another option would be to attach the score directly to a tree)
-rw-r--r--stanza/models/constituency/trainer.py7
-rw-r--r--stanza/server/parser_eval.py15
-rw-r--r--stanza/tests/server/test_parser_eval.py33
3 files changed, 41 insertions, 14 deletions
diff --git a/stanza/models/constituency/trainer.py b/stanza/models/constituency/trainer.py
index 8e270959..e6b38a3e 100644
--- a/stanza/models/constituency/trainer.py
+++ b/stanza/models/constituency/trainer.py
@@ -34,7 +34,7 @@ from stanza.models.constituency.parse_transitions import State, TransitionScheme
from stanza.models.constituency.parse_tree import Tree
from stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler
from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY
-from stanza.server.parser_eval import EvaluateParser
+from stanza.server.parser_eval import EvaluateParser, ParseResult, ScoredTree
tqdm = utils.get_tqdm()
@@ -823,9 +823,6 @@ def build_batch_from_tagged_words(batch_size, data_iterator, model):
tree_batch = parse_transitions.initial_state_from_words(tree_batch, model)
return tree_batch
-ParseResult = namedtuple("ParseResult", ['gold', 'predictions'])
-ParsePrediction = namedtuple("ParsePrediction", ['tree', 'score'])
-
@torch.no_grad()
def parse_sentences(data_iterator, build_batch_fn, batch_size, model, best=True):
"""
@@ -864,7 +861,7 @@ def parse_sentences(data_iterator, build_batch_fn, batch_size, model, best=True)
predicted_tree = tree.get_tree(model)
gold_tree = tree.gold_tree
# TODO: put an actual score here?
- treebank.append(ParseResult(gold_tree, [ParsePrediction(predicted_tree, 1.0)]))
+ treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, 1.0)]))
treebank_indices.append(batch_indices[idx])
remove.add(idx)
diff --git a/stanza/server/parser_eval.py b/stanza/server/parser_eval.py
index 5f0eaaf0..93e8ce1f 100644
--- a/stanza/server/parser_eval.py
+++ b/stanza/server/parser_eval.py
@@ -1,5 +1,8 @@
+"""
+This class runs a Java process to evaluate a treebank prediction using CoreNLP
+"""
-
+from collections import namedtuple
import stanza
from stanza.protobuf import EvaluateParserRequest, EvaluateParserResponse
@@ -8,6 +11,9 @@ from stanza.server.java_protobuf_requests import send_request, build_tree, JavaP
EVALUATE_JAVA = "edu.stanford.nlp.parser.metrics.EvaluateExternalParser"
+ParseResult = namedtuple("ParseResult", ['gold', 'predictions'])
+ScoredTree = namedtuple("ScoredTree", ['tree', 'score'])
+
def build_request(treebank):
"""
treebank should be a list of pairs: [gold, predictions]
@@ -19,7 +25,12 @@ def build_request(treebank):
for gold, predictions in treebank:
parse_result = request.treebank.add()
parse_result.gold.CopyFrom(build_tree(gold, None))
- for prediction, score in predictions:
+ for pred in predictions:
+ if isinstance(pred, tuple):
+ prediction, score = pred
+ else:
+ prediction = pred
+ score = None
parse_result.predicted.append(build_tree(prediction, score))
return request
diff --git a/stanza/tests/server/test_parser_eval.py b/stanza/tests/server/test_parser_eval.py
index c75f5319..aacda49f 100644
--- a/stanza/tests/server/test_parser_eval.py
+++ b/stanza/tests/server/test_parser_eval.py
@@ -13,27 +13,46 @@ from stanza.tests import *
pytestmark = [pytest.mark.travis, pytest.mark.client]
-def build_one_tree_treebank():
+def build_one_tree_treebank(fake_scores=True):
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
gold = trees[0]
- prediction = (gold, 1.0)
+ if fake_scores:
+ prediction = (gold, 1.0)
+ else:
+ prediction = gold
treebank = [(gold, [prediction])]
return treebank
-def test_build_request_one_tree():
- treebank = build_one_tree_treebank()
+def check_build(fake_scores=True):
+ treebank = build_one_tree_treebank(fake_scores)
request = build_request(treebank)
assert len(request.treebank) == 1
check_tree(request.treebank[0].gold, treebank[0][0], None)
assert len(request.treebank[0].predicted) == 1
- check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1])
+ if fake_scores:
+ check_tree(request.treebank[0].predicted[0], treebank[0][1][0][0], treebank[0][1][0][1])
+ else:
+ check_tree(request.treebank[0].predicted[0], treebank[0][1][0], None)
-def test_score_one_tree():
- treebank = build_one_tree_treebank()
+def test_build_tuple_request():
+ check_build(True)
+
+def test_build_notuple_request():
+ check_build(False)
+
+def test_score_one_tree_tuples():
+ treebank = build_one_tree_treebank(True)
+
+ with EvaluateParser(classpath="$CLASSPATH") as ep:
+ response = ep.process(treebank)
+ assert response.f1 == pytest.approx(1.0)
+
+def test_score_one_tree_notuples():
+ treebank = build_one_tree_treebank(False)
with EvaluateParser(classpath="$CLASSPATH") as ep:
response = ep.process(treebank)