diff options
Diffstat (limited to 'stanza/models/tokenization/utils.py')
-rw-r--r-- | stanza/models/tokenization/utils.py | 49 |
1 files changed, 33 insertions, 16 deletions
diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py index 2cd67e6b..ea7bda47 100644 --- a/stanza/models/tokenization/utils.py +++ b/stanza/models/tokenization/utils.py @@ -30,7 +30,7 @@ def load_mwt_dict(filename): def process_sentence(sentence, mwt_dict=None): sent = [] i = 0 - for tok, p, additional_info in sentence: + for tok, p, position_info in sentence: expansion = None if (p == 3 or p == 4) and mwt_dict is not None: # MWT found, (attempt to) expand it! @@ -39,20 +39,22 @@ def process_sentence(sentence, mwt_dict=None): elif tok.lower() in mwt_dict: expansion = mwt_dict[tok.lower()][0] if expansion is not None: - infostr = None if len(additional_info) == 0 else '|'.join([f"{k}={additional_info[k]}" for k in additional_info]) sent.append({ID: (i+1, i+len(expansion)), TEXT: tok}) - if infostr is not None: sent[-1][MISC] = infostr + if position_info is not None: + sent[-1][START_CHAR] = position_info[0] + sent[-1][END_CHAR] = position_info[1] for etok in expansion: sent.append({ID: (i+1, ), TEXT: etok}) i += 1 else: if len(tok) <= 0: continue - if p == 3 or p == 4: - additional_info['MWT'] = 'Yes' - infostr = None if len(additional_info) == 0 else '|'.join([f"{k}={additional_info[k]}" for k in additional_info]) sent.append({ID: (i+1, ), TEXT: tok}) - if infostr is not None: sent[-1][MISC] = infostr + if position_info is not None: + sent[-1][START_CHAR] = position_info[0] + sent[-1][END_CHAR] = position_info[1] + if p == 3 or p == 4:# MARK + sent[-1][MISC] = 'MWT=Yes' i += 1 return sent @@ -117,7 +119,7 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma for i, p in enumerate(data_generator.sentences): start = 0 if i == 0 else paragraphs[-1][2] length = sum([len(x[0]) for x in p]) - paragraphs += [(i, start, start+length, length+1)] # para idx, start idx, end idx, length + paragraphs += [(i, start, start+length, length)] # para idx, start idx, end idx, length paragraphs = list(sorted(paragraphs, key=lambda x: x[3], reverse=True)) @@ -127,13 +129,15 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma eval_limit = max(3000, max_seqlen) batch_size = trainer.args['batch_size'] + skip_newline = trainer.args['skip_newline'] batches = int((len(paragraphs) + batch_size - 1) / batch_size) - t = 0 for i in range(batches): + # At evaluation time, each paragraph is treated as a single "sentence", and a batch of `batch_size` paragraphs + # are tokenized together. `offsets` here are used by the data generator to identify which paragraphs to use + # for the next batch of evaluation. batchparas = paragraphs[i * batch_size : (i + 1) * batch_size] offsets = [x[1] for x in batchparas] - t += sum([x[3] for x in batchparas]) batch = data_generator.next(eval_offsets=offsets) raw = batch[3] @@ -165,6 +169,8 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma if all([idx1 >= N for idx1, N in zip(idx, Ns)]): break + # once we've made predictions on a certain number of characters for each paragraph (recorded in `adv`), + # we skip the first `adv` characters to make the updated batch batch = data_generator.next(eval_offsets=adv, old_batch=batch) pred = [np.concatenate(p, 0) for p in pred] @@ -189,6 +195,10 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma char_offset = 0 use_la_ittb_shorthand = trainer.args['shorthand'] == 'la_ittb' + UNK_ID = vocab.unit2id('<UNK>') + + # Once everything is fed through the tokenizer model, it's time to decode the predictions + # into actual tokens and sentences that the rest of the pipeline uses for j in range(len(paragraphs)): raw = all_raw[j] pred = all_preds[j] @@ -203,7 +213,7 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma if use_la_ittb_shorthand and t in (":", ";"): p = 2 offset += 1 - if vocab.unit2id(t) == vocab.unit2id('<UNK>'): + if vocab.unit2id(t) == UNK_ID: oov_count += 1 current_tok += t @@ -218,15 +228,22 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma tok_len = 0 for part in SPACE_SPLIT_RE.split(current_tok): if len(part) == 0: continue - st0 = text.index(part, char_offset) - char_offset + if skip_newline: + part_pattern = re.compile(r'\s*'.join(re.escape(c) for c in part)) + match = part_pattern.search(text, char_offset) + st0 = match.start(0) - char_offset + partlen = match.end(0) - match.start(0) + else: + st0 = text.index(part, char_offset) - char_offset + partlen = len(part) lstripped = part.lstrip() if st < 0: st = char_offset + st0 + (len(part) - len(lstripped)) - char_offset += st0 + len(part) - additional_info = {START_CHAR: st, END_CHAR: char_offset} + char_offset += st0 + partlen + position_info = (st, char_offset) else: - additional_info = dict() - current_sent.append((tok, p, additional_info)) + position_info = None + current_sent.append((tok, p, position_info)) current_tok = '' if (p == 2 or p == 4) and not no_ssplit: doc.append(process_sentence(current_sent, mwt_dict)) |