Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-14 08:31:33 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-14 08:31:33 +0400
commit937f52bdd070f59db90720c5993b28abd44f6aac (patch)
tree310f10457ad5237f8f79e2634fb2b0908729071b /SuperCriterion.lua
parentd6da0e6b62f24d0581255751d3ee22e6f3765035 (diff)
SuperCriterion accepts multiple targets.
Diffstat (limited to 'SuperCriterion.lua')
-rw-r--r--SuperCriterion.lua26
1 files changed, 20 insertions, 6 deletions
diff --git a/SuperCriterion.lua b/SuperCriterion.lua
index ce429e2..c73d716 100644
--- a/SuperCriterion.lua
+++ b/SuperCriterion.lua
@@ -15,17 +15,31 @@ end
function SuperCriterion:forward(input, target)
self.output = 0
- for i,criterion in ipairs(self.criterions) do
- self.output = self.output + self.weights[i]*criterion:forward(input[i],target)
+ if type(target) == 'table' then
+ for i,criterion in ipairs(self.criterions) do
+ self.output = self.output + self.weights[i]*criterion:forward(input[i],target[i])
+ end
+ else
+ for i,criterion in ipairs(self.criterions) do
+ self.output = self.output + self.weights[i]*criterion:forward(input[i],target)
+ end
end
return self.output
end
function SuperCriterion:backward(input, target)
- for i,criterion in ipairs(self.criterions) do
- self.gradInput[i] = torch.Tensor() or self.gradInput[i]
- self.gradInput[i]:resizeAs(input[i]):zero()
- self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target) )
+ if type(target) == 'table' then
+ for i,criterion in ipairs(self.criterions) do
+ self.gradInput[i] = torch.Tensor() or self.gradInput[i]
+ self.gradInput[i]:resizeAs(input[i]):zero()
+ self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target[i]))
+ end
+ else
+ for i,criterion in ipairs(self.criterions) do
+ self.gradInput[i] = torch.Tensor() or self.gradInput[i]
+ self.gradInput[i]:resizeAs(input[i]):zero()
+ self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target))
+ end
end
return self.gradInput
end