Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/models/tokenization/utils.py')
-rw-r--r--stanza/models/tokenization/utils.py49
1 files changed, 33 insertions, 16 deletions
diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py
index 2cd67e6b..ea7bda47 100644
--- a/stanza/models/tokenization/utils.py
+++ b/stanza/models/tokenization/utils.py
@@ -30,7 +30,7 @@ def load_mwt_dict(filename):
def process_sentence(sentence, mwt_dict=None):
sent = []
i = 0
- for tok, p, additional_info in sentence:
+ for tok, p, position_info in sentence:
expansion = None
if (p == 3 or p == 4) and mwt_dict is not None:
# MWT found, (attempt to) expand it!
@@ -39,20 +39,22 @@ def process_sentence(sentence, mwt_dict=None):
elif tok.lower() in mwt_dict:
expansion = mwt_dict[tok.lower()][0]
if expansion is not None:
- infostr = None if len(additional_info) == 0 else '|'.join([f"{k}={additional_info[k]}" for k in additional_info])
sent.append({ID: (i+1, i+len(expansion)), TEXT: tok})
- if infostr is not None: sent[-1][MISC] = infostr
+ if position_info is not None:
+ sent[-1][START_CHAR] = position_info[0]
+ sent[-1][END_CHAR] = position_info[1]
for etok in expansion:
sent.append({ID: (i+1, ), TEXT: etok})
i += 1
else:
if len(tok) <= 0:
continue
- if p == 3 or p == 4:
- additional_info['MWT'] = 'Yes'
- infostr = None if len(additional_info) == 0 else '|'.join([f"{k}={additional_info[k]}" for k in additional_info])
sent.append({ID: (i+1, ), TEXT: tok})
- if infostr is not None: sent[-1][MISC] = infostr
+ if position_info is not None:
+ sent[-1][START_CHAR] = position_info[0]
+ sent[-1][END_CHAR] = position_info[1]
+ if p == 3 or p == 4:# MARK
+ sent[-1][MISC] = 'MWT=Yes'
i += 1
return sent
@@ -117,7 +119,7 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma
for i, p in enumerate(data_generator.sentences):
start = 0 if i == 0 else paragraphs[-1][2]
length = sum([len(x[0]) for x in p])
- paragraphs += [(i, start, start+length, length+1)] # para idx, start idx, end idx, length
+ paragraphs += [(i, start, start+length, length)] # para idx, start idx, end idx, length
paragraphs = list(sorted(paragraphs, key=lambda x: x[3], reverse=True))
@@ -127,13 +129,15 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma
eval_limit = max(3000, max_seqlen)
batch_size = trainer.args['batch_size']
+ skip_newline = trainer.args['skip_newline']
batches = int((len(paragraphs) + batch_size - 1) / batch_size)
- t = 0
for i in range(batches):
+ # At evaluation time, each paragraph is treated as a single "sentence", and a batch of `batch_size` paragraphs
+ # are tokenized together. `offsets` here are used by the data generator to identify which paragraphs to use
+ # for the next batch of evaluation.
batchparas = paragraphs[i * batch_size : (i + 1) * batch_size]
offsets = [x[1] for x in batchparas]
- t += sum([x[3] for x in batchparas])
batch = data_generator.next(eval_offsets=offsets)
raw = batch[3]
@@ -165,6 +169,8 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma
if all([idx1 >= N for idx1, N in zip(idx, Ns)]):
break
+ # once we've made predictions on a certain number of characters for each paragraph (recorded in `adv`),
+ # we skip the first `adv` characters to make the updated batch
batch = data_generator.next(eval_offsets=adv, old_batch=batch)
pred = [np.concatenate(p, 0) for p in pred]
@@ -189,6 +195,10 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma
char_offset = 0
use_la_ittb_shorthand = trainer.args['shorthand'] == 'la_ittb'
+ UNK_ID = vocab.unit2id('<UNK>')
+
+ # Once everything is fed through the tokenizer model, it's time to decode the predictions
+ # into actual tokens and sentences that the rest of the pipeline uses
for j in range(len(paragraphs)):
raw = all_raw[j]
pred = all_preds[j]
@@ -203,7 +213,7 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma
if use_la_ittb_shorthand and t in (":", ";"):
p = 2
offset += 1
- if vocab.unit2id(t) == vocab.unit2id('<UNK>'):
+ if vocab.unit2id(t) == UNK_ID:
oov_count += 1
current_tok += t
@@ -218,15 +228,22 @@ def output_predictions(output_file, trainer, data_generator, vocab, mwt_dict, ma
tok_len = 0
for part in SPACE_SPLIT_RE.split(current_tok):
if len(part) == 0: continue
- st0 = text.index(part, char_offset) - char_offset
+ if skip_newline:
+ part_pattern = re.compile(r'\s*'.join(re.escape(c) for c in part))
+ match = part_pattern.search(text, char_offset)
+ st0 = match.start(0) - char_offset
+ partlen = match.end(0) - match.start(0)
+ else:
+ st0 = text.index(part, char_offset) - char_offset
+ partlen = len(part)
lstripped = part.lstrip()
if st < 0:
st = char_offset + st0 + (len(part) - len(lstripped))
- char_offset += st0 + len(part)
- additional_info = {START_CHAR: st, END_CHAR: char_offset}
+ char_offset += st0 + partlen
+ position_info = (st, char_offset)
else:
- additional_info = dict()
- current_sent.append((tok, p, additional_info))
+ position_info = None
+ current_sent.append((tok, p, position_info))
current_tok = ''
if (p == 2 or p == 4) and not no_ssplit:
doc.append(process_sentence(current_sent, mwt_dict))