diff options
Diffstat (limited to 'stanza/models/charlm.py')
-rw-r--r-- | stanza/models/charlm.py | 49 |
1 files changed, 29 insertions, 20 deletions
diff --git a/stanza/models/charlm.py b/stanza/models/charlm.py index 6ffc14b8..36345121 100644 --- a/stanza/models/charlm.py +++ b/stanza/models/charlm.py @@ -13,6 +13,7 @@ import math import logging import time import os +import lzma from stanza.models.common.char_model import CharacterLanguageModel from stanza.models.common.vocab import CharVocab @@ -44,33 +45,41 @@ def get_batch(source, i, seq_len): target = source[:, i+1:i+1+seq_len].reshape(-1) return data, target +def readlines(path): + if path.endswith(".xz"): + with lzma.open(path, mode='rt') as fin: + lines = fin.readlines() + else: + with open(path) as fin: + lines = fin.readlines() # preserve '\n' + return lines + def build_vocab(path, cutoff=0): # Requires a large amount of memory, but only need to build once + + # here we need some trick to deal with excessively large files + # for each file we accumulate the counter of characters, and + # at the end we simply pass a list of chars to the vocab builder + counter = Counter() if os.path.isdir(path): - # here we need some trick to deal with excessively large files - # for each file we accumulate the counter of characters, and - # at the end we simply pass a list of chars to the vocab builder - counter = Counter() filenames = sorted(os.listdir(path)) - for filename in filenames: - lines = open(path + '/' + filename).readlines() - for line in lines: - counter.update(list(line)) - # remove infrequent characters from vocab - for k in list(counter.keys()): - if counter[k] < cutoff: - del counter[k] - # a singleton list of all characters - data = [sorted([x[0] for x in counter.most_common()])] - vocab = CharVocab(data) # skip cutoff argument because this has been dealt with else: - lines = open(path).readlines() # reserve '\n' - data = [list(line) for line in lines] - vocab = CharVocab(data, cutoff=cutoff) + filenames = [path] + for filename in filenames: + lines = readlines(path + '/' + filename) + for line in lines: + counter.update(list(line)) + # remove infrequent characters from vocab + for k in list(counter.keys()): + if counter[k] < cutoff: + del counter[k] + # a singleton list of all characters + data = [sorted([x[0] for x in counter.most_common()])] + vocab = CharVocab(data) # skip cutoff argument because this has been dealt with return vocab def load_file(path, vocab, direction): - lines = open(path).readlines() # reserve '\n' + lines = readlines(path) data = list(''.join(lines)) idx = vocab['char'].map(data) if direction == 'backward': idx = idx[::-1] @@ -90,7 +99,7 @@ def load_data(path, vocab, direction): def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--train_file', type=str, help="Input plaintext file") - parser.add_argument('--train_dir', type=str, help="If non-emtpy, load from directory with multiple training files") + parser.add_argument('--train_dir', type=str, help="If non-empty, load from directory with multiple training files") parser.add_argument('--eval_file', type=str, help="Input plaintext file for the dev/test set") parser.add_argument('--lang', type=str, help="Language") parser.add_argument('--shorthand', type=str, help="UD treebank shorthand") |