Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--MultiCriterion.lua7
-rw-r--r--ParallelCriterion.lua52
-rwxr-xr-xdoc/criterion.md48
-rw-r--r--init.lua1
-rw-r--r--test.lua60
5 files changed, 164 insertions, 4 deletions
diff --git a/MultiCriterion.lua b/MultiCriterion.lua
index e83b97e..4ad5a73 100644
--- a/MultiCriterion.lua
+++ b/MultiCriterion.lua
@@ -30,3 +30,10 @@ function MultiCriterion:updateGradInput(input, target)
end
return self.gradInput
end
+
+function MultiCriterion:type(type)
+ for i,criterion in ipairs(self.criterions) do
+ criterion:type(type)
+ end
+ return parent.type(self, type)
+end
diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua
new file mode 100644
index 0000000..bee1f9c
--- /dev/null
+++ b/ParallelCriterion.lua
@@ -0,0 +1,52 @@
+local ParallelCriterion, parent = torch.class('nn.ParallelCriterion', 'nn.Criterion')
+
+function ParallelCriterion:__init(repeatTarget)
+ parent.__init(self)
+ self.criterions = {}
+ self.weights = {}
+ self.gradInput = {}
+ self.repeatTarget = repeatTarget
+end
+
+function ParallelCriterion:add(criterion, weight)
+ weight = weight or 1
+ table.insert(self.criterions, criterion)
+ table.insert(self.weights, weight)
+ return self
+end
+
+function ParallelCriterion:updateOutput(input, target)
+ self.output = 0
+ if not self.repeatTarget then
+ for i,criterion in ipairs(self.criterions) do
+ self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target[i])
+ end
+ else
+ for i,criterion in ipairs(self.criterions) do
+ self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target)
+ end
+ end
+ return self.output
+end
+
+function ParallelCriterion:updateGradInput(input, target)
+ if not self.repeatTarget then
+ for i,criterion in ipairs(self.criterions) do
+ self.gradInput[i] = input[i].new() or self.gradInput[i]
+ self.gradInput[i]:resizeAs(input[i]):zero()
+ self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target[i]))
+ end
+ else
+ for i,criterion in ipairs(self.criterions) do
+ self.gradInput[i] = input[i].new() or self.gradInput[i]
+ self.gradInput[i]:resizeAs(input[i]):zero()
+ self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target))
+ end
+ end
+ return self.gradInput
+end
+
+function ParallelCriterion:type(type)
+ self.gradInput = {}
+ return parent.type(self, type)
+end
diff --git a/doc/criterion.md b/doc/criterion.md
index 3d3aefd..d0c9366 100755
--- a/doc/criterion.md
+++ b/doc/criterion.md
@@ -20,10 +20,10 @@ target, they compute a gradient according to a given loss function.
* [`L1HingeEmbeddingCriterion`](#nn.L1HingeEmbeddingCriterion): L1 distance between two inputs;
* [`CosineEmbeddingCriterion`](#nn.CosineEmbeddingCriterion): cosine distance between two inputs;
* Miscelaneus criterions:
- * [`MultiCriterion`](#nn.MultiCriterion): a weighted sum of other criterions;
+ * [`MultiCriterion`](#nn.MultiCriterion) : a weighted sum of other criterions each applied to the same input and target;
+ * [`ParallelCriterion`](#nn.ParallelCriterion) : a weighted sum of other criterions each applied to a different input and target;
* [`MarginRankingCriterion`](#nn.MarginRankingCriterion): ranks two inputs;
-
<a name="nn.Criterion"/>
## Criterion ##
@@ -344,10 +344,50 @@ This returns a Criterion which is a weighted sum of other Criterion.
Criterions are added using the method:
```lua
-criterion:add(singleCriterion, weight)
+criterion:add(singleCriterion [, weight])
+```
+
+where `weight` is a scalar (default 1). Each criterion is applied to the same `input` and `target`.
+
+Example :
+
+```lua
+input = torch.rand(2,10)
+target = torch.IntTensor{1,8}
+nll = nn.ClassNLLCriterion()
+nll2 = nn.CrossEntropyCriterion()
+mc = nn.MultiCriterion():add(nll, 0.5):add(nll2)
+output = mc:forward(input, target)
+```
+
+<a name="nn.ParallelCriterion"/>
+## ParallelCriterion ##
+
+```lua
+criterion = nn.ParallelCriterion([repeatTarget])
```
-where `weight` is a scalar.
+This returns a Criterion which is a weighted sum of other Criterion.
+Criterions are added using the method:
+
+```lua
+criterion:add(singleCriterion [, weight])
+```
+
+where `weight` is a scalar (default 1). The criterion expects an `input` and `target` table.
+Each criterion is applied to the commensurate `input` and `target` element in the tables.
+However, if `repeatTarget=true`, the `target` is repeatedly presented to each criterion (with a different `input`).
+
+Example :
+
+```lua
+input = {torch.rand(2,10), torch.randn(2,10)}
+target = {torch.IntTensor{1,8}, torch.randn(2,10)}
+nll = nn.ClassNLLCriterion()
+mse = nn.MSECriterion()
+pc = nn.ParallelCriterion():add(nll, 0.5):add(mse)
+output = pc:forward(input, target)
+```
<a name="nn.HingeEmbeddingCriterion"/>
diff --git a/init.lua b/init.lua
index c2b2996..b1d36db 100644
--- a/init.lua
+++ b/init.lua
@@ -121,6 +121,7 @@ include('L1Penalty.lua')
include('WeightedMSECriterion.lua')
include('BCECriterion.lua')
include('CrossEntropyCriterion.lua')
+include('ParallelCriterion.lua')
include('StochasticGradient.lua')
diff --git a/test.lua b/test.lua
index a28674a..9414a66 100644
--- a/test.lua
+++ b/test.lua
@@ -811,6 +811,66 @@ function nntest.MarginRankingCriterion()
mytester:assert(torch.type(gradInput2[2]) == 'torch.FloatTensor', "MRC:type() error 2")
end
+function nntest.ParallelCriterion()
+ local input = {torch.rand(2,10), torch.randn(2,10)}
+ local target = {torch.IntTensor{1,8}, torch.randn(2,10)}
+ local nll = nn.ClassNLLCriterion()
+ local mse = nn.MSECriterion()
+ local pc = nn.ParallelCriterion():add(nll, 0.5):add(mse)
+ local output = pc:forward(input, target)
+ local output2 = nll:forward(input[1], target[1])/2 + mse:forward(input[2], target[2])
+ mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion forward error")
+ local gradInput = pc:backward(input, target)
+ local gradInput2 = {nll:backward(input[1], target[1]):clone():div(2), mse:backward(input[2], target[2])}
+ mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion backward error 1")
+ mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion backward error 2")
+ -- test type
+ pc:float()
+ gradInput[1], gradInput[2] = gradInput[1]:clone(), gradInput[2]:clone()
+ local input3 = {input[1]:float(), input[2]:float()}
+ local target3 = {target[1]:float(), target[2]:float()}
+ local output3 = pc:forward(input3, target3)
+ local gradInput3 = pc:backward(input3, target3)
+ mytester:assert(math.abs(output3 - output) < 0.00001, "ParallelCriterion forward error type")
+ mytester:assertTensorEq(gradInput[1]:float(), gradInput3[1], 0.000001, "ParallelCriterion backward error 1 type")
+ mytester:assertTensorEq(gradInput[2]:float(), gradInput3[2], 0.000001, "ParallelCriterion backward error 2 type")
+ -- test repeatTarget
+ local input = {torch.rand(2,10), torch.randn(2,10)}
+ local target = torch.randn(2,10)
+ local mse = nn.MSECriterion()
+ local pc = nn.ParallelCriterion(true):add(mse, 0.5):add(mse:clone())
+ local output = pc:forward(input, target)
+ local output2 = mse:forward(input[1], target)/2 + mse:forward(input[2], target)
+ mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion repeatTarget forward error")
+ local gradInput = pc:backward(input, target)
+ local gradInput2 = {mse:backward(input[1], target):clone():div(2), mse:backward(input[2], target)}
+ mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion repeatTarget backward error 1")
+ mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion repeatTarget backward error 2")
+end
+
+function nntest.MultiCriterion()
+ local input = torch.rand(2,10)
+ local target = torch.IntTensor{1,8}
+ local nll = nn.ClassNLLCriterion()
+ local nll2 = nn.CrossEntropyCriterion()
+ local mc = nn.MultiCriterion():add(nll, 0.5):add(nll2)
+ local output = mc:forward(input, target)
+ local output2 = nll:forward(input, target)/2 + nll2:forward(input, target)
+ mytester:assert(math.abs(output2 - output) < 0.00001, "MultiCriterion forward error")
+ local gradInput = mc:backward(input, target)
+ local gradInput2 = nll:backward(input, target):clone():div(2):add(nll2:backward(input, target))
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "MultiCriterion backward error ")
+ -- test type
+ mc:float()
+ gradInput = gradInput:clone()
+ local input3 = input:float()
+ local target3 = target:float()
+ local output3 = mc:forward(input3, target3)
+ local gradInput3 = mc:backward(input3, target3)
+ mytester:assert(math.abs(output3 - output) < 0.00001, "MultiCriterion forward error type")
+ mytester:assertTensorEq(gradInput:float(), gradInput3, 0.000001, "MultiCriterion backward error type")
+end
+
function nntest.WeightedMSECriterion()
local input = torch.rand(10)
local target = input:clone():add(torch.rand(10))