Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-10-28 09:28:01 +0300
committerJohn Bauer <horatio@gmail.com>2022-10-28 09:28:01 +0300
commitbf4204f948dd8db716acfa000e20c8576abd6734 (patch)
tree41faf1d576497c4c95714daf2f39b83827183331
parent9f5a85298630513ebc2b727529002a75e5fc1486 (diff)
Add a rotation to make N non-overlapping dev sets with the remainder being train for vlsp22
-rw-r--r--stanza/utils/datasets/constituency/prepare_con_dataset.py13
-rw-r--r--stanza/utils/datasets/constituency/vtb_split.py5
2 files changed, 15 insertions, 3 deletions
diff --git a/stanza/utils/datasets/constituency/prepare_con_dataset.py b/stanza/utils/datasets/constituency/prepare_con_dataset.py
index 01426f61..92fe559e 100644
--- a/stanza/utils/datasets/constituency/prepare_con_dataset.py
+++ b/stanza/utils/datasets/constituency/prepare_con_dataset.py
@@ -171,6 +171,7 @@ def process_vlsp22(paths, dataset_name, *args):
parser = argparse.ArgumentParser()
parser.add_argument('--subdir', default='VLSP_2022', type=str, help='Where to find the data - allows for using previous versions, if needed')
parser.add_argument('--no_convert_brackets', default=True, action='store_false', dest='convert_brackets', help="Don't convert the VLSP parens RKBT & LKBT to PTB parens")
+ parser.add_argument('--n_splits', default=None, type=int, help='Split the data into this many pieces. Relevant as there is no set training/dev split and no official test data yet, so this allows for N models on N different dev sets')
args = parser.parse_args(args=list(*args))
if os.path.exists(args.subdir):
@@ -189,8 +190,16 @@ def process_vlsp22(paths, dataset_name, *args):
with tempfile.TemporaryDirectory() as tmp_output_path:
vtb_convert.convert_files(vlsp_files, tmp_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets)
# This produces a 0 length test set, just as a placeholder until the actual test set is released
- vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1)
- _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name)
+ if args.n_splits:
+ dev_size = 1.0 / args.n_splits
+ train_size = 1.0 - dev_size
+ for rotation in range(args.n_splits):
+ rotation_name = "%s-%d" % (dataset_name, rotation)
+ vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=train_size, dev_size=dev_size, rotation=(rotation, args.n_splits))
+ _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], rotation_name)
+ else:
+ vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1)
+ _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name)
with open(test_file, "w"):
# create an empty test file - currently we don't have actual test data for VLSP 21
pass
diff --git a/stanza/utils/datasets/constituency/vtb_split.py b/stanza/utils/datasets/constituency/vtb_split.py
index d1539d9f..842eb41f 100644
--- a/stanza/utils/datasets/constituency/vtb_split.py
+++ b/stanza/utils/datasets/constituency/vtb_split.py
@@ -67,7 +67,7 @@ def get_num_samples(org_dir, file_names):
return count
-def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15):
+def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None):
os.makedirs(split_dir, exist_ok=True)
if train_size + dev_size >= 1.0:
@@ -108,6 +108,9 @@ def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.
new_trees = [x.strip() for x in new_trees]
new_trees = [x for x in new_trees if x]
trees.extend(new_trees)
+ if rotation is not None and rotation[0] > 0:
+ rotation_start = len(trees) * rotation[0] // rotation[1]
+ trees = trees[rotation_start:] + trees[:rotation_start]
tree_iter = iter(trees)
for write_path, count_limit in zip(output_names, output_limits):
with open(write_path, 'w', encoding='utf-8') as writer: