diff options
author | Jean Senellart <jean@senellart.com> | 2017-09-29 07:18:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-29 07:18:41 +0300 |
commit | 57558f2ec98a5ca5af8e3226500fa67421e27182 (patch) | |
tree | 616142cc438d70bbf15f9599dc08f8be38fdfc30 /onmt | |
parent | 08e15964e4cf26f3882461fae138302e42053684 (diff) |
Fix384 (#387)
build automatically case vocab dictionary if missing
Diffstat (limited to 'onmt')
-rw-r--r-- | onmt/data/Preprocessor.lua | 7 | ||||
-rw-r--r-- | onmt/data/Vocabulary.lua | 25 |
2 files changed, 26 insertions, 6 deletions
diff --git a/onmt/data/Preprocessor.lua b/onmt/data/Preprocessor.lua index e909102d..b9aee618 100644 --- a/onmt/data/Preprocessor.lua +++ b/onmt/data/Preprocessor.lua @@ -922,7 +922,9 @@ function Preprocessor:getVocabulary() self.args.features_vocabs_prefix, function(s) return isValid(s, self.args.src_seq_length or self.args.seq_length) end, self.args.keep_frequency, - self.args.idx_files) + self.args.idx_files, + self.args.tok_src_case_feature) + end if self.dataType ~= 'monotext' then -- use the first target file to count target features @@ -935,7 +937,8 @@ function Preprocessor:getVocabulary() self.args.features_vocabs_prefix, function(s) return isValid(s, self.args.tgt_seq_length) end, self.args.keep_frequency, - self.args.idx_files) + self.args.idx_files, + self.args.tok_tgt_case_feature) end return dicts end diff --git a/onmt/data/Vocabulary.lua b/onmt/data/Vocabulary.lua index 7bd69d9d..4a42c0dd 100644 --- a/onmt/data/Vocabulary.lua +++ b/onmt/data/Vocabulary.lua @@ -1,4 +1,5 @@ local path = require('pl.path') +local case = require('tools.utils.case') --[[ Vocabulary management utility functions. ]] local Vocabulary = torch.class("Vocabulary") @@ -66,10 +67,15 @@ function Vocabulary.make(filename, validFunc, idxFile) return wordVocab, featuresVocabs end -function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency, featuresVocabsFiles, validFunc, keepFrequency, idxFile) +function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency, featuresVocabsFiles, validFunc, keepFrequency, idxFile, case_feature) local wordVocab local featuresVocabs = {} local numFeatures = countFeatures(dataFile, idxFile) + local correctedNumFeatures = numFeatures + + if numFeatures == 0 and case_feature then + correctedNumFeatures = 1 + end if vocabFile:len() > 0 then -- If given, load existing word dictionary. @@ -79,7 +85,7 @@ function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency _G.logger:info(' * Loaded ' .. wordVocab:size() .. ' ' .. name .. ' words') end - if featuresVocabsFiles:len() > 0 and numFeatures > 0 then + if featuresVocabsFiles:len() > 0 and correctedNumFeatures > 0 then -- If given, discover existing features dictionaries. local j = 1 @@ -100,11 +106,22 @@ function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency assert(#featuresVocabs > 0, 'dictionary \'' .. featuresVocabsFiles .. '.' .. name .. '_feature_1.dict\' not found') - assert(#featuresVocabs == numFeatures, - 'the data contains ' .. numFeatures .. ' ' .. name + assert(#featuresVocabs == correctedNumFeatures, + 'the data contains ' .. correctedNumFeatures .. ' ' .. name .. ' features but only ' .. #featuresVocabs .. ' dictionaries were found') end + if #featuresVocabs == 0 and case_feature then + -- build default case feature + _G.logger:info(' * Building default case feature vocabularies...') + featuresVocabs[1] = onmt.utils.Dict.new() + local regCaseFeat = case.getFeatures() + featuresVocabs[1]:addSpecials({onmt.Constants.PAD, onmt.Constants.UNK, onmt.Constants.BOS, onmt.Constants.EOS}) + for _, f in ipairs(regCaseFeat) do + featuresVocabs[1]:add(f) + end + end + if wordVocab == nil or keepFrequency or (#featuresVocabs == 0 and numFeatures > 0) then -- If a dictionary is still missing, generate it. _G.logger:info(' * Building ' .. name .. ' vocabularies...') |