Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-08-31 02:17:00 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-31 03:44:10 +0300
commitd30b7de0eed2230f4efe7a7ee3f1f7dbb8f2eb0f (patch)
tree479a3c520ea68a7d4d43a92e0f143b9a3281bf9e
parent7b476eb85cb7881aa4c33b79c5e419c11116a592 (diff)
Refactor a bunch of data manipulation methods to data.py
-rw-r--r--stanza/models/classifier.py80
-rw-r--r--stanza/models/classifiers/data.py66
-rw-r--r--stanza/tests/classifiers/test_classifier.py6
3 files changed, 78 insertions, 74 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index a1db15ac..a2f85966 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -1,6 +1,5 @@
import argparse
import ast
-import collections
import logging
import os
import random
@@ -15,7 +14,6 @@ from stanza.models.common import loss
from stanza.models.common import utils
from stanza.models.common.foundation_cache import load_bert, load_charlm
from stanza.models.common.pretrain import Pretrain
-from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
from stanza.models.pos.vocab import CharVocab
from stanza.models.classifiers.classifier_args import WVType, ExtraVectors
@@ -243,55 +241,6 @@ def parse_args(args=None):
return args
-def dataset_labels(dataset):
- """
- Returns a sorted list of label name
- """
- labels = set([x[0] for x in dataset])
- if all(re.match("^[0-9]+$", label) for label in labels):
- # if all of the labels are integers, sort numerically
- # maybe not super important, but it would be nicer than having
- # 10 before 2
- labels = [str(x) for x in sorted(map(int, list(labels)))]
- else:
- labels = sorted(list(labels))
- return labels
-
-def dataset_vocab(dataset):
- vocab = set()
- for line in dataset:
- for word in line[1]:
- vocab.add(word)
- vocab = [PAD, UNK] + list(vocab)
- if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
- raise ValueError("Unexpected values for PAD and UNK!")
- return vocab
-
-def sort_dataset_by_len(dataset):
- """
- returns a dict mapping length -> list of items of that length
- an OrderedDict is used to that the mapping is sorted from smallest to largest
- """
- sorted_dataset = collections.OrderedDict()
- lengths = sorted(list(set(len(x[1]) for x in dataset)))
- for l in lengths:
- sorted_dataset[l] = []
- for item in dataset:
- sorted_dataset[len(item[1])].append(item)
- return sorted_dataset
-
-def shuffle_dataset(sorted_dataset):
- """
- Given a dataset sorted by len, sorts within each length to make
- chunks of roughly the same size. Returns all items as a single list.
- """
- dataset = []
- for l in sorted_dataset.keys():
- items = list(sorted_dataset[l])
- random.shuffle(items)
- dataset.extend(items)
- return dataset
-
def confusion_dataset(model, dataset):
"""
Returns a confusion matrix
@@ -303,7 +252,7 @@ def confusion_dataset(model, dataset):
model.eval()
index_label_map = {x: y for (x, y) in enumerate(model.labels)}
- dataset_lengths = sort_dataset_by_len(dataset)
+ dataset_lengths = data.sort_dataset_by_len(dataset)
confusion_matrix = {}
for label in model.labels:
@@ -340,7 +289,7 @@ def score_dataset(model, dataset, label_map=None,
if label_map is None:
label_map = {x: y for (y, x) in enumerate(model.labels)}
correct = 0
- dataset_lengths = sort_dataset_by_len(dataset)
+ dataset_lengths = data.sort_dataset_by_len(dataset)
for length in dataset_lengths.keys():
# TODO: possibly break this up into smaller batches
@@ -395,17 +344,6 @@ def score_dev_set(model, dev_set, dev_eval_scoring):
else:
raise ValueError("Unknown scoring method {}".format(dev_eval_scoring))
-def check_labels(labels, dataset):
- """
- Check that all of the labels in the dataset are in the known labels.
- Actually, unknown labels could be acceptable if we just treat the model as always wrong.
- However, this is a good sanity check to make sure the datasets match
- """
- new_labels = dataset_labels(dataset)
- not_found = [i for i in new_labels if i not in labels]
- if not_found:
- raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))
-
def checkpoint_name(filename, epoch, dev_scoring, score):
"""
Build an informative checkpoint name from a base name, epoch #, and accuracy
@@ -465,7 +403,7 @@ def train_model(model, model_file, args, train_set, dev_set, labels):
raise ValueError("Unknown loss function {}".format(args.loss))
loss_function.to(device)
- train_set_by_len = sort_dataset_by_len(train_set)
+ train_set_by_len = data.sort_dataset_by_len(train_set)
best_score = 0
if args.load_name:
@@ -491,7 +429,7 @@ def train_model(model, model_file, args, train_set, dev_set, labels):
for epoch in range(args.max_epochs):
running_loss = 0.0
epoch_loss = 0.0
- shuffled = shuffle_dataset(train_set_by_len)
+ shuffled = data.shuffle_dataset(train_set_by_len)
model.train()
random.shuffle(batch_starts)
for batch_num, start_batch in enumerate(batch_starts):
@@ -614,8 +552,8 @@ def build_new_model(args, train_set):
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)
- labels = dataset_labels(train_set)
- extra_vocab = dataset_vocab(train_set)
+ labels = data.dataset_labels(train_set)
+ extra_vocab = data.dataset_vocab(train_set)
bert_model, bert_tokenizer = load_bert(args.bert_model)
@@ -647,7 +585,7 @@ def main(args=None):
if args.train:
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
logger.info("Using training set: %s" % args.train_file)
- logger.info("Training set has %d labels" % len(dataset_labels(train_set)))
+ logger.info("Training set has %d labels" % len(data.dataset_labels(train_set)))
tlogger.setLevel(logging.DEBUG)
elif not args.load_name:
if args.save_name:
@@ -683,13 +621,13 @@ def main(args=None):
logger.info("Using dev set: %s", args.dev_file)
logger.info("Training set has %d items", len(train_set))
logger.info("Dev set has %d items", len(dev_set))
- check_labels(model.labels, dev_set)
+ data.check_labels(model.labels, dev_set)
train_model(model, model_file, args, train_set, dev_set, model.labels)
test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)
logger.info("Using test set: %s" % args.test_file)
- check_labels(model.labels, test_set)
+ data.check_labels(model.labels, test_set)
if args.test_remap_labels is None:
confusion_matrix = confusion_dataset(model, test_set)
diff --git a/stanza/models/classifiers/data.py b/stanza/models/classifiers/data.py
index 85d44703..3dfdf283 100644
--- a/stanza/models/classifiers/data.py
+++ b/stanza/models/classifiers/data.py
@@ -1,11 +1,14 @@
"""Stanza models classifier data functions."""
+import collections
import logging
import json
+import random
import re
from typing import List
import stanza.models.classifiers.classifier_args as classifier_args
+from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
logger = logging.getLogger('stanza')
@@ -61,3 +64,66 @@ def read_dataset(dataset, wordvec_type: classifier_args.WVType, min_len: int) ->
if min_len:
lines = [x for x in lines if len(x[1]) >= min_len]
return lines
+
+def dataset_labels(dataset):
+ """
+ Returns a sorted list of label name
+ """
+ labels = set([x[0] for x in dataset])
+ if all(re.match("^[0-9]+$", label) for label in labels):
+ # if all of the labels are integers, sort numerically
+ # maybe not super important, but it would be nicer than having
+ # 10 before 2
+ labels = [str(x) for x in sorted(map(int, list(labels)))]
+ else:
+ labels = sorted(list(labels))
+ return labels
+
+def dataset_vocab(dataset):
+ vocab = set()
+ for line in dataset:
+ for word in line[1]:
+ vocab.add(word)
+ vocab = [PAD, UNK] + list(vocab)
+ if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
+ raise ValueError("Unexpected values for PAD and UNK!")
+ return vocab
+
+def sort_dataset_by_len(dataset):
+ """
+ returns a dict mapping length -> list of items of that length
+ an OrderedDict is used to that the mapping is sorted from smallest to largest
+ """
+ sorted_dataset = collections.OrderedDict()
+ lengths = sorted(list(set(len(x[1]) for x in dataset)))
+ for l in lengths:
+ sorted_dataset[l] = []
+ for item in dataset:
+ sorted_dataset[len(item[1])].append(item)
+ return sorted_dataset
+
+def shuffle_dataset(sorted_dataset):
+ """
+ Given a dataset sorted by len, sorts within each length to make
+ chunks of roughly the same size. Returns all items as a single list.
+ """
+ dataset = []
+ for l in sorted_dataset.keys():
+ items = list(sorted_dataset[l])
+ random.shuffle(items)
+ dataset.extend(items)
+ return dataset
+
+
+def check_labels(labels, dataset):
+ """
+ Check that all of the labels in the dataset are in the known labels.
+
+ Actually, unknown labels could be acceptable if we just treat the model as always wrong.
+ However, this is a good sanity check to make sure the datasets match
+ """
+ new_labels = dataset_labels(dataset)
+ not_found = [i for i in new_labels if i not in labels]
+ if not_found:
+ raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))
+
diff --git a/stanza/tests/classifiers/test_classifier.py b/stanza/tests/classifiers/test_classifier.py
index 5e170f0f..c21bee60 100644
--- a/stanza/tests/classifiers/test_classifier.py
+++ b/stanza/tests/classifiers/test_classifier.py
@@ -85,7 +85,7 @@ def test_dataset_vocab(train_file):
Converting a dataset to vocab should have a specific set of words along with PAD and UNK
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
- vocab = classifier.dataset_vocab(train_set)
+ vocab = data.dataset_vocab(train_set)
expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])
assert set(vocab) == expected
@@ -94,7 +94,7 @@ def test_dataset_labels(train_file):
Test the extraction of labels from a dataset
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
- labels = classifier.dataset_labels(train_set)
+ labels = data.dataset_labels(train_set)
assert labels == ["0", "1", "2"]
def build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=None):
@@ -125,7 +125,7 @@ def run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=Non
"""
model, train_set, args = build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args)
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
- labels = classifier.dataset_labels(train_set)
+ labels = data.dataset_labels(train_set)
save_filename = os.path.join(args.save_dir, args.save_name)
classifier.train_model(model, save_filename, args, train_set, dev_set, labels)