Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/utils/datasets/ner/prepare_ner_file.py')
-rw-r--r--stanza/utils/datasets/ner/prepare_ner_file.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/stanza/utils/datasets/ner/prepare_ner_file.py b/stanza/utils/datasets/ner/prepare_ner_file.py
new file mode 100644
index 00000000..e5fe2220
--- /dev/null
+++ b/stanza/utils/datasets/ner/prepare_ner_file.py
@@ -0,0 +1,74 @@
+"""
+This script converts NER data from the CoNLL03 format to the latest CoNLL-U format. The script assumes that in the
+input column format data, the token is always in the first column, while the NER tag is always in the last column.
+"""
+
+import argparse
+import json
+
+MIN_NUM_FIELD = 2
+MAX_NUM_FIELD = 5
+
+DOC_START_TOKEN = '-DOCSTART-'
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Convert the conll03 format data into conllu format.")
+ parser.add_argument('input', help='Input conll03 format data filename.')
+ parser.add_argument('output', help='Output json filename.')
+ args = parser.parse_args()
+ return args
+
+def main():
+ args = parse_args()
+ process_dataset(args.input, args.output)
+
+def process_dataset(input_filename, output_filename):
+ sentences = load_conll03(input_filename)
+ print("{} examples loaded from {}".format(len(sentences), input_filename))
+
+ document = []
+ for (words, tags) in sentences:
+ sent = []
+ for w, t in zip(words, tags):
+ sent += [{'text': w, 'ner': t}]
+ document += [sent]
+
+ with open(output_filename, 'w') as outfile:
+ json.dump(document, outfile)
+ print("Generated json file {}.".format(output_filename))
+
+# TODO: make skip_doc_start an argument
+def load_conll03(filename, skip_doc_start=True):
+ cached_lines = []
+ examples = []
+ with open(filename) as infile:
+ for line in infile:
+ line = line.strip()
+ if skip_doc_start and DOC_START_TOKEN in line:
+ continue
+ if len(line) > 0:
+ array = line.split()
+ if len(array) < MIN_NUM_FIELD:
+ continue
+ else:
+ cached_lines.append(line)
+ elif len(cached_lines) > 0:
+ example = process_cache(cached_lines)
+ examples.append(example)
+ cached_lines = []
+ if len(cached_lines) > 0:
+ examples.append(process_cache(cached_lines))
+ return examples
+
+def process_cache(cached_lines):
+ tokens = []
+ ner_tags = []
+ for line in cached_lines:
+ array = line.split()
+ assert len(array) >= MIN_NUM_FIELD and len(array) <= MAX_NUM_FIELD
+ tokens.append(array[0])
+ ner_tags.append(array[-1])
+ return (tokens, ner_tags)
+
+if __name__ == '__main__':
+ main()