Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-11 03:05:33 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-11 03:05:33 +0300
commit2230425da15e24733c34e389702dccf6f59dd755 (patch)
tree8a3cdc3fd97dafe1a0d7f090dd2fe52253152385
parentd0a729801412372cb553a3328010675f404a1dca (diff)
Split the retagging operation into chunks. The tqdm is no longer as smooth, but hopefully the peak memory usage of turning the trees into Stanza docs is less, which was especially an issue when processing 1.2M silver trees
-rw-r--r--stanza/models/constituency/retagging.py3
-rw-r--r--stanza/models/constituency/utils.py53
2 files changed, 33 insertions, 23 deletions
diff --git a/stanza/models/constituency/retagging.py b/stanza/models/constituency/retagging.py
index 283d83f6..4bf9e69e 100644
--- a/stanza/models/constituency/retagging.py
+++ b/stanza/models/constituency/retagging.py
@@ -55,8 +55,7 @@ def build_retag_pipeline(args):
retag_args = {"lang": lang,
"processors": "tokenize, pos",
"tokenize_pretokenized": True,
- "package": {"pos": package},
- "pos_tqdm": True}
+ "package": {"pos": package}}
if args['retag_model_path'] is not None:
retag_args['pos_model_path'] = args['retag_model_path']
retag_pipeline = Pipeline(**retag_args)
diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py
index 4f46abea..3e94491d 100644
--- a/stanza/models/constituency/utils.py
+++ b/stanza/models/constituency/utils.py
@@ -10,6 +10,9 @@ import torch.nn as nn
from torch import optim
from stanza.models.common.doc import TEXT, Document
+from stanza.models.common.utils import get_tqdm
+
+tqdm = get_tqdm()
DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 }
DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
@@ -81,28 +84,36 @@ def retag_trees(trees, pipeline, xpos=True):
if len(trees) == 0:
return trees
- sentences = []
- try:
- for idx, tree in enumerate(trees):
- tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]
- sentences.append(tokens)
- except ValueError as e:
- raise ValueError("Unable to process tree %d" % idx) from e
-
- doc = Document(sentences)
- doc = pipeline(doc)
- if xpos:
- tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences]
- else:
- tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences]
-
new_trees = []
- for tree_idx, (tree, tags) in enumerate(zip(trees, tag_lists)):
- try:
- new_tree = replace_tags(tree, tags)
- new_trees.append(new_tree)
- except ValueError as e:
- raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e
+ chunk_size = 1000
+ with tqdm(total=len(trees)) as pbar:
+ for chunk_start in range(0, len(trees), chunk_size):
+ chunk_end = min(chunk_start + chunk_size, len(trees))
+ chunk = trees[chunk_start:chunk_end]
+ sentences = []
+ try:
+ for idx, tree in enumerate(chunk):
+ tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]
+ sentences.append(tokens)
+ except ValueError as e:
+ raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e
+
+ doc = Document(sentences)
+ doc = pipeline(doc)
+ if xpos:
+ tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences]
+ else:
+ tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences]
+
+ for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):
+ try:
+ new_tree = replace_tags(tree, tags)
+ new_trees.append(new_tree)
+ pbar.update(1)
+ except ValueError as e:
+ raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e
+ if len(new_trees) != len(trees):
+ raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees)))
return new_trees