diff options
author | John Bauer <horatio@gmail.com> | 2022-08-31 02:17:00 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-08-31 03:44:10 +0300 |
commit | d30b7de0eed2230f4efe7a7ee3f1f7dbb8f2eb0f (patch) | |
tree | 479a3c520ea68a7d4d43a92e0f143b9a3281bf9e | |
parent | 7b476eb85cb7881aa4c33b79c5e419c11116a592 (diff) |
Refactor a bunch of data manipulation methods to data.py
-rw-r--r-- | stanza/models/classifier.py | 80 | ||||
-rw-r--r-- | stanza/models/classifiers/data.py | 66 | ||||
-rw-r--r-- | stanza/tests/classifiers/test_classifier.py | 6 |
3 files changed, 78 insertions, 74 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py index a1db15ac..a2f85966 100644 --- a/stanza/models/classifier.py +++ b/stanza/models/classifier.py @@ -1,6 +1,5 @@ import argparse import ast -import collections import logging import os import random @@ -15,7 +14,6 @@ from stanza.models.common import loss from stanza.models.common import utils from stanza.models.common.foundation_cache import load_bert, load_charlm from stanza.models.common.pretrain import Pretrain -from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID from stanza.models.pos.vocab import CharVocab from stanza.models.classifiers.classifier_args import WVType, ExtraVectors @@ -243,55 +241,6 @@ def parse_args(args=None): return args -def dataset_labels(dataset): - """ - Returns a sorted list of label name - """ - labels = set([x[0] for x in dataset]) - if all(re.match("^[0-9]+$", label) for label in labels): - # if all of the labels are integers, sort numerically - # maybe not super important, but it would be nicer than having - # 10 before 2 - labels = [str(x) for x in sorted(map(int, list(labels)))] - else: - labels = sorted(list(labels)) - return labels - -def dataset_vocab(dataset): - vocab = set() - for line in dataset: - for word in line[1]: - vocab.add(word) - vocab = [PAD, UNK] + list(vocab) - if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK: - raise ValueError("Unexpected values for PAD and UNK!") - return vocab - -def sort_dataset_by_len(dataset): - """ - returns a dict mapping length -> list of items of that length - an OrderedDict is used to that the mapping is sorted from smallest to largest - """ - sorted_dataset = collections.OrderedDict() - lengths = sorted(list(set(len(x[1]) for x in dataset))) - for l in lengths: - sorted_dataset[l] = [] - for item in dataset: - sorted_dataset[len(item[1])].append(item) - return sorted_dataset - -def shuffle_dataset(sorted_dataset): - """ - Given a dataset sorted by len, sorts within each length to make - chunks of roughly the same size. Returns all items as a single list. - """ - dataset = [] - for l in sorted_dataset.keys(): - items = list(sorted_dataset[l]) - random.shuffle(items) - dataset.extend(items) - return dataset - def confusion_dataset(model, dataset): """ Returns a confusion matrix @@ -303,7 +252,7 @@ def confusion_dataset(model, dataset): model.eval() index_label_map = {x: y for (x, y) in enumerate(model.labels)} - dataset_lengths = sort_dataset_by_len(dataset) + dataset_lengths = data.sort_dataset_by_len(dataset) confusion_matrix = {} for label in model.labels: @@ -340,7 +289,7 @@ def score_dataset(model, dataset, label_map=None, if label_map is None: label_map = {x: y for (y, x) in enumerate(model.labels)} correct = 0 - dataset_lengths = sort_dataset_by_len(dataset) + dataset_lengths = data.sort_dataset_by_len(dataset) for length in dataset_lengths.keys(): # TODO: possibly break this up into smaller batches @@ -395,17 +344,6 @@ def score_dev_set(model, dev_set, dev_eval_scoring): else: raise ValueError("Unknown scoring method {}".format(dev_eval_scoring)) -def check_labels(labels, dataset): - """ - Check that all of the labels in the dataset are in the known labels. - Actually, unknown labels could be acceptable if we just treat the model as always wrong. - However, this is a good sanity check to make sure the datasets match - """ - new_labels = dataset_labels(dataset) - not_found = [i for i in new_labels if i not in labels] - if not_found: - raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found)) - def checkpoint_name(filename, epoch, dev_scoring, score): """ Build an informative checkpoint name from a base name, epoch #, and accuracy @@ -465,7 +403,7 @@ def train_model(model, model_file, args, train_set, dev_set, labels): raise ValueError("Unknown loss function {}".format(args.loss)) loss_function.to(device) - train_set_by_len = sort_dataset_by_len(train_set) + train_set_by_len = data.sort_dataset_by_len(train_set) best_score = 0 if args.load_name: @@ -491,7 +429,7 @@ def train_model(model, model_file, args, train_set, dev_set, labels): for epoch in range(args.max_epochs): running_loss = 0.0 epoch_loss = 0.0 - shuffled = shuffle_dataset(train_set_by_len) + shuffled = data.shuffle_dataset(train_set_by_len) model.train() random.shuffle(batch_starts) for batch_num, start_batch in enumerate(batch_starts): @@ -614,8 +552,8 @@ def build_new_model(args, train_set): charmodel_forward = load_charlm(args.charlm_forward_file) charmodel_backward = load_charlm(args.charlm_backward_file) - labels = dataset_labels(train_set) - extra_vocab = dataset_vocab(train_set) + labels = data.dataset_labels(train_set) + extra_vocab = data.dataset_vocab(train_set) bert_model, bert_tokenizer = load_bert(args.bert_model) @@ -647,7 +585,7 @@ def main(args=None): if args.train: train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len) logger.info("Using training set: %s" % args.train_file) - logger.info("Training set has %d labels" % len(dataset_labels(train_set))) + logger.info("Training set has %d labels" % len(data.dataset_labels(train_set))) tlogger.setLevel(logging.DEBUG) elif not args.load_name: if args.save_name: @@ -683,13 +621,13 @@ def main(args=None): logger.info("Using dev set: %s", args.dev_file) logger.info("Training set has %d items", len(train_set)) logger.info("Dev set has %d items", len(dev_set)) - check_labels(model.labels, dev_set) + data.check_labels(model.labels, dev_set) train_model(model, model_file, args, train_set, dev_set, model.labels) test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None) logger.info("Using test set: %s" % args.test_file) - check_labels(model.labels, test_set) + data.check_labels(model.labels, test_set) if args.test_remap_labels is None: confusion_matrix = confusion_dataset(model, test_set) diff --git a/stanza/models/classifiers/data.py b/stanza/models/classifiers/data.py index 85d44703..3dfdf283 100644 --- a/stanza/models/classifiers/data.py +++ b/stanza/models/classifiers/data.py @@ -1,11 +1,14 @@ """Stanza models classifier data functions.""" +import collections import logging import json +import random import re from typing import List import stanza.models.classifiers.classifier_args as classifier_args +from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID logger = logging.getLogger('stanza') @@ -61,3 +64,66 @@ def read_dataset(dataset, wordvec_type: classifier_args.WVType, min_len: int) -> if min_len: lines = [x for x in lines if len(x[1]) >= min_len] return lines + +def dataset_labels(dataset): + """ + Returns a sorted list of label name + """ + labels = set([x[0] for x in dataset]) + if all(re.match("^[0-9]+$", label) for label in labels): + # if all of the labels are integers, sort numerically + # maybe not super important, but it would be nicer than having + # 10 before 2 + labels = [str(x) for x in sorted(map(int, list(labels)))] + else: + labels = sorted(list(labels)) + return labels + +def dataset_vocab(dataset): + vocab = set() + for line in dataset: + for word in line[1]: + vocab.add(word) + vocab = [PAD, UNK] + list(vocab) + if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK: + raise ValueError("Unexpected values for PAD and UNK!") + return vocab + +def sort_dataset_by_len(dataset): + """ + returns a dict mapping length -> list of items of that length + an OrderedDict is used to that the mapping is sorted from smallest to largest + """ + sorted_dataset = collections.OrderedDict() + lengths = sorted(list(set(len(x[1]) for x in dataset))) + for l in lengths: + sorted_dataset[l] = [] + for item in dataset: + sorted_dataset[len(item[1])].append(item) + return sorted_dataset + +def shuffle_dataset(sorted_dataset): + """ + Given a dataset sorted by len, sorts within each length to make + chunks of roughly the same size. Returns all items as a single list. + """ + dataset = [] + for l in sorted_dataset.keys(): + items = list(sorted_dataset[l]) + random.shuffle(items) + dataset.extend(items) + return dataset + + +def check_labels(labels, dataset): + """ + Check that all of the labels in the dataset are in the known labels. + + Actually, unknown labels could be acceptable if we just treat the model as always wrong. + However, this is a good sanity check to make sure the datasets match + """ + new_labels = dataset_labels(dataset) + not_found = [i for i in new_labels if i not in labels] + if not_found: + raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found)) + diff --git a/stanza/tests/classifiers/test_classifier.py b/stanza/tests/classifiers/test_classifier.py index 5e170f0f..c21bee60 100644 --- a/stanza/tests/classifiers/test_classifier.py +++ b/stanza/tests/classifiers/test_classifier.py @@ -85,7 +85,7 @@ def test_dataset_vocab(train_file): Converting a dataset to vocab should have a specific set of words along with PAD and UNK """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) - vocab = classifier.dataset_vocab(train_set) + vocab = data.dataset_vocab(train_set) expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y]) assert set(vocab) == expected @@ -94,7 +94,7 @@ def test_dataset_labels(train_file): Test the extraction of labels from a dataset """ train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) - labels = classifier.dataset_labels(train_set) + labels = data.dataset_labels(train_set) assert labels == ["0", "1", "2"] def build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=None): @@ -125,7 +125,7 @@ def run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=Non """ model, train_set, args = build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args) dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len) - labels = classifier.dataset_labels(train_set) + labels = data.dataset_labels(train_set) save_filename = os.path.join(args.save_dir, args.save_name) classifier.train_model(model, save_filename, args, train_set, dev_set, labels) |