Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/onmt
diff options
context:
space:
mode:
authorJean Senellart <jean@senellart.com>2017-10-01 11:28:33 +0300
committerGitHub <noreply@github.com>2017-10-01 11:28:33 +0300
commitb3b26fc92673b425e5f474196fca12412e18954c (patch)
tree6547c25bdb57bd9fd57720730d1fa9cc5f71184c /onmt
parentff99f9127c94ecf55312f2cb678fe968c3609d9a (diff)
Fix #384 (#389)
* fix #384
Diffstat (limited to 'onmt')
-rw-r--r--onmt/data/Preprocessor.lua6
-rw-r--r--onmt/data/Vocabulary.lua25
2 files changed, 25 insertions, 6 deletions
diff --git a/onmt/data/Preprocessor.lua b/onmt/data/Preprocessor.lua
index e909102d..0bb50b96 100644
--- a/onmt/data/Preprocessor.lua
+++ b/onmt/data/Preprocessor.lua
@@ -922,7 +922,8 @@ function Preprocessor:getVocabulary()
self.args.features_vocabs_prefix,
function(s) return isValid(s, self.args.src_seq_length or self.args.seq_length) end,
self.args.keep_frequency,
- self.args.idx_files)
+ self.args.idx_files,
+ self.args.tok_src_case_feature)
end
if self.dataType ~= 'monotext' then
-- use the first target file to count target features
@@ -935,7 +936,8 @@ function Preprocessor:getVocabulary()
self.args.features_vocabs_prefix,
function(s) return isValid(s, self.args.tgt_seq_length) end,
self.args.keep_frequency,
- self.args.idx_files)
+ self.args.idx_files,
+ self.args.tok_tgt_case_feature)
end
return dicts
end
diff --git a/onmt/data/Vocabulary.lua b/onmt/data/Vocabulary.lua
index 7bd69d9d..a445b58e 100644
--- a/onmt/data/Vocabulary.lua
+++ b/onmt/data/Vocabulary.lua
@@ -1,4 +1,5 @@
local path = require('pl.path')
+local case = require('tools.utils.case')
--[[ Vocabulary management utility functions. ]]
local Vocabulary = torch.class("Vocabulary")
@@ -66,10 +67,15 @@ function Vocabulary.make(filename, validFunc, idxFile)
return wordVocab, featuresVocabs
end
-function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency, featuresVocabsFiles, validFunc, keepFrequency, idxFile)
+function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency, featuresVocabsFiles, validFunc, keepFrequency, idxFile, case_feature)
local wordVocab
local featuresVocabs = {}
local numFeatures = countFeatures(dataFile, idxFile)
+ local correctedNumFeatures = numFeatures
+
+ if numFeatures == 0 and case_feature then
+ correctedNumFeatures = 1
+ end
if vocabFile:len() > 0 then
-- If given, load existing word dictionary.
@@ -79,7 +85,7 @@ function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency
_G.logger:info(' * Loaded ' .. wordVocab:size() .. ' ' .. name .. ' words')
end
- if featuresVocabsFiles:len() > 0 and numFeatures > 0 then
+ if featuresVocabsFiles:len() > 0 and correctedNumFeatures > 0 then
-- If given, discover existing features dictionaries.
local j = 1
@@ -100,11 +106,22 @@ function Vocabulary.init(name, dataFile, vocabFile, vocabSize, wordsMinFrequency
assert(#featuresVocabs > 0,
'dictionary \'' .. featuresVocabsFiles .. '.' .. name .. '_feature_1.dict\' not found')
- assert(#featuresVocabs == numFeatures,
- 'the data contains ' .. numFeatures .. ' ' .. name
+ assert(#featuresVocabs == correctedNumFeatures,
+ 'the data contains ' .. correctedNumFeatures .. ' ' .. name
.. ' features but only ' .. #featuresVocabs .. ' dictionaries were found')
end
+ if #featuresVocabs == 0 and case_feature then
+ -- build default case feature
+ _G.logger:info(' * Building default case feature vocabularies...')
+ featuresVocabs[1] = onmt.utils.Dict.new({onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD,
+ onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD})
+ local regCaseFeat = case.getFeatures()
+ for _, f in ipairs(regCaseFeat) do
+ featuresVocabs[1]:add(f)
+ end
+ end
+
if wordVocab == nil or keepFrequency or (#featuresVocabs == 0 and numFeatures > 0) then
-- If a dictionary is still missing, generate it.
_G.logger:info(' * Building ' .. name .. ' vocabularies...')