diff options
author | Nicholas Leonard <nick@nikopia.org> | 2015-04-29 00:40:07 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2015-05-05 21:58:49 +0300 |
commit | 3fc211c48b5ad270f374cb18733b0354cc4b9335 (patch) | |
tree | c5084b7f3363721d2b90440c2c1ff979a6f72c1b /ParallelCriterion.lua | |
parent | 59d56be820159c8620062a95182cada6e2384ffc (diff) |
ParallelCriterion
Diffstat (limited to 'ParallelCriterion.lua')
-rw-r--r-- | ParallelCriterion.lua | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua new file mode 100644 index 0000000..bee1f9c --- /dev/null +++ b/ParallelCriterion.lua @@ -0,0 +1,52 @@ +local ParallelCriterion, parent = torch.class('nn.ParallelCriterion', 'nn.Criterion') + +function ParallelCriterion:__init(repeatTarget) + parent.__init(self) + self.criterions = {} + self.weights = {} + self.gradInput = {} + self.repeatTarget = repeatTarget +end + +function ParallelCriterion:add(criterion, weight) + weight = weight or 1 + table.insert(self.criterions, criterion) + table.insert(self.weights, weight) + return self +end + +function ParallelCriterion:updateOutput(input, target) + self.output = 0 + if not self.repeatTarget then + for i,criterion in ipairs(self.criterions) do + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target[i]) + end + else + for i,criterion in ipairs(self.criterions) do + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) + end + end + return self.output +end + +function ParallelCriterion:updateGradInput(input, target) + if not self.repeatTarget then + for i,criterion in ipairs(self.criterions) do + self.gradInput[i] = input[i].new() or self.gradInput[i] + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target[i])) + end + else + for i,criterion in ipairs(self.criterions) do + self.gradInput[i] = input[i].new() or self.gradInput[i] + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target)) + end + end + return self.gradInput +end + +function ParallelCriterion:type(type) + self.gradInput = {} + return parent.type(self, type) +end |