diff options
author | John Bauer <horatio@gmail.com> | 2022-10-28 09:28:01 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-10-28 09:28:01 +0300 |
commit | bf4204f948dd8db716acfa000e20c8576abd6734 (patch) | |
tree | 41faf1d576497c4c95714daf2f39b83827183331 | |
parent | 9f5a85298630513ebc2b727529002a75e5fc1486 (diff) |
Add a rotation to make N non-overlapping dev sets with the remainder being train for vlsp22
-rw-r--r-- | stanza/utils/datasets/constituency/prepare_con_dataset.py | 13 | ||||
-rw-r--r-- | stanza/utils/datasets/constituency/vtb_split.py | 5 |
2 files changed, 15 insertions, 3 deletions
diff --git a/stanza/utils/datasets/constituency/prepare_con_dataset.py b/stanza/utils/datasets/constituency/prepare_con_dataset.py index 01426f61..92fe559e 100644 --- a/stanza/utils/datasets/constituency/prepare_con_dataset.py +++ b/stanza/utils/datasets/constituency/prepare_con_dataset.py @@ -171,6 +171,7 @@ def process_vlsp22(paths, dataset_name, *args): parser = argparse.ArgumentParser() parser.add_argument('--subdir', default='VLSP_2022', type=str, help='Where to find the data - allows for using previous versions, if needed') parser.add_argument('--no_convert_brackets', default=True, action='store_false', dest='convert_brackets', help="Don't convert the VLSP parens RKBT & LKBT to PTB parens") + parser.add_argument('--n_splits', default=None, type=int, help='Split the data into this many pieces. Relevant as there is no set training/dev split and no official test data yet, so this allows for N models on N different dev sets') args = parser.parse_args(args=list(*args)) if os.path.exists(args.subdir): @@ -189,8 +190,16 @@ def process_vlsp22(paths, dataset_name, *args): with tempfile.TemporaryDirectory() as tmp_output_path: vtb_convert.convert_files(vlsp_files, tmp_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets) # This produces a 0 length test set, just as a placeholder until the actual test set is released - vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1) - _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name) + if args.n_splits: + dev_size = 1.0 / args.n_splits + train_size = 1.0 - dev_size + for rotation in range(args.n_splits): + rotation_name = "%s-%d" % (dataset_name, rotation) + vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], rotation_name, train_size=train_size, dev_size=dev_size, rotation=(rotation, args.n_splits)) + _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], rotation_name) + else: + vtb_split.split_files(tmp_output_path, paths["CONSTITUENCY_DATA_DIR"], dataset_name, train_size=0.9, dev_size=0.1) + _, _, test_file = vtb_split.create_paths(paths["CONSTITUENCY_DATA_DIR"], dataset_name) with open(test_file, "w"): # create an empty test file - currently we don't have actual test data for VLSP 21 pass diff --git a/stanza/utils/datasets/constituency/vtb_split.py b/stanza/utils/datasets/constituency/vtb_split.py index d1539d9f..842eb41f 100644 --- a/stanza/utils/datasets/constituency/vtb_split.py +++ b/stanza/utils/datasets/constituency/vtb_split.py @@ -67,7 +67,7 @@ def get_num_samples(org_dir, file_names): return count -def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15): +def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None): os.makedirs(split_dir, exist_ok=True) if train_size + dev_size >= 1.0: @@ -108,6 +108,9 @@ def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0. new_trees = [x.strip() for x in new_trees] new_trees = [x for x in new_trees if x] trees.extend(new_trees) + if rotation is not None and rotation[0] > 0: + rotation_start = len(trees) * rotation[0] // rotation[1] + trees = trees[rotation_start:] + trees[:rotation_start] tree_iter = iter(trees) for write_path, count_limit in zip(output_names, output_limits): with open(write_path, 'w', encoding='utf-8') as writer: |