Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/onmt
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2017-10-20 11:57:33 +0300
committerGitHub <noreply@github.com>2017-10-20 11:57:33 +0300
commit5496a8928e5ece9503b23337bc52c25f04fa44d6 (patch)
tree94862d83cb2893721bf116b5c266489ebeeac185 /onmt
parentf661a6077a3cf97c7ba3e765695a037dc03330c0 (diff)
Add keep rules in file sampling to keep all sentences from training files (#408)
* Add keep rules in file sampling to use a file without sampling * Print a warning when the sampling size can not be reached
Diffstat (limited to 'onmt')
-rw-r--r--onmt/data/Preprocessor.lua125
1 files changed, 89 insertions, 36 deletions
diff --git a/onmt/data/Preprocessor.lua b/onmt/data/Preprocessor.lua
index cfd77716..0cb89b66 100644
--- a/onmt/data/Preprocessor.lua
+++ b/onmt/data/Preprocessor.lua
@@ -234,7 +234,7 @@ local function ruleMatch(s, rule)
end
end
-function Preprocessor:parseDirectory(args, datalist, dist_rules, type)
+function Preprocessor:parseDirectory(args, datalist, dist_rules, keep_rules, type)
local dir = args[type.."_dir"]
assert(dir ~= '', 'missing \''..type..'_dir\' parameter')
_G.logger:info('Parsing '..type..' data from directory \''..dir..'\':')
@@ -265,6 +265,7 @@ function Preprocessor:parseDirectory(args, datalist, dist_rules, type)
if error == 0 then
local fdesc = { countLines, flist }
fdesc.fname = fprefix
+ fdesc.weight = 0
return _G.__threadid, 0, fdesc
else
return _G.__threadid, error, errors
@@ -299,49 +300,73 @@ function Preprocessor:parseDirectory(args, datalist, dist_rules, type)
_G.logger:info(totalCount..' sentences, in '..#list_files..' files, in '..type..' directory')
_G.logger:info('')
+ local keepCount = 0
+
+ if #keep_rules > 0 then
+ _G.logger:info('Matching files with keep rules:')
+ for i = 1, #list_files do
+ if list_files[i].weight == 0 then
+ for rule_idx = 1, #keep_rules do
+ if ruleMatch(list_files[i].fname, keep_rules[rule_idx][1]) then
+ keepCount = keepCount + list_files[i][1]
+ list_files[i].rule_idx = rule_idx
+ list_files[i].weight = math.huge
+ _G.logger:info(" * file '%s' is covered by the keep rule %d",
+ list_files[i].fname, list_files[i].rule_idx or 0)
+ break
+ end
+ end
+ end
+ end
+ _G.logger:info('')
+
+ -- Files matched with keep rules are not part of the global sampling.
+ totalCount = totalCount - keepCount
+ end
+
if #dist_rules > 0 then
- _G.logger:info('Matching files with sample distribution rules:')
+ _G.logger:info('Matching files with sample rules:')
local weight_norm = 0
local weight_rule = {}
+
for i = 1, #list_files do
- local rule_idx = 1
- while rule_idx <= #dist_rules do
- local fname = list_files[i].fname
- if ruleMatch(fname, dist_rules[rule_idx][1]) then
- list_files[i].rule_idx = rule_idx
- if not weight_rule[rule_idx] then
- weight_norm = weight_norm + dist_rules[rule_idx][2]
- weight_rule[rule_idx] = 0
+ if list_files[i].weight == 0 then
+ for rule_idx = 1, #dist_rules do
+ if ruleMatch(list_files[i].fname, dist_rules[rule_idx][1]) then
+ list_files[i].rule_idx = rule_idx
+ if not weight_rule[rule_idx] then
+ weight_norm = weight_norm + dist_rules[rule_idx][2]
+ weight_rule[rule_idx] = 0
+ end
+ weight_rule[rule_idx] = weight_rule[rule_idx] + list_files[i][1]
+ break
end
- weight_rule[rule_idx] = weight_rule[rule_idx] + list_files[i][1]
- break
end
- rule_idx = rule_idx + 1
- end
- if rule_idx > #dist_rules then
- _G.logger:warning(" * file '"..list_files[i].fname.."' is not covered by rules - will not be used")
- list_files[i].weight = 0
end
end
+
local sum_weight = 0
for i = 1, #list_files do
- if list_files[i].rule_idx then
+ if list_files[i].rule_idx and list_files[i].weight ~= math.huge then
local rule_idx = list_files[i].rule_idx
list_files[i].weight = dist_rules[rule_idx][2] / weight_norm * list_files[i][1] / weight_rule[rule_idx]
sum_weight = sum_weight + list_files[i].weight
end
end
- -- final normalization of weights
+
for i = 1, #list_files do
- list_files[i].weight = list_files[i].weight / sum_weight
- if list_files[i].weight > 0 then
- _G.logger:info(" * file '%s' uniform weight: %.3f, (rule: %d) distribution weight: %.3f",
- list_files[i].fname,
- 100 * list_files[i][1] / totalCount,
- list_files[i].rule_idx or 0,
- 100 * list_files[i].weight)
+ if list_files[i].weight ~= math.huge then
+ list_files[i].weight = list_files[i].weight / sum_weight
+ if list_files[i].weight > 0 then
+ _G.logger:info(" * file '%s' is covered by the sampling rule %d - uniform weight: %.4f, distribution weight: %.4f",
+ list_files[i].fname,
+ list_files[i].rule_idx or 0,
+ 100 * list_files[i][1] / totalCount,
+ 100 * list_files[i].weight)
+ end
end
end
+
_G.logger:info('')
else
for i = 1, #list_files do
@@ -349,7 +374,14 @@ function Preprocessor:parseDirectory(args, datalist, dist_rules, type)
end
end
- return totalCount, list_files
+ for i = 1, #list_files do
+ if list_files[i].weight == 0 then
+ _G.logger:warning(" * file '%s' is not covered by any rules and will not be used",
+ list_files[i].fname)
+ end
+ end
+
+ return totalCount, keepCount, list_files
end
-- helper functions for threading
@@ -462,6 +494,7 @@ function Preprocessor:__init(args, dataType)
end
self.dist_rules = {}
+ self.keep_rules = {}
if args.gsample_dist ~= '' then
local f = io.input(args.gsample_dist)
while true do
@@ -469,7 +502,11 @@ function Preprocessor:__init(args, dataType)
if not dist_rule then break end
local trule = onmt.utils.String.split(dist_rule, " ")
onmt.utils.Error.assert(#trule == 2, "invalid syntax for sample distribution rule: "..dist_rule)
- table.insert(self.dist_rules, trule)
+ if trule[2] == "*" then
+ table.insert(self.keep_rules, trule)
+ else
+ table.insert(self.dist_rules, trule)
+ end
end
end
-- list and check training files
@@ -478,10 +515,11 @@ function Preprocessor:__init(args, dataType)
if not args.dry_run then
onmt.utils.Error.assert(isempty(self.vocabs) == 0, 'For directory mode, vocabs should be predefined')
end
- self.totalCount, self.list_train = self:parseDirectory(self.args, Preprocessor.getDataList(self.dataType), self.dist_rules, 'train')
+ self.totalCount, self.keepCount, self.list_train = self:parseDirectory(self.args, Preprocessor.getDataList(self.dataType), self.dist_rules, self.keep_rules, 'train')
else
onmt.utils.Error.assert(isempty(self.trains) == 0)
self.totalCount = onmt.utils.FileReader.countLines(self.args[self.trains[1]], args.idx_files)
+ self.keepCount = 0
local list_files = { self.args[self.trains[1]] }
for i = 2, #self.trains do
table.insert(list_files, args[self.trains[i]])
@@ -1012,18 +1050,33 @@ function Preprocessor:makeData(dataset, dicts)
local sampledCount = self.args.gsample
if sampledCount < 1 then
sampledCount = sampledCount * self.totalCount
+ else
+ sampledCount = sampledCount - self.keepCount
+ end
+ if self.totalCount == 0 and self.keepCount < self.args.gsample then
+ _G.logger:warning('You requested a sample of %d sentences but no files matched any sampling rules and only %d sentences are selected by keep rules. There could be issues with your distribution rules.',
+ self.args.gsample, self.keepCount)
+ end
+ if sampledCount < 0 then
+ _G.logger:error('You requested a sample of %d sentences but %d are already reserved by keep rules. You should configure a larger sample or keep less sentences.',
+ self.args.gsample, self.keepCount)
+ os.exit(1)
end
-- check how many sentences per file
for _, f in ipairs(self.list_train) do
- local n = math.ceil(sampledCount * f.weight)
- local t = torch.LongTensor(n)
- if n > 0 then
- for i = 1, n do
- t[i] = torch.random(1, f[1])
+ if f.weight == math.huge then
+ table.insert(sample_file, tds.Vec(torch.range(1, f[1]):totable()))
+ else
+ local n = math.ceil(sampledCount * f.weight)
+ local t = torch.LongTensor(n)
+ if n > 0 then
+ for i = 1, n do
+ t[i] = torch.random(1, f[1])
+ end
+ t = torch.sort(t)
end
- t = torch.sort(t)
+ table.insert(sample_file, tds.Vec(t:totable()))
end
- table.insert(sample_file, tds.Vec(t:totable()))
end
end