diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2017-10-20 11:57:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-10-20 11:57:33 +0300 |
commit | 5496a8928e5ece9503b23337bc52c25f04fa44d6 (patch) | |
tree | 94862d83cb2893721bf116b5c266489ebeeac185 /onmt | |
parent | f661a6077a3cf97c7ba3e765695a037dc03330c0 (diff) |
Add keep rules in file sampling to keep all sentences from training files (#408)
* Add keep rules in file sampling to use a file without sampling
* Print a warning when the sampling size can not be reached
Diffstat (limited to 'onmt')
-rw-r--r-- | onmt/data/Preprocessor.lua | 125 |
1 files changed, 89 insertions, 36 deletions
diff --git a/onmt/data/Preprocessor.lua b/onmt/data/Preprocessor.lua index cfd77716..0cb89b66 100644 --- a/onmt/data/Preprocessor.lua +++ b/onmt/data/Preprocessor.lua @@ -234,7 +234,7 @@ local function ruleMatch(s, rule) end end -function Preprocessor:parseDirectory(args, datalist, dist_rules, type) +function Preprocessor:parseDirectory(args, datalist, dist_rules, keep_rules, type) local dir = args[type.."_dir"] assert(dir ~= '', 'missing \''..type..'_dir\' parameter') _G.logger:info('Parsing '..type..' data from directory \''..dir..'\':') @@ -265,6 +265,7 @@ function Preprocessor:parseDirectory(args, datalist, dist_rules, type) if error == 0 then local fdesc = { countLines, flist } fdesc.fname = fprefix + fdesc.weight = 0 return _G.__threadid, 0, fdesc else return _G.__threadid, error, errors @@ -299,49 +300,73 @@ function Preprocessor:parseDirectory(args, datalist, dist_rules, type) _G.logger:info(totalCount..' sentences, in '..#list_files..' files, in '..type..' directory') _G.logger:info('') + local keepCount = 0 + + if #keep_rules > 0 then + _G.logger:info('Matching files with keep rules:') + for i = 1, #list_files do + if list_files[i].weight == 0 then + for rule_idx = 1, #keep_rules do + if ruleMatch(list_files[i].fname, keep_rules[rule_idx][1]) then + keepCount = keepCount + list_files[i][1] + list_files[i].rule_idx = rule_idx + list_files[i].weight = math.huge + _G.logger:info(" * file '%s' is covered by the keep rule %d", + list_files[i].fname, list_files[i].rule_idx or 0) + break + end + end + end + end + _G.logger:info('') + + -- Files matched with keep rules are not part of the global sampling. + totalCount = totalCount - keepCount + end + if #dist_rules > 0 then - _G.logger:info('Matching files with sample distribution rules:') + _G.logger:info('Matching files with sample rules:') local weight_norm = 0 local weight_rule = {} + for i = 1, #list_files do - local rule_idx = 1 - while rule_idx <= #dist_rules do - local fname = list_files[i].fname - if ruleMatch(fname, dist_rules[rule_idx][1]) then - list_files[i].rule_idx = rule_idx - if not weight_rule[rule_idx] then - weight_norm = weight_norm + dist_rules[rule_idx][2] - weight_rule[rule_idx] = 0 + if list_files[i].weight == 0 then + for rule_idx = 1, #dist_rules do + if ruleMatch(list_files[i].fname, dist_rules[rule_idx][1]) then + list_files[i].rule_idx = rule_idx + if not weight_rule[rule_idx] then + weight_norm = weight_norm + dist_rules[rule_idx][2] + weight_rule[rule_idx] = 0 + end + weight_rule[rule_idx] = weight_rule[rule_idx] + list_files[i][1] + break end - weight_rule[rule_idx] = weight_rule[rule_idx] + list_files[i][1] - break end - rule_idx = rule_idx + 1 - end - if rule_idx > #dist_rules then - _G.logger:warning(" * file '"..list_files[i].fname.."' is not covered by rules - will not be used") - list_files[i].weight = 0 end end + local sum_weight = 0 for i = 1, #list_files do - if list_files[i].rule_idx then + if list_files[i].rule_idx and list_files[i].weight ~= math.huge then local rule_idx = list_files[i].rule_idx list_files[i].weight = dist_rules[rule_idx][2] / weight_norm * list_files[i][1] / weight_rule[rule_idx] sum_weight = sum_weight + list_files[i].weight end end - -- final normalization of weights + for i = 1, #list_files do - list_files[i].weight = list_files[i].weight / sum_weight - if list_files[i].weight > 0 then - _G.logger:info(" * file '%s' uniform weight: %.3f, (rule: %d) distribution weight: %.3f", - list_files[i].fname, - 100 * list_files[i][1] / totalCount, - list_files[i].rule_idx or 0, - 100 * list_files[i].weight) + if list_files[i].weight ~= math.huge then + list_files[i].weight = list_files[i].weight / sum_weight + if list_files[i].weight > 0 then + _G.logger:info(" * file '%s' is covered by the sampling rule %d - uniform weight: %.4f, distribution weight: %.4f", + list_files[i].fname, + list_files[i].rule_idx or 0, + 100 * list_files[i][1] / totalCount, + 100 * list_files[i].weight) + end end end + _G.logger:info('') else for i = 1, #list_files do @@ -349,7 +374,14 @@ function Preprocessor:parseDirectory(args, datalist, dist_rules, type) end end - return totalCount, list_files + for i = 1, #list_files do + if list_files[i].weight == 0 then + _G.logger:warning(" * file '%s' is not covered by any rules and will not be used", + list_files[i].fname) + end + end + + return totalCount, keepCount, list_files end -- helper functions for threading @@ -462,6 +494,7 @@ function Preprocessor:__init(args, dataType) end self.dist_rules = {} + self.keep_rules = {} if args.gsample_dist ~= '' then local f = io.input(args.gsample_dist) while true do @@ -469,7 +502,11 @@ function Preprocessor:__init(args, dataType) if not dist_rule then break end local trule = onmt.utils.String.split(dist_rule, " ") onmt.utils.Error.assert(#trule == 2, "invalid syntax for sample distribution rule: "..dist_rule) - table.insert(self.dist_rules, trule) + if trule[2] == "*" then + table.insert(self.keep_rules, trule) + else + table.insert(self.dist_rules, trule) + end end end -- list and check training files @@ -478,10 +515,11 @@ function Preprocessor:__init(args, dataType) if not args.dry_run then onmt.utils.Error.assert(isempty(self.vocabs) == 0, 'For directory mode, vocabs should be predefined') end - self.totalCount, self.list_train = self:parseDirectory(self.args, Preprocessor.getDataList(self.dataType), self.dist_rules, 'train') + self.totalCount, self.keepCount, self.list_train = self:parseDirectory(self.args, Preprocessor.getDataList(self.dataType), self.dist_rules, self.keep_rules, 'train') else onmt.utils.Error.assert(isempty(self.trains) == 0) self.totalCount = onmt.utils.FileReader.countLines(self.args[self.trains[1]], args.idx_files) + self.keepCount = 0 local list_files = { self.args[self.trains[1]] } for i = 2, #self.trains do table.insert(list_files, args[self.trains[i]]) @@ -1012,18 +1050,33 @@ function Preprocessor:makeData(dataset, dicts) local sampledCount = self.args.gsample if sampledCount < 1 then sampledCount = sampledCount * self.totalCount + else + sampledCount = sampledCount - self.keepCount + end + if self.totalCount == 0 and self.keepCount < self.args.gsample then + _G.logger:warning('You requested a sample of %d sentences but no files matched any sampling rules and only %d sentences are selected by keep rules. There could be issues with your distribution rules.', + self.args.gsample, self.keepCount) + end + if sampledCount < 0 then + _G.logger:error('You requested a sample of %d sentences but %d are already reserved by keep rules. You should configure a larger sample or keep less sentences.', + self.args.gsample, self.keepCount) + os.exit(1) end -- check how many sentences per file for _, f in ipairs(self.list_train) do - local n = math.ceil(sampledCount * f.weight) - local t = torch.LongTensor(n) - if n > 0 then - for i = 1, n do - t[i] = torch.random(1, f[1]) + if f.weight == math.huge then + table.insert(sample_file, tds.Vec(torch.range(1, f[1]):totable())) + else + local n = math.ceil(sampledCount * f.weight) + local t = torch.LongTensor(n) + if n > 0 then + for i = 1, n do + t[i] = torch.random(1, f[1]) + end + t = torch.sort(t) end - t = torch.sort(t) + table.insert(sample_file, tds.Vec(t:totable())) end - table.insert(sample_file, tds.Vec(t:totable())) end end |