diff options
author | Jean Senellart <jean@senellart.com> | 2017-10-31 09:50:43 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-10-31 09:50:43 +0300 |
commit | 7db30090a7ff4055a6062a03b2637ffb8b32374a (patch) | |
tree | 4a992c8feb48e7494d426f6c7d7fb57062031715 /onmt | |
parent | 48aa1fb233c366616b74602c4b7b7f2f3e4d36e9 (diff) |
Placeholders in protected sequences (#416)
Diffstat (limited to 'onmt')
-rw-r--r-- | onmt/data/Preprocessor.lua | 3 | ||||
-rw-r--r-- | onmt/data/Vocabulary.lua | 9 | ||||
-rw-r--r-- | onmt/lm/LM.lua | 3 | ||||
-rw-r--r-- | onmt/tagger/Tagger.lua | 3 | ||||
-rw-r--r-- | onmt/translate/Translator.lua | 3 | ||||
-rw-r--r-- | onmt/utils/Features.lua | 17 | ||||
-rw-r--r-- | onmt/utils/FileReader.lua | 2 | ||||
-rw-r--r-- | onmt/utils/Placeholders.lua | 25 | ||||
-rw-r--r-- | onmt/utils/init.lua | 1 |
9 files changed, 50 insertions, 16 deletions
diff --git a/onmt/data/Preprocessor.lua b/onmt/data/Preprocessor.lua index b8d05a45..63691377 100644 --- a/onmt/data/Preprocessor.lua +++ b/onmt/data/Preprocessor.lua @@ -577,7 +577,8 @@ local function processSentence(n, idx, tokens, parallelCheck, isValid, isInputVe vectors[i]:insert(tokens[i]) else local words, feats = onmt.utils.Features.extract(tokens[i]) - local vec = dicts[i].words:convertToIdx(words, table.unpack(constants[i])) + local vocabs = onmt.utils.Placeholders.norm(words) + local vec = dicts[i].words:convertToIdx(vocabs, table.unpack(constants[i])) local pruned = vec:eq(onmt.Constants.UNK):sum() / vec:size(1) prunedRatio[i] = prunedRatio[i] * (#vectors[i] / (#vectors[i] + 1)) + pruned / (#vectors[i] + 1) diff --git a/onmt/data/Vocabulary.lua b/onmt/data/Vocabulary.lua index a445b58e..fc27baaa 100644 --- a/onmt/data/Vocabulary.lua +++ b/onmt/data/Vocabulary.lua @@ -28,9 +28,12 @@ function Vocabulary.make(filename, validFunc, idxFile) lineId = lineId + 1 if validFunc(sent) then - local words, features, numFeatures + local features, numFeatures + local vocabs local _, err = pcall(function () + local words words, features, numFeatures = onmt.utils.Features.extract(sent) + vocabs = onmt.utils.Placeholders.norm(words) end) if err then @@ -47,8 +50,8 @@ function Vocabulary.make(filename, validFunc, idxFile) 'all sentences must have the same numbers of additional features (' .. filename .. ':' .. lineId .. ')') end - for i = 1, #words do - wordVocab:add(words[i]) + for i = 1, #vocabs do + wordVocab:add(vocabs[i]) for j = 1, numFeatures do featuresVocabs[j]:add(features[j][i]) diff --git a/onmt/lm/LM.lua b/onmt/lm/LM.lua index 461c0574..ec04879b 100644 --- a/onmt/lm/LM.lua +++ b/onmt/lm/LM.lua @@ -61,8 +61,9 @@ function LM:buildInput(tokens) local data = {} local words, features = onmt.utils.Features.extract(tokens) + local vocabs = onmt.utils.Placeholders.norm(words) - data.words = words + data.words = vocabs data.features = features return data diff --git a/onmt/tagger/Tagger.lua b/onmt/tagger/Tagger.lua index d9258cf3..64d152dd 100644 --- a/onmt/tagger/Tagger.lua +++ b/onmt/tagger/Tagger.lua @@ -49,8 +49,9 @@ function Tagger:buildInput(tokens) data.vectors = torch.Tensor(tokens) else local words, features = onmt.utils.Features.extract(tokens) + local vocabs = onmt.utils.Placeholders.norm(words) - data.words = words + data.words = vocabs if #features > 0 then data.features = features diff --git a/onmt/translate/Translator.lua b/onmt/translate/Translator.lua index e064624e..773bdf9f 100644 --- a/onmt/translate/Translator.lua +++ b/onmt/translate/Translator.lua @@ -185,8 +185,9 @@ function Translator:buildInput(tokens) data.vectors = torch.Tensor(tokens) else local words, features = onmt.utils.Features.extract(tokens) + local vocabs = onmt.utils.Placeholders.norm(words) - data.words = words + data.words = vocabs if #features > 0 then data.features = features diff --git a/onmt/utils/Features.lua b/onmt/utils/Features.lua index 5cb8abdd..e33ea34d 100644 --- a/onmt/utils/Features.lua +++ b/onmt/utils/Features.lua @@ -1,32 +1,33 @@ -- tds is lazy loaded. local tds ---[[ Separate words and features (if any). ]] +--[[ Separate words, features (if any) and normalize placeholders. ]] local function extract(tokens) local words = {} local features = {} local numFeatures = nil for t = 1, #tokens do - local field = onmt.utils.String.split(tokens[t], '│') - local word = field[1] + local fields = onmt.utils.String.split(tokens[t], '│') + local word = fields[1] if word:len() > 0 then + table.insert(words, word) if numFeatures == nil then - numFeatures = #field - 1 + numFeatures = #fields - 1 else - assert(#field - 1 == numFeatures, + assert(#fields - 1 == numFeatures, 'all words must have the same number of features') end - if #field > 1 then - for i = 2, #field do + if #fields > 1 then + for i = 2, #fields do if features[i - 1] == nil then features[i - 1] = {} end - table.insert(features[i - 1], field[i]) + table.insert(features[i - 1], fields[i]) end end end diff --git a/onmt/utils/FileReader.lua b/onmt/utils/FileReader.lua index 58206bfe..f8ce6708 100644 --- a/onmt/utils/FileReader.lua +++ b/onmt/utils/FileReader.lua @@ -30,7 +30,7 @@ function FileReader:next(doTokenize) if not self.featSequence then if doTokenize then - for word in line:gmatch'([^%s]+)' do + for word in line:gmatch'([^ ]+)' do table.insert(sent, word) end else diff --git a/onmt/utils/Placeholders.lua b/onmt/utils/Placeholders.lua new file mode 100644 index 00000000..0d8c2c33 --- /dev/null +++ b/onmt/utils/Placeholders.lua @@ -0,0 +1,25 @@ +local function norm(t) + if type(t) == "table" then + local v = {} + local vrep = {} + for _, tokt in ipairs(t) do + local vt, vtrep + vt, vtrep = norm(tokt) + table.insert(v, vt) + table.insert(vrep, vtrep) + end + return v, vrep + end + if t:sub(1, string.len('⦅')) == '⦅' then + local p = t:find('⦆') + assert(p, 'invalid placeholder tag: '..t) + local tcontent = t:sub(string.len('⦅')+1, p-1) + local fields = onmt.utils.String.split(tcontent, ':') + return '⦅'..fields[1]..t:sub(p), fields[2] or fields[1] + end + return t +end + +return { + norm = norm +} diff --git a/onmt/utils/init.lua b/onmt/utils/init.lua index 5dd2e674..3bb2ea05 100644 --- a/onmt/utils/init.lua +++ b/onmt/utils/init.lua @@ -11,6 +11,7 @@ utils.Memory = require('onmt.utils.Memory') utils.MemoryOptimizer = require('onmt.utils.MemoryOptimizer') utils.Parallel = require('onmt.utils.Parallel') utils.Features = require('onmt.utils.Features') +utils.Placeholders = require('onmt.utils.Placeholders') utils.Logger = require('onmt.utils.Logger') utils.Profiler = require('onmt.utils.Profiler') utils.ExtendedCmdLine = require('onmt.utils.ExtendedCmdLine') |