diff options
-rw-r--r-- | MultiCriterion.lua | 7 | ||||
-rw-r--r-- | ParallelCriterion.lua | 52 | ||||
-rwxr-xr-x | doc/criterion.md | 48 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua | 60 |
5 files changed, 164 insertions, 4 deletions
diff --git a/MultiCriterion.lua b/MultiCriterion.lua index e83b97e..4ad5a73 100644 --- a/MultiCriterion.lua +++ b/MultiCriterion.lua @@ -30,3 +30,10 @@ function MultiCriterion:updateGradInput(input, target) end return self.gradInput end + +function MultiCriterion:type(type) + for i,criterion in ipairs(self.criterions) do + criterion:type(type) + end + return parent.type(self, type) +end diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua new file mode 100644 index 0000000..bee1f9c --- /dev/null +++ b/ParallelCriterion.lua @@ -0,0 +1,52 @@ +local ParallelCriterion, parent = torch.class('nn.ParallelCriterion', 'nn.Criterion') + +function ParallelCriterion:__init(repeatTarget) + parent.__init(self) + self.criterions = {} + self.weights = {} + self.gradInput = {} + self.repeatTarget = repeatTarget +end + +function ParallelCriterion:add(criterion, weight) + weight = weight or 1 + table.insert(self.criterions, criterion) + table.insert(self.weights, weight) + return self +end + +function ParallelCriterion:updateOutput(input, target) + self.output = 0 + if not self.repeatTarget then + for i,criterion in ipairs(self.criterions) do + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target[i]) + end + else + for i,criterion in ipairs(self.criterions) do + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) + end + end + return self.output +end + +function ParallelCriterion:updateGradInput(input, target) + if not self.repeatTarget then + for i,criterion in ipairs(self.criterions) do + self.gradInput[i] = input[i].new() or self.gradInput[i] + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target[i])) + end + else + for i,criterion in ipairs(self.criterions) do + self.gradInput[i] = input[i].new() or self.gradInput[i] + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target)) + end + end + return self.gradInput +end + +function ParallelCriterion:type(type) + self.gradInput = {} + return parent.type(self, type) +end diff --git a/doc/criterion.md b/doc/criterion.md index 3d3aefd..d0c9366 100755 --- a/doc/criterion.md +++ b/doc/criterion.md @@ -20,10 +20,10 @@ target, they compute a gradient according to a given loss function. * [`L1HingeEmbeddingCriterion`](#nn.L1HingeEmbeddingCriterion): L1 distance between two inputs; * [`CosineEmbeddingCriterion`](#nn.CosineEmbeddingCriterion): cosine distance between two inputs; * Miscelaneus criterions: - * [`MultiCriterion`](#nn.MultiCriterion): a weighted sum of other criterions; + * [`MultiCriterion`](#nn.MultiCriterion) : a weighted sum of other criterions each applied to the same input and target; + * [`ParallelCriterion`](#nn.ParallelCriterion) : a weighted sum of other criterions each applied to a different input and target; * [`MarginRankingCriterion`](#nn.MarginRankingCriterion): ranks two inputs; - <a name="nn.Criterion"/> ## Criterion ## @@ -344,10 +344,50 @@ This returns a Criterion which is a weighted sum of other Criterion. Criterions are added using the method: ```lua -criterion:add(singleCriterion, weight) +criterion:add(singleCriterion [, weight]) +``` + +where `weight` is a scalar (default 1). Each criterion is applied to the same `input` and `target`. + +Example : + +```lua +input = torch.rand(2,10) +target = torch.IntTensor{1,8} +nll = nn.ClassNLLCriterion() +nll2 = nn.CrossEntropyCriterion() +mc = nn.MultiCriterion():add(nll, 0.5):add(nll2) +output = mc:forward(input, target) +``` + +<a name="nn.ParallelCriterion"/> +## ParallelCriterion ## + +```lua +criterion = nn.ParallelCriterion([repeatTarget]) ``` -where `weight` is a scalar. +This returns a Criterion which is a weighted sum of other Criterion. +Criterions are added using the method: + +```lua +criterion:add(singleCriterion [, weight]) +``` + +where `weight` is a scalar (default 1). The criterion expects an `input` and `target` table. +Each criterion is applied to the commensurate `input` and `target` element in the tables. +However, if `repeatTarget=true`, the `target` is repeatedly presented to each criterion (with a different `input`). + +Example : + +```lua +input = {torch.rand(2,10), torch.randn(2,10)} +target = {torch.IntTensor{1,8}, torch.randn(2,10)} +nll = nn.ClassNLLCriterion() +mse = nn.MSECriterion() +pc = nn.ParallelCriterion():add(nll, 0.5):add(mse) +output = pc:forward(input, target) +``` <a name="nn.HingeEmbeddingCriterion"/> @@ -121,6 +121,7 @@ include('L1Penalty.lua') include('WeightedMSECriterion.lua') include('BCECriterion.lua') include('CrossEntropyCriterion.lua') +include('ParallelCriterion.lua') include('StochasticGradient.lua') @@ -811,6 +811,66 @@ function nntest.MarginRankingCriterion() mytester:assert(torch.type(gradInput2[2]) == 'torch.FloatTensor', "MRC:type() error 2") end +function nntest.ParallelCriterion() + local input = {torch.rand(2,10), torch.randn(2,10)} + local target = {torch.IntTensor{1,8}, torch.randn(2,10)} + local nll = nn.ClassNLLCriterion() + local mse = nn.MSECriterion() + local pc = nn.ParallelCriterion():add(nll, 0.5):add(mse) + local output = pc:forward(input, target) + local output2 = nll:forward(input[1], target[1])/2 + mse:forward(input[2], target[2]) + mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion forward error") + local gradInput = pc:backward(input, target) + local gradInput2 = {nll:backward(input[1], target[1]):clone():div(2), mse:backward(input[2], target[2])} + mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion backward error 1") + mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion backward error 2") + -- test type + pc:float() + gradInput[1], gradInput[2] = gradInput[1]:clone(), gradInput[2]:clone() + local input3 = {input[1]:float(), input[2]:float()} + local target3 = {target[1]:float(), target[2]:float()} + local output3 = pc:forward(input3, target3) + local gradInput3 = pc:backward(input3, target3) + mytester:assert(math.abs(output3 - output) < 0.00001, "ParallelCriterion forward error type") + mytester:assertTensorEq(gradInput[1]:float(), gradInput3[1], 0.000001, "ParallelCriterion backward error 1 type") + mytester:assertTensorEq(gradInput[2]:float(), gradInput3[2], 0.000001, "ParallelCriterion backward error 2 type") + -- test repeatTarget + local input = {torch.rand(2,10), torch.randn(2,10)} + local target = torch.randn(2,10) + local mse = nn.MSECriterion() + local pc = nn.ParallelCriterion(true):add(mse, 0.5):add(mse:clone()) + local output = pc:forward(input, target) + local output2 = mse:forward(input[1], target)/2 + mse:forward(input[2], target) + mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion repeatTarget forward error") + local gradInput = pc:backward(input, target) + local gradInput2 = {mse:backward(input[1], target):clone():div(2), mse:backward(input[2], target)} + mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion repeatTarget backward error 1") + mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion repeatTarget backward error 2") +end + +function nntest.MultiCriterion() + local input = torch.rand(2,10) + local target = torch.IntTensor{1,8} + local nll = nn.ClassNLLCriterion() + local nll2 = nn.CrossEntropyCriterion() + local mc = nn.MultiCriterion():add(nll, 0.5):add(nll2) + local output = mc:forward(input, target) + local output2 = nll:forward(input, target)/2 + nll2:forward(input, target) + mytester:assert(math.abs(output2 - output) < 0.00001, "MultiCriterion forward error") + local gradInput = mc:backward(input, target) + local gradInput2 = nll:backward(input, target):clone():div(2):add(nll2:backward(input, target)) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "MultiCriterion backward error ") + -- test type + mc:float() + gradInput = gradInput:clone() + local input3 = input:float() + local target3 = target:float() + local output3 = mc:forward(input3, target3) + local gradInput3 = mc:backward(input3, target3) + mytester:assert(math.abs(output3 - output) < 0.00001, "MultiCriterion forward error type") + mytester:assertTensorEq(gradInput:float(), gradInput3, 0.000001, "MultiCriterion backward error type") +end + function nntest.WeightedMSECriterion() local input = torch.rand(10) local target = input:clone():add(torch.rand(10)) |