Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'MultiCriterion.lua')
-rw-r--r--MultiCriterion.lua32
1 files changed, 32 insertions, 0 deletions
diff --git a/MultiCriterion.lua b/MultiCriterion.lua
new file mode 100644
index 0000000..e83b97e
--- /dev/null
+++ b/MultiCriterion.lua
@@ -0,0 +1,32 @@
+local MultiCriterion, parent = torch.class('nn.MultiCriterion', 'nn.Criterion')
+
+function MultiCriterion:__init()
+ parent.__init(self)
+ self.criterions = {}
+ self.weights = torch.DoubleStorage()
+end
+
+function MultiCriterion:add(criterion, weight)
+ weight = weight or 1
+ table.insert(self.criterions, criterion)
+ self.weights:resize(#self.criterions, true)
+ self.weights[#self.criterions] = weight
+ return self
+end
+
+function MultiCriterion:updateOutput(input, target)
+ self.output = 0
+ for i=1,#self.criterions do
+ self.output = self.output + self.weights[i]*self.criterions[i]:updateOutput(input, target)
+ end
+ return self.output
+end
+
+function MultiCriterion:updateGradInput(input, target)
+ self.gradInput:resizeAs(input)
+ self.gradInput:zero()
+ for i=1,#self.criterions do
+ self.gradInput:add(self.weights[i], self.criterions[i]:updateGradInput(input, target))
+ end
+ return self.gradInput
+end