Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-04-29 00:40:07 +0300
committerNicholas Leonard <nick@nikopia.org>2015-05-05 21:58:49 +0300
commit3fc211c48b5ad270f374cb18733b0354cc4b9335 (patch)
treec5084b7f3363721d2b90440c2c1ff979a6f72c1b /ParallelCriterion.lua
parent59d56be820159c8620062a95182cada6e2384ffc (diff)
ParallelCriterion
Diffstat (limited to 'ParallelCriterion.lua')
-rw-r--r--ParallelCriterion.lua52
1 files changed, 52 insertions, 0 deletions
diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua
new file mode 100644
index 0000000..bee1f9c
--- /dev/null
+++ b/ParallelCriterion.lua
@@ -0,0 +1,52 @@
+local ParallelCriterion, parent = torch.class('nn.ParallelCriterion', 'nn.Criterion')
+
+function ParallelCriterion:__init(repeatTarget)
+ parent.__init(self)
+ self.criterions = {}
+ self.weights = {}
+ self.gradInput = {}
+ self.repeatTarget = repeatTarget
+end
+
+function ParallelCriterion:add(criterion, weight)
+ weight = weight or 1
+ table.insert(self.criterions, criterion)
+ table.insert(self.weights, weight)
+ return self
+end
+
+function ParallelCriterion:updateOutput(input, target)
+ self.output = 0
+ if not self.repeatTarget then
+ for i,criterion in ipairs(self.criterions) do
+ self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target[i])
+ end
+ else
+ for i,criterion in ipairs(self.criterions) do
+ self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target)
+ end
+ end
+ return self.output
+end
+
+function ParallelCriterion:updateGradInput(input, target)
+ if not self.repeatTarget then
+ for i,criterion in ipairs(self.criterions) do
+ self.gradInput[i] = input[i].new() or self.gradInput[i]
+ self.gradInput[i]:resizeAs(input[i]):zero()
+ self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target[i]))
+ end
+ else
+ for i,criterion in ipairs(self.criterions) do
+ self.gradInput[i] = input[i].new() or self.gradInput[i]
+ self.gradInput[i]:resizeAs(input[i]):zero()
+ self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target))
+ end
+ end
+ return self.gradInput
+end
+
+function ParallelCriterion:type(type)
+ self.gradInput = {}
+ return parent.type(self, type)
+end