Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/onmt
diff options
context:
space:
mode:
authorJean Senellart <jean@senellart.com>2017-10-31 09:50:43 +0300
committerGitHub <noreply@github.com>2017-10-31 09:50:43 +0300
commit7db30090a7ff4055a6062a03b2637ffb8b32374a (patch)
tree4a992c8feb48e7494d426f6c7d7fb57062031715 /onmt
parent48aa1fb233c366616b74602c4b7b7f2f3e4d36e9 (diff)
Placeholders in protected sequences (#416)
Diffstat (limited to 'onmt')
-rw-r--r--onmt/data/Preprocessor.lua3
-rw-r--r--onmt/data/Vocabulary.lua9
-rw-r--r--onmt/lm/LM.lua3
-rw-r--r--onmt/tagger/Tagger.lua3
-rw-r--r--onmt/translate/Translator.lua3
-rw-r--r--onmt/utils/Features.lua17
-rw-r--r--onmt/utils/FileReader.lua2
-rw-r--r--onmt/utils/Placeholders.lua25
-rw-r--r--onmt/utils/init.lua1
9 files changed, 50 insertions, 16 deletions
diff --git a/onmt/data/Preprocessor.lua b/onmt/data/Preprocessor.lua
index b8d05a45..63691377 100644
--- a/onmt/data/Preprocessor.lua
+++ b/onmt/data/Preprocessor.lua
@@ -577,7 +577,8 @@ local function processSentence(n, idx, tokens, parallelCheck, isValid, isInputVe
vectors[i]:insert(tokens[i])
else
local words, feats = onmt.utils.Features.extract(tokens[i])
- local vec = dicts[i].words:convertToIdx(words, table.unpack(constants[i]))
+ local vocabs = onmt.utils.Placeholders.norm(words)
+ local vec = dicts[i].words:convertToIdx(vocabs, table.unpack(constants[i]))
local pruned = vec:eq(onmt.Constants.UNK):sum() / vec:size(1)
prunedRatio[i] = prunedRatio[i] * (#vectors[i] / (#vectors[i] + 1)) + pruned / (#vectors[i] + 1)
diff --git a/onmt/data/Vocabulary.lua b/onmt/data/Vocabulary.lua
index a445b58e..fc27baaa 100644
--- a/onmt/data/Vocabulary.lua
+++ b/onmt/data/Vocabulary.lua
@@ -28,9 +28,12 @@ function Vocabulary.make(filename, validFunc, idxFile)
lineId = lineId + 1
if validFunc(sent) then
- local words, features, numFeatures
+ local features, numFeatures
+ local vocabs
local _, err = pcall(function ()
+ local words
words, features, numFeatures = onmt.utils.Features.extract(sent)
+ vocabs = onmt.utils.Placeholders.norm(words)
end)
if err then
@@ -47,8 +50,8 @@ function Vocabulary.make(filename, validFunc, idxFile)
'all sentences must have the same numbers of additional features (' .. filename .. ':' .. lineId .. ')')
end
- for i = 1, #words do
- wordVocab:add(words[i])
+ for i = 1, #vocabs do
+ wordVocab:add(vocabs[i])
for j = 1, numFeatures do
featuresVocabs[j]:add(features[j][i])
diff --git a/onmt/lm/LM.lua b/onmt/lm/LM.lua
index 461c0574..ec04879b 100644
--- a/onmt/lm/LM.lua
+++ b/onmt/lm/LM.lua
@@ -61,8 +61,9 @@ function LM:buildInput(tokens)
local data = {}
local words, features = onmt.utils.Features.extract(tokens)
+ local vocabs = onmt.utils.Placeholders.norm(words)
- data.words = words
+ data.words = vocabs
data.features = features
return data
diff --git a/onmt/tagger/Tagger.lua b/onmt/tagger/Tagger.lua
index d9258cf3..64d152dd 100644
--- a/onmt/tagger/Tagger.lua
+++ b/onmt/tagger/Tagger.lua
@@ -49,8 +49,9 @@ function Tagger:buildInput(tokens)
data.vectors = torch.Tensor(tokens)
else
local words, features = onmt.utils.Features.extract(tokens)
+ local vocabs = onmt.utils.Placeholders.norm(words)
- data.words = words
+ data.words = vocabs
if #features > 0 then
data.features = features
diff --git a/onmt/translate/Translator.lua b/onmt/translate/Translator.lua
index e064624e..773bdf9f 100644
--- a/onmt/translate/Translator.lua
+++ b/onmt/translate/Translator.lua
@@ -185,8 +185,9 @@ function Translator:buildInput(tokens)
data.vectors = torch.Tensor(tokens)
else
local words, features = onmt.utils.Features.extract(tokens)
+ local vocabs = onmt.utils.Placeholders.norm(words)
- data.words = words
+ data.words = vocabs
if #features > 0 then
data.features = features
diff --git a/onmt/utils/Features.lua b/onmt/utils/Features.lua
index 5cb8abdd..e33ea34d 100644
--- a/onmt/utils/Features.lua
+++ b/onmt/utils/Features.lua
@@ -1,32 +1,33 @@
-- tds is lazy loaded.
local tds
---[[ Separate words and features (if any). ]]
+--[[ Separate words, features (if any) and normalize placeholders. ]]
local function extract(tokens)
local words = {}
local features = {}
local numFeatures = nil
for t = 1, #tokens do
- local field = onmt.utils.String.split(tokens[t], '│')
- local word = field[1]
+ local fields = onmt.utils.String.split(tokens[t], '│')
+ local word = fields[1]
if word:len() > 0 then
+
table.insert(words, word)
if numFeatures == nil then
- numFeatures = #field - 1
+ numFeatures = #fields - 1
else
- assert(#field - 1 == numFeatures,
+ assert(#fields - 1 == numFeatures,
'all words must have the same number of features')
end
- if #field > 1 then
- for i = 2, #field do
+ if #fields > 1 then
+ for i = 2, #fields do
if features[i - 1] == nil then
features[i - 1] = {}
end
- table.insert(features[i - 1], field[i])
+ table.insert(features[i - 1], fields[i])
end
end
end
diff --git a/onmt/utils/FileReader.lua b/onmt/utils/FileReader.lua
index 58206bfe..f8ce6708 100644
--- a/onmt/utils/FileReader.lua
+++ b/onmt/utils/FileReader.lua
@@ -30,7 +30,7 @@ function FileReader:next(doTokenize)
if not self.featSequence then
if doTokenize then
- for word in line:gmatch'([^%s]+)' do
+ for word in line:gmatch'([^ ]+)' do
table.insert(sent, word)
end
else
diff --git a/onmt/utils/Placeholders.lua b/onmt/utils/Placeholders.lua
new file mode 100644
index 00000000..0d8c2c33
--- /dev/null
+++ b/onmt/utils/Placeholders.lua
@@ -0,0 +1,25 @@
+local function norm(t)
+ if type(t) == "table" then
+ local v = {}
+ local vrep = {}
+ for _, tokt in ipairs(t) do
+ local vt, vtrep
+ vt, vtrep = norm(tokt)
+ table.insert(v, vt)
+ table.insert(vrep, vtrep)
+ end
+ return v, vrep
+ end
+ if t:sub(1, string.len('⦅')) == '⦅' then
+ local p = t:find('⦆')
+ assert(p, 'invalid placeholder tag: '..t)
+ local tcontent = t:sub(string.len('⦅')+1, p-1)
+ local fields = onmt.utils.String.split(tcontent, ':')
+ return '⦅'..fields[1]..t:sub(p), fields[2] or fields[1]
+ end
+ return t
+end
+
+return {
+ norm = norm
+}
diff --git a/onmt/utils/init.lua b/onmt/utils/init.lua
index 5dd2e674..3bb2ea05 100644
--- a/onmt/utils/init.lua
+++ b/onmt/utils/init.lua
@@ -11,6 +11,7 @@ utils.Memory = require('onmt.utils.Memory')
utils.MemoryOptimizer = require('onmt.utils.MemoryOptimizer')
utils.Parallel = require('onmt.utils.Parallel')
utils.Features = require('onmt.utils.Features')
+utils.Placeholders = require('onmt.utils.Placeholders')
utils.Logger = require('onmt.utils.Logger')
utils.Profiler = require('onmt.utils.Profiler')
utils.ExtendedCmdLine = require('onmt.utils.ExtendedCmdLine')