diff options
author | John Bauer <horatio@gmail.com> | 2022-11-11 03:05:33 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-11 03:05:33 +0300 |
commit | 2230425da15e24733c34e389702dccf6f59dd755 (patch) | |
tree | 8a3cdc3fd97dafe1a0d7f090dd2fe52253152385 | |
parent | d0a729801412372cb553a3328010675f404a1dca (diff) |
Split the retagging operation into chunks. The tqdm is no longer as smooth, but hopefully the peak memory usage of turning the trees into Stanza docs is less, which was especially an issue when processing 1.2M silver trees
-rw-r--r-- | stanza/models/constituency/retagging.py | 3 | ||||
-rw-r--r-- | stanza/models/constituency/utils.py | 53 |
2 files changed, 33 insertions, 23 deletions
diff --git a/stanza/models/constituency/retagging.py b/stanza/models/constituency/retagging.py index 283d83f6..4bf9e69e 100644 --- a/stanza/models/constituency/retagging.py +++ b/stanza/models/constituency/retagging.py @@ -55,8 +55,7 @@ def build_retag_pipeline(args): retag_args = {"lang": lang, "processors": "tokenize, pos", "tokenize_pretokenized": True, - "package": {"pos": package}, - "pos_tqdm": True} + "package": {"pos": package}} if args['retag_model_path'] is not None: retag_args['pos_model_path'] = args['retag_model_path'] retag_pipeline = Pipeline(**retag_args) diff --git a/stanza/models/constituency/utils.py b/stanza/models/constituency/utils.py index 4f46abea..3e94491d 100644 --- a/stanza/models/constituency/utils.py +++ b/stanza/models/constituency/utils.py @@ -10,6 +10,9 @@ import torch.nn as nn from torch import optim from stanza.models.common.doc import TEXT, Document +from stanza.models.common.utils import get_tqdm + +tqdm = get_tqdm() DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 } DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 } @@ -81,28 +84,36 @@ def retag_trees(trees, pipeline, xpos=True): if len(trees) == 0: return trees - sentences = [] - try: - for idx, tree in enumerate(trees): - tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()] - sentences.append(tokens) - except ValueError as e: - raise ValueError("Unable to process tree %d" % idx) from e - - doc = Document(sentences) - doc = pipeline(doc) - if xpos: - tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences] - else: - tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences] - new_trees = [] - for tree_idx, (tree, tags) in enumerate(zip(trees, tag_lists)): - try: - new_tree = replace_tags(tree, tags) - new_trees.append(new_tree) - except ValueError as e: - raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e + chunk_size = 1000 + with tqdm(total=len(trees)) as pbar: + for chunk_start in range(0, len(trees), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(trees)) + chunk = trees[chunk_start:chunk_end] + sentences = [] + try: + for idx, tree in enumerate(chunk): + tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()] + sentences.append(tokens) + except ValueError as e: + raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e + + doc = Document(sentences) + doc = pipeline(doc) + if xpos: + tag_lists = [[x.xpos for x in sentence.words] for sentence in doc.sentences] + else: + tag_lists = [[x.upos for x in sentence.words] for sentence in doc.sentences] + + for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)): + try: + new_tree = replace_tags(tree, tags) + new_trees.append(new_tree) + pbar.update(1) + except ValueError as e: + raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e + if len(new_trees) != len(trees): + raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees))) return new_trees |